Simplify _locally_reject_invite

Update `EventCreationHandler.create_event` to accept an auth_events param, and
use it in `_locally_reject_invite` instead of reinventing the wheel.
This commit is contained in:
Richard van der Hoff 2020-10-13 23:14:35 +01:00
parent d9d86c2996
commit a34b17e492
6 changed files with 52 additions and 57 deletions

View file

@ -97,32 +97,37 @@ class EventBuilder:
def is_state(self): def is_state(self):
return self._state_key is not None return self._state_key is not None
async def build(self, prev_event_ids: List[str]) -> EventBase: async def build(
self, prev_event_ids: List[str], auth_event_ids: Optional[List[str]]
) -> EventBase:
"""Transform into a fully signed and hashed event """Transform into a fully signed and hashed event
Args: Args:
prev_event_ids: The event IDs to use as the prev events prev_event_ids: The event IDs to use as the prev events
auth_event_ids: The event IDs to use as the auth events.
Should normally be set to None, which will cause them to be calculated
based on the room state at the prev_events.
Returns: Returns:
The signed and hashed event. The signed and hashed event.
""" """
if auth_event_ids is None:
state_ids = await self._state.get_current_state_ids( state_ids = await self._state.get_current_state_ids(
self.room_id, prev_event_ids self.room_id, prev_event_ids
) )
auth_ids = self._auth.compute_auth_events(self, state_ids) auth_event_ids = self._auth.compute_auth_events(self, state_ids)
format_version = self.room_version.event_format format_version = self.room_version.event_format
if format_version == EventFormatVersions.V1: if format_version == EventFormatVersions.V1:
# The types of auth/prev events changes between event versions. # The types of auth/prev events changes between event versions.
auth_events = await self._store.add_event_hashes( auth_events = await self._store.add_event_hashes(
auth_ids auth_event_ids
) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]] ) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]]
prev_events = await self._store.add_event_hashes( prev_events = await self._store.add_event_hashes(
prev_event_ids prev_event_ids
) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]] ) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]]
else: else:
auth_events = auth_ids auth_events = auth_event_ids
prev_events = prev_event_ids prev_events = prev_event_ids
old_depth = await self._store.get_max_depth_of(prev_event_ids) old_depth = await self._store.get_max_depth_of(prev_event_ids)

View file

@ -439,6 +439,7 @@ class EventCreationHandler:
event_dict: dict, event_dict: dict,
txn_id: Optional[str] = None, txn_id: Optional[str] = None,
prev_event_ids: Optional[List[str]] = None, prev_event_ids: Optional[List[str]] = None,
auth_event_ids: Optional[List[str]] = None,
require_consent: bool = True, require_consent: bool = True,
) -> Tuple[EventBase, EventContext]: ) -> Tuple[EventBase, EventContext]:
""" """
@ -458,6 +459,12 @@ class EventCreationHandler:
new event. new event.
If None, they will be requested from the database. If None, they will be requested from the database.
auth_event_ids:
The event ids to use as the auth_events for the new event.
Should normally be left as None, which will cause them to be calculated
based on the room state at the prev_events.
require_consent: Whether to check if the requester has require_consent: Whether to check if the requester has
consented to the privacy policy. consented to the privacy policy.
Raises: Raises:
@ -516,7 +523,10 @@ class EventCreationHandler:
builder.internal_metadata.txn_id = txn_id builder.internal_metadata.txn_id = txn_id
event, context = await self.create_new_client_event( event, context = await self.create_new_client_event(
builder=builder, requester=requester, prev_event_ids=prev_event_ids, builder=builder,
requester=requester,
prev_event_ids=prev_event_ids,
auth_event_ids=auth_event_ids,
) )
# In an ideal world we wouldn't need the second part of this condition. However, # In an ideal world we wouldn't need the second part of this condition. However,
@ -755,6 +765,7 @@ class EventCreationHandler:
builder: EventBuilder, builder: EventBuilder,
requester: Optional[Requester] = None, requester: Optional[Requester] = None,
prev_event_ids: Optional[List[str]] = None, prev_event_ids: Optional[List[str]] = None,
auth_event_ids: Optional[List[str]] = None,
) -> Tuple[EventBase, EventContext]: ) -> Tuple[EventBase, EventContext]:
"""Create a new event for a local client """Create a new event for a local client
@ -767,6 +778,11 @@ class EventCreationHandler:
If None, they will be requested from the database. If None, they will be requested from the database.
auth_event_ids:
The event ids to use as the auth_events for the new event.
Should normally be left as None, which will cause them to be calculated
based on the room state at the prev_events.
Returns: Returns:
Tuple of created event, context Tuple of created event, context
""" """
@ -788,7 +804,9 @@ class EventCreationHandler:
builder.type == EventTypes.Create or len(prev_event_ids) > 0 builder.type == EventTypes.Create or len(prev_event_ids) > 0
), "Attempting to create an event with no prev_events" ), "Attempting to create an event with no prev_events"
event = await builder.build(prev_event_ids=prev_event_ids) event = await builder.build(
prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids
)
context = await self.state.compute_event_context(event) context = await self.state.compute_event_context(event)
if requester: if requester:
context.app_service = requester.app_service context.app_service = requester.app_service

View file

@ -17,12 +17,10 @@ import abc
import logging import logging
import random import random
from http import HTTPStatus from http import HTTPStatus
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
from unpaddedbase64 import encode_base64
from synapse import types from synapse import types
from synapse.api.constants import MAX_DEPTH, AccountDataTypes, EventTypes, Membership from synapse.api.constants import AccountDataTypes, EventTypes, Membership
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, AuthError,
Codes, Codes,
@ -31,12 +29,8 @@ from synapse.api.errors import (
SynapseError, SynapseError,
) )
from synapse.api.ratelimiting import Ratelimiter from synapse.api.ratelimiting import Ratelimiter
from synapse.api.room_versions import EventFormatVersions
from synapse.crypto.event_signing import compute_event_reference_hash
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.builder import create_local_event_from_event_dict
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator
from synapse.storage.roommember import RoomsForUser from synapse.storage.roommember import RoomsForUser
from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
@ -1132,31 +1126,10 @@ class RoomMemberMasterHandler(RoomMemberHandler):
room_id = invite_event.room_id room_id = invite_event.room_id
target_user = invite_event.state_key target_user = invite_event.state_key
room_version = await self.store.get_room_version(room_id)
content["membership"] = Membership.LEAVE content["membership"] = Membership.LEAVE
# the auth events for the new event are the same as that of the invite, plus
# the invite itself.
#
# the prev_events are just the invite.
invite_hash = invite_event.event_id # type: Union[str, Tuple]
if room_version.event_format == EventFormatVersions.V1:
alg, h = compute_event_reference_hash(invite_event)
invite_hash = (invite_event.event_id, {alg: encode_base64(h)})
auth_events = tuple(invite_event.auth_events) + (invite_hash,)
prev_events = (invite_hash,)
# we cap depth of generated events, to ensure that they are not
# rejected by other servers (and so that they can be persisted in
# the db)
depth = min(invite_event.depth + 1, MAX_DEPTH)
event_dict = { event_dict = {
"depth": depth,
"auth_events": auth_events,
"prev_events": prev_events,
"type": EventTypes.Member, "type": EventTypes.Member,
"room_id": room_id, "room_id": room_id,
"sender": target_user, "sender": target_user,
@ -1164,24 +1137,23 @@ class RoomMemberMasterHandler(RoomMemberHandler):
"state_key": target_user, "state_key": target_user,
} }
event = create_local_event_from_event_dict( # the auth events for the new event are the same as that of the invite, plus
clock=self.clock, # the invite itself.
hostname=self.hs.hostname, #
signing_key=self.hs.signing_key, # the prev_events are just the invite.
room_version=room_version, prev_event_ids = [invite_event.event_id]
event_dict=event_dict, auth_event_ids = invite_event.auth_event_ids() + prev_event_ids
event, context = await self.event_creation_handler.create_event(
requester,
event_dict,
txn_id=txn_id,
prev_event_ids=prev_event_ids,
auth_event_ids=auth_event_ids,
) )
event.internal_metadata.outlier = True event.internal_metadata.outlier = True
event.internal_metadata.out_of_band_membership = True event.internal_metadata.out_of_band_membership = True
if txn_id is not None:
event.internal_metadata.txn_id = txn_id
if requester.access_token_id is not None:
event.internal_metadata.token_id = requester.access_token_id
EventValidator().validate_new(event, self.config)
context = await self.state_handler.compute_event_context(event)
context.app_service = requester.app_service
result_event = await self.event_creation_handler.handle_new_client_event( result_event = await self.event_creation_handler.handle_new_client_event(
requester, event, context, extra_users=[UserID.from_string(target_user)], requester, event, context, extra_users=[UserID.from_string(target_user)],
) )

View file

@ -615,7 +615,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
self.store.get_latest_event_ids_in_room(room_id) self.store.get_latest_event_ids_in_room(room_id)
) )
event = self.get_success(builder.build(prev_event_ids)) event = self.get_success(builder.build(prev_event_ids, None))
self.get_success(self.federation_handler.on_receive_pdu(hostname, event)) self.get_success(self.federation_handler.on_receive_pdu(hostname, event))

View file

@ -226,7 +226,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
} }
builder = factory.for_room_version(room_version, event_dict) builder = factory.for_room_version(room_version, event_dict)
join_event = self.get_success(builder.build(prev_event_ids)) join_event = self.get_success(builder.build(prev_event_ids, None))
self.get_success(federation.on_send_join_request(remote_server, join_event)) self.get_success(federation.on_send_join_request(remote_server, join_event))
self.replicate() self.replicate()

View file

@ -236,9 +236,9 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self._event_id = event_id self._event_id = event_id
@defer.inlineCallbacks @defer.inlineCallbacks
def build(self, prev_event_ids): def build(self, prev_event_ids, auth_event_ids):
built_event = yield defer.ensureDeferred( built_event = yield defer.ensureDeferred(
self._base_builder.build(prev_event_ids) self._base_builder.build(prev_event_ids, auth_event_ids)
) )
built_event._event_id = self._event_id built_event._event_id = self._event_id