0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-06-14 08:38:24 +02:00

Simplify _locally_reject_invite

Update `EventCreationHandler.create_event` to accept an auth_events param, and
use it in `_locally_reject_invite` instead of reinventing the wheel.
This commit is contained in:
Richard van der Hoff 2020-10-13 23:14:35 +01:00
parent d9d86c2996
commit a34b17e492
6 changed files with 52 additions and 57 deletions

View file

@ -97,32 +97,37 @@ class EventBuilder:
def is_state(self):
return self._state_key is not None
async def build(self, prev_event_ids: List[str]) -> EventBase:
async def build(
self, prev_event_ids: List[str], auth_event_ids: Optional[List[str]]
) -> EventBase:
"""Transform into a fully signed and hashed event
Args:
prev_event_ids: The event IDs to use as the prev events
auth_event_ids: The event IDs to use as the auth events.
Should normally be set to None, which will cause them to be calculated
based on the room state at the prev_events.
Returns:
The signed and hashed event.
"""
state_ids = await self._state.get_current_state_ids(
self.room_id, prev_event_ids
)
auth_ids = self._auth.compute_auth_events(self, state_ids)
if auth_event_ids is None:
state_ids = await self._state.get_current_state_ids(
self.room_id, prev_event_ids
)
auth_event_ids = self._auth.compute_auth_events(self, state_ids)
format_version = self.room_version.event_format
if format_version == EventFormatVersions.V1:
# The types of auth/prev events changes between event versions.
auth_events = await self._store.add_event_hashes(
auth_ids
auth_event_ids
) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]]
prev_events = await self._store.add_event_hashes(
prev_event_ids
) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]]
else:
auth_events = auth_ids
auth_events = auth_event_ids
prev_events = prev_event_ids
old_depth = await self._store.get_max_depth_of(prev_event_ids)

View file

@ -439,6 +439,7 @@ class EventCreationHandler:
event_dict: dict,
txn_id: Optional[str] = None,
prev_event_ids: Optional[List[str]] = None,
auth_event_ids: Optional[List[str]] = None,
require_consent: bool = True,
) -> Tuple[EventBase, EventContext]:
"""
@ -458,6 +459,12 @@ class EventCreationHandler:
new event.
If None, they will be requested from the database.
auth_event_ids:
The event ids to use as the auth_events for the new event.
Should normally be left as None, which will cause them to be calculated
based on the room state at the prev_events.
require_consent: Whether to check if the requester has
consented to the privacy policy.
Raises:
@ -516,7 +523,10 @@ class EventCreationHandler:
builder.internal_metadata.txn_id = txn_id
event, context = await self.create_new_client_event(
builder=builder, requester=requester, prev_event_ids=prev_event_ids,
builder=builder,
requester=requester,
prev_event_ids=prev_event_ids,
auth_event_ids=auth_event_ids,
)
# In an ideal world we wouldn't need the second part of this condition. However,
@ -755,6 +765,7 @@ class EventCreationHandler:
builder: EventBuilder,
requester: Optional[Requester] = None,
prev_event_ids: Optional[List[str]] = None,
auth_event_ids: Optional[List[str]] = None,
) -> Tuple[EventBase, EventContext]:
"""Create a new event for a local client
@ -767,6 +778,11 @@ class EventCreationHandler:
If None, they will be requested from the database.
auth_event_ids:
The event ids to use as the auth_events for the new event.
Should normally be left as None, which will cause them to be calculated
based on the room state at the prev_events.
Returns:
Tuple of created event, context
"""
@ -788,7 +804,9 @@ class EventCreationHandler:
builder.type == EventTypes.Create or len(prev_event_ids) > 0
), "Attempting to create an event with no prev_events"
event = await builder.build(prev_event_ids=prev_event_ids)
event = await builder.build(
prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids
)
context = await self.state.compute_event_context(event)
if requester:
context.app_service = requester.app_service

View file

@ -17,12 +17,10 @@ import abc
import logging
import random
from http import HTTPStatus
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union
from unpaddedbase64 import encode_base64
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
from synapse import types
from synapse.api.constants import MAX_DEPTH, AccountDataTypes, EventTypes, Membership
from synapse.api.constants import AccountDataTypes, EventTypes, Membership
from synapse.api.errors import (
AuthError,
Codes,
@ -31,12 +29,8 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.api.ratelimiting import Ratelimiter
from synapse.api.room_versions import EventFormatVersions
from synapse.crypto.event_signing import compute_event_reference_hash
from synapse.events import EventBase
from synapse.events.builder import create_local_event_from_event_dict
from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator
from synapse.storage.roommember import RoomsForUser
from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID
from synapse.util.async_helpers import Linearizer
@ -1132,31 +1126,10 @@ class RoomMemberMasterHandler(RoomMemberHandler):
room_id = invite_event.room_id
target_user = invite_event.state_key
room_version = await self.store.get_room_version(room_id)
content["membership"] = Membership.LEAVE
# the auth events for the new event are the same as that of the invite, plus
# the invite itself.
#
# the prev_events are just the invite.
invite_hash = invite_event.event_id # type: Union[str, Tuple]
if room_version.event_format == EventFormatVersions.V1:
alg, h = compute_event_reference_hash(invite_event)
invite_hash = (invite_event.event_id, {alg: encode_base64(h)})
auth_events = tuple(invite_event.auth_events) + (invite_hash,)
prev_events = (invite_hash,)
# we cap depth of generated events, to ensure that they are not
# rejected by other servers (and so that they can be persisted in
# the db)
depth = min(invite_event.depth + 1, MAX_DEPTH)
event_dict = {
"depth": depth,
"auth_events": auth_events,
"prev_events": prev_events,
"type": EventTypes.Member,
"room_id": room_id,
"sender": target_user,
@ -1164,24 +1137,23 @@ class RoomMemberMasterHandler(RoomMemberHandler):
"state_key": target_user,
}
event = create_local_event_from_event_dict(
clock=self.clock,
hostname=self.hs.hostname,
signing_key=self.hs.signing_key,
room_version=room_version,
event_dict=event_dict,
# the auth events for the new event are the same as that of the invite, plus
# the invite itself.
#
# the prev_events are just the invite.
prev_event_ids = [invite_event.event_id]
auth_event_ids = invite_event.auth_event_ids() + prev_event_ids
event, context = await self.event_creation_handler.create_event(
requester,
event_dict,
txn_id=txn_id,
prev_event_ids=prev_event_ids,
auth_event_ids=auth_event_ids,
)
event.internal_metadata.outlier = True
event.internal_metadata.out_of_band_membership = True
if txn_id is not None:
event.internal_metadata.txn_id = txn_id
if requester.access_token_id is not None:
event.internal_metadata.token_id = requester.access_token_id
EventValidator().validate_new(event, self.config)
context = await self.state_handler.compute_event_context(event)
context.app_service = requester.app_service
result_event = await self.event_creation_handler.handle_new_client_event(
requester, event, context, extra_users=[UserID.from_string(target_user)],
)

View file

@ -615,7 +615,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
self.store.get_latest_event_ids_in_room(room_id)
)
event = self.get_success(builder.build(prev_event_ids))
event = self.get_success(builder.build(prev_event_ids, None))
self.get_success(self.federation_handler.on_receive_pdu(hostname, event))

View file

@ -226,7 +226,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
}
builder = factory.for_room_version(room_version, event_dict)
join_event = self.get_success(builder.build(prev_event_ids))
join_event = self.get_success(builder.build(prev_event_ids, None))
self.get_success(federation.on_send_join_request(remote_server, join_event))
self.replicate()

View file

@ -236,9 +236,9 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self._event_id = event_id
@defer.inlineCallbacks
def build(self, prev_event_ids):
def build(self, prev_event_ids, auth_event_ids):
built_event = yield defer.ensureDeferred(
self._base_builder.build(prev_event_ids)
self._base_builder.build(prev_event_ids, auth_event_ids)
)
built_event._event_id = self._event_id