Merge pull request #8537 from matrix-org/rav/simplify_locally_reject_invite

Simplify `_locally_reject_invite`
This commit is contained in:
Richard van der Hoff 2020-10-15 10:20:19 +01:00 committed by GitHub
commit 4433d01519
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 56 additions and 65 deletions

1
changelog.d/8537.misc Normal file
View file

@ -0,0 +1 @@
Factor out common code between `RoomMemberHandler._locally_reject_invite` and `EventCreationHandler.create_event`.

View file

@ -97,32 +97,37 @@ class EventBuilder:
def is_state(self): def is_state(self):
return self._state_key is not None return self._state_key is not None
async def build(self, prev_event_ids: List[str]) -> EventBase: async def build(
self, prev_event_ids: List[str], auth_event_ids: Optional[List[str]]
) -> EventBase:
"""Transform into a fully signed and hashed event """Transform into a fully signed and hashed event
Args: Args:
prev_event_ids: The event IDs to use as the prev events prev_event_ids: The event IDs to use as the prev events
auth_event_ids: The event IDs to use as the auth events.
Should normally be set to None, which will cause them to be calculated
based on the room state at the prev_events.
Returns: Returns:
The signed and hashed event. The signed and hashed event.
""" """
if auth_event_ids is None:
state_ids = await self._state.get_current_state_ids( state_ids = await self._state.get_current_state_ids(
self.room_id, prev_event_ids self.room_id, prev_event_ids
) )
auth_ids = self._auth.compute_auth_events(self, state_ids) auth_event_ids = self._auth.compute_auth_events(self, state_ids)
format_version = self.room_version.event_format format_version = self.room_version.event_format
if format_version == EventFormatVersions.V1: if format_version == EventFormatVersions.V1:
# The types of auth/prev events changes between event versions. # The types of auth/prev events changes between event versions.
auth_events = await self._store.add_event_hashes( auth_events = await self._store.add_event_hashes(
auth_ids auth_event_ids
) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]] ) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]]
prev_events = await self._store.add_event_hashes( prev_events = await self._store.add_event_hashes(
prev_event_ids prev_event_ids
) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]] ) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]]
else: else:
auth_events = auth_ids auth_events = auth_event_ids
prev_events = prev_event_ids prev_events = prev_event_ids
old_depth = await self._store.get_max_depth_of(prev_event_ids) old_depth = await self._store.get_max_depth_of(prev_event_ids)

View file

@ -437,9 +437,9 @@ class EventCreationHandler:
self, self,
requester: Requester, requester: Requester,
event_dict: dict, event_dict: dict,
token_id: Optional[str] = None,
txn_id: Optional[str] = None, txn_id: Optional[str] = None,
prev_event_ids: Optional[List[str]] = None, prev_event_ids: Optional[List[str]] = None,
auth_event_ids: Optional[List[str]] = None,
require_consent: bool = True, require_consent: bool = True,
) -> Tuple[EventBase, EventContext]: ) -> Tuple[EventBase, EventContext]:
""" """
@ -453,13 +453,18 @@ class EventCreationHandler:
Args: Args:
requester requester
event_dict: An entire event event_dict: An entire event
token_id
txn_id txn_id
prev_event_ids: prev_event_ids:
the forward extremities to use as the prev_events for the the forward extremities to use as the prev_events for the
new event. new event.
If None, they will be requested from the database. If None, they will be requested from the database.
auth_event_ids:
The event ids to use as the auth_events for the new event.
Should normally be left as None, which will cause them to be calculated
based on the room state at the prev_events.
require_consent: Whether to check if the requester has require_consent: Whether to check if the requester has
consented to the privacy policy. consented to the privacy policy.
Raises: Raises:
@ -511,14 +516,17 @@ class EventCreationHandler:
if require_consent and not is_exempt: if require_consent and not is_exempt:
await self.assert_accepted_privacy_policy(requester) await self.assert_accepted_privacy_policy(requester)
if token_id is not None: if requester.access_token_id is not None:
builder.internal_metadata.token_id = token_id builder.internal_metadata.token_id = requester.access_token_id
if txn_id is not None: if txn_id is not None:
builder.internal_metadata.txn_id = txn_id builder.internal_metadata.txn_id = txn_id
event, context = await self.create_new_client_event( event, context = await self.create_new_client_event(
builder=builder, requester=requester, prev_event_ids=prev_event_ids, builder=builder,
requester=requester,
prev_event_ids=prev_event_ids,
auth_event_ids=auth_event_ids,
) )
# In an ideal world we wouldn't need the second part of this condition. However, # In an ideal world we wouldn't need the second part of this condition. However,
@ -726,7 +734,7 @@ class EventCreationHandler:
return event, event.internal_metadata.stream_ordering return event, event.internal_metadata.stream_ordering
event, context = await self.create_event( event, context = await self.create_event(
requester, event_dict, token_id=requester.access_token_id, txn_id=txn_id requester, event_dict, txn_id=txn_id
) )
assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % ( assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % (
@ -757,6 +765,7 @@ class EventCreationHandler:
builder: EventBuilder, builder: EventBuilder,
requester: Optional[Requester] = None, requester: Optional[Requester] = None,
prev_event_ids: Optional[List[str]] = None, prev_event_ids: Optional[List[str]] = None,
auth_event_ids: Optional[List[str]] = None,
) -> Tuple[EventBase, EventContext]: ) -> Tuple[EventBase, EventContext]:
"""Create a new event for a local client """Create a new event for a local client
@ -769,6 +778,11 @@ class EventCreationHandler:
If None, they will be requested from the database. If None, they will be requested from the database.
auth_event_ids:
The event ids to use as the auth_events for the new event.
Should normally be left as None, which will cause them to be calculated
based on the room state at the prev_events.
Returns: Returns:
Tuple of created event, context Tuple of created event, context
""" """
@ -790,7 +804,9 @@ class EventCreationHandler:
builder.type == EventTypes.Create or len(prev_event_ids) > 0 builder.type == EventTypes.Create or len(prev_event_ids) > 0
), "Attempting to create an event with no prev_events" ), "Attempting to create an event with no prev_events"
event = await builder.build(prev_event_ids=prev_event_ids) event = await builder.build(
prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids
)
context = await self.state.compute_event_context(event) context = await self.state.compute_event_context(event)
if requester: if requester:
context.app_service = requester.app_service context.app_service = requester.app_service

View file

@ -214,7 +214,6 @@ class RoomCreationHandler(BaseHandler):
"replacement_room": new_room_id, "replacement_room": new_room_id,
}, },
}, },
token_id=requester.access_token_id,
) )
old_room_version = await self.store.get_room_version_id(old_room_id) old_room_version = await self.store.get_room_version_id(old_room_id)
await self.auth.check_from_context( await self.auth.check_from_context(

View file

@ -17,12 +17,10 @@ import abc
import logging import logging
import random import random
from http import HTTPStatus from http import HTTPStatus
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
from unpaddedbase64 import encode_base64
from synapse import types from synapse import types
from synapse.api.constants import MAX_DEPTH, AccountDataTypes, EventTypes, Membership from synapse.api.constants import AccountDataTypes, EventTypes, Membership
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, AuthError,
Codes, Codes,
@ -31,12 +29,8 @@ from synapse.api.errors import (
SynapseError, SynapseError,
) )
from synapse.api.ratelimiting import Ratelimiter from synapse.api.ratelimiting import Ratelimiter
from synapse.api.room_versions import EventFormatVersions
from synapse.crypto.event_signing import compute_event_reference_hash
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.builder import create_local_event_from_event_dict
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator
from synapse.storage.roommember import RoomsForUser from synapse.storage.roommember import RoomsForUser
from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
@ -193,7 +187,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# For backwards compatibility: # For backwards compatibility:
"membership": membership, "membership": membership,
}, },
token_id=requester.access_token_id,
txn_id=txn_id, txn_id=txn_id,
prev_event_ids=prev_event_ids, prev_event_ids=prev_event_ids,
require_consent=require_consent, require_consent=require_consent,
@ -1133,31 +1126,10 @@ class RoomMemberMasterHandler(RoomMemberHandler):
room_id = invite_event.room_id room_id = invite_event.room_id
target_user = invite_event.state_key target_user = invite_event.state_key
room_version = await self.store.get_room_version(room_id)
content["membership"] = Membership.LEAVE content["membership"] = Membership.LEAVE
# the auth events for the new event are the same as that of the invite, plus
# the invite itself.
#
# the prev_events are just the invite.
invite_hash = invite_event.event_id # type: Union[str, Tuple]
if room_version.event_format == EventFormatVersions.V1:
alg, h = compute_event_reference_hash(invite_event)
invite_hash = (invite_event.event_id, {alg: encode_base64(h)})
auth_events = tuple(invite_event.auth_events) + (invite_hash,)
prev_events = (invite_hash,)
# we cap depth of generated events, to ensure that they are not
# rejected by other servers (and so that they can be persisted in
# the db)
depth = min(invite_event.depth + 1, MAX_DEPTH)
event_dict = { event_dict = {
"depth": depth,
"auth_events": auth_events,
"prev_events": prev_events,
"type": EventTypes.Member, "type": EventTypes.Member,
"room_id": room_id, "room_id": room_id,
"sender": target_user, "sender": target_user,
@ -1165,24 +1137,23 @@ class RoomMemberMasterHandler(RoomMemberHandler):
"state_key": target_user, "state_key": target_user,
} }
event = create_local_event_from_event_dict( # the auth events for the new event are the same as that of the invite, plus
clock=self.clock, # the invite itself.
hostname=self.hs.hostname, #
signing_key=self.hs.signing_key, # the prev_events are just the invite.
room_version=room_version, prev_event_ids = [invite_event.event_id]
event_dict=event_dict, auth_event_ids = invite_event.auth_event_ids() + prev_event_ids
event, context = await self.event_creation_handler.create_event(
requester,
event_dict,
txn_id=txn_id,
prev_event_ids=prev_event_ids,
auth_event_ids=auth_event_ids,
) )
event.internal_metadata.outlier = True event.internal_metadata.outlier = True
event.internal_metadata.out_of_band_membership = True event.internal_metadata.out_of_band_membership = True
if txn_id is not None:
event.internal_metadata.txn_id = txn_id
if requester.access_token_id is not None:
event.internal_metadata.token_id = requester.access_token_id
EventValidator().validate_new(event, self.config)
context = await self.state_handler.compute_event_context(event)
context.app_service = requester.app_service
result_event = await self.event_creation_handler.handle_new_client_event( result_event = await self.event_creation_handler.handle_new_client_event(
requester, event, context, extra_users=[UserID.from_string(target_user)], requester, event, context, extra_users=[UserID.from_string(target_user)],
) )

View file

@ -66,7 +66,6 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
"sender": self.requester.user.to_string(), "sender": self.requester.user.to_string(),
"content": {"msgtype": "m.text", "body": random_string(5)}, "content": {"msgtype": "m.text", "body": random_string(5)},
}, },
token_id=self.token_id,
txn_id=txn_id, txn_id=txn_id,
) )
) )

View file

@ -615,7 +615,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
self.store.get_latest_event_ids_in_room(room_id) self.store.get_latest_event_ids_in_room(room_id)
) )
event = self.get_success(builder.build(prev_event_ids)) event = self.get_success(builder.build(prev_event_ids, None))
self.get_success(self.federation_handler.on_receive_pdu(hostname, event)) self.get_success(self.federation_handler.on_receive_pdu(hostname, event))

View file

@ -226,7 +226,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
} }
builder = factory.for_room_version(room_version, event_dict) builder = factory.for_room_version(room_version, event_dict)
join_event = self.get_success(builder.build(prev_event_ids)) join_event = self.get_success(builder.build(prev_event_ids, None))
self.get_success(federation.on_send_join_request(remote_server, join_event)) self.get_success(federation.on_send_join_request(remote_server, join_event))
self.replicate() self.replicate()

View file

@ -236,9 +236,9 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self._event_id = event_id self._event_id = event_id
@defer.inlineCallbacks @defer.inlineCallbacks
def build(self, prev_event_ids): def build(self, prev_event_ids, auth_event_ids):
built_event = yield defer.ensureDeferred( built_event = yield defer.ensureDeferred(
self._base_builder.build(prev_event_ids) self._base_builder.build(prev_event_ids, auth_event_ids)
) )
built_event._event_id = self._event_id built_event._event_id = self._event_id