mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-14 17:44:01 +01:00
Update the MSC3083 support to verify if joins are from an authorized server. (#10254)
This commit is contained in:
parent
4fb92d93ea
commit
228decfce1
17 changed files with 632 additions and 98 deletions
1
changelog.d/10254.feature
Normal file
1
changelog.d/10254.feature
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Update support for [MSC3083](https://github.com/matrix-org/matrix-doc/pull/3083) to consider changes in the MSC around which servers can issue join events.
|
|
@ -75,6 +75,9 @@ class Codes:
|
||||||
INVALID_SIGNATURE = "M_INVALID_SIGNATURE"
|
INVALID_SIGNATURE = "M_INVALID_SIGNATURE"
|
||||||
USER_DEACTIVATED = "M_USER_DEACTIVATED"
|
USER_DEACTIVATED = "M_USER_DEACTIVATED"
|
||||||
BAD_ALIAS = "M_BAD_ALIAS"
|
BAD_ALIAS = "M_BAD_ALIAS"
|
||||||
|
# For restricted join rules.
|
||||||
|
UNABLE_AUTHORISE_JOIN = "M_UNABLE_TO_AUTHORISE_JOIN"
|
||||||
|
UNABLE_TO_GRANT_JOIN = "M_UNABLE_TO_GRANT_JOIN"
|
||||||
|
|
||||||
|
|
||||||
class CodeMessageException(RuntimeError):
|
class CodeMessageException(RuntimeError):
|
||||||
|
|
|
@ -168,7 +168,7 @@ class RoomVersions:
|
||||||
msc2403_knocking=False,
|
msc2403_knocking=False,
|
||||||
)
|
)
|
||||||
MSC3083 = RoomVersion(
|
MSC3083 = RoomVersion(
|
||||||
"org.matrix.msc3083",
|
"org.matrix.msc3083.v2",
|
||||||
RoomDisposition.UNSTABLE,
|
RoomDisposition.UNSTABLE,
|
||||||
EventFormatVersions.V3,
|
EventFormatVersions.V3,
|
||||||
StateResolutionVersions.V2,
|
StateResolutionVersions.V2,
|
||||||
|
|
|
@ -106,6 +106,18 @@ def check(
|
||||||
if not event.signatures.get(event_id_domain):
|
if not event.signatures.get(event_id_domain):
|
||||||
raise AuthError(403, "Event not signed by sending server")
|
raise AuthError(403, "Event not signed by sending server")
|
||||||
|
|
||||||
|
is_invite_via_allow_rule = (
|
||||||
|
event.type == EventTypes.Member
|
||||||
|
and event.membership == Membership.JOIN
|
||||||
|
and "join_authorised_via_users_server" in event.content
|
||||||
|
)
|
||||||
|
if is_invite_via_allow_rule:
|
||||||
|
authoriser_domain = get_domain_from_id(
|
||||||
|
event.content["join_authorised_via_users_server"]
|
||||||
|
)
|
||||||
|
if not event.signatures.get(authoriser_domain):
|
||||||
|
raise AuthError(403, "Event not signed by authorising server")
|
||||||
|
|
||||||
# Implementation of https://matrix.org/docs/spec/rooms/v1#authorization-rules
|
# Implementation of https://matrix.org/docs/spec/rooms/v1#authorization-rules
|
||||||
#
|
#
|
||||||
# 1. If type is m.room.create:
|
# 1. If type is m.room.create:
|
||||||
|
@ -177,7 +189,7 @@ def check(
|
||||||
# https://github.com/vector-im/vector-web/issues/1208 hopefully
|
# https://github.com/vector-im/vector-web/issues/1208 hopefully
|
||||||
if event.type == EventTypes.ThirdPartyInvite:
|
if event.type == EventTypes.ThirdPartyInvite:
|
||||||
user_level = get_user_power_level(event.user_id, auth_events)
|
user_level = get_user_power_level(event.user_id, auth_events)
|
||||||
invite_level = _get_named_level(auth_events, "invite", 0)
|
invite_level = get_named_level(auth_events, "invite", 0)
|
||||||
|
|
||||||
if user_level < invite_level:
|
if user_level < invite_level:
|
||||||
raise AuthError(403, "You don't have permission to invite users")
|
raise AuthError(403, "You don't have permission to invite users")
|
||||||
|
@ -285,8 +297,8 @@ def _is_membership_change_allowed(
|
||||||
user_level = get_user_power_level(event.user_id, auth_events)
|
user_level = get_user_power_level(event.user_id, auth_events)
|
||||||
target_level = get_user_power_level(target_user_id, auth_events)
|
target_level = get_user_power_level(target_user_id, auth_events)
|
||||||
|
|
||||||
# FIXME (erikj): What should we do here as the default?
|
invite_level = get_named_level(auth_events, "invite", 0)
|
||||||
ban_level = _get_named_level(auth_events, "ban", 50)
|
ban_level = get_named_level(auth_events, "ban", 50)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"_is_membership_change_allowed: %s",
|
"_is_membership_change_allowed: %s",
|
||||||
|
@ -336,8 +348,6 @@ def _is_membership_change_allowed(
|
||||||
elif target_in_room: # the target is already in the room.
|
elif target_in_room: # the target is already in the room.
|
||||||
raise AuthError(403, "%s is already in the room." % target_user_id)
|
raise AuthError(403, "%s is already in the room." % target_user_id)
|
||||||
else:
|
else:
|
||||||
invite_level = _get_named_level(auth_events, "invite", 0)
|
|
||||||
|
|
||||||
if user_level < invite_level:
|
if user_level < invite_level:
|
||||||
raise AuthError(403, "You don't have permission to invite users")
|
raise AuthError(403, "You don't have permission to invite users")
|
||||||
elif Membership.JOIN == membership:
|
elif Membership.JOIN == membership:
|
||||||
|
@ -345,16 +355,41 @@ def _is_membership_change_allowed(
|
||||||
# * They are not banned.
|
# * They are not banned.
|
||||||
# * They are accepting a previously sent invitation.
|
# * They are accepting a previously sent invitation.
|
||||||
# * They are already joined (it's a NOOP).
|
# * They are already joined (it's a NOOP).
|
||||||
# * The room is public or restricted.
|
# * The room is public.
|
||||||
|
# * The room is restricted and the user meets the allows rules.
|
||||||
if event.user_id != target_user_id:
|
if event.user_id != target_user_id:
|
||||||
raise AuthError(403, "Cannot force another user to join.")
|
raise AuthError(403, "Cannot force another user to join.")
|
||||||
elif target_banned:
|
elif target_banned:
|
||||||
raise AuthError(403, "You are banned from this room")
|
raise AuthError(403, "You are banned from this room")
|
||||||
elif join_rule == JoinRules.PUBLIC or (
|
elif join_rule == JoinRules.PUBLIC:
|
||||||
|
pass
|
||||||
|
elif (
|
||||||
room_version.msc3083_join_rules
|
room_version.msc3083_join_rules
|
||||||
and join_rule == JoinRules.MSC3083_RESTRICTED
|
and join_rule == JoinRules.MSC3083_RESTRICTED
|
||||||
):
|
):
|
||||||
pass
|
# This is the same as public, but the event must contain a reference
|
||||||
|
# to the server who authorised the join. If the event does not contain
|
||||||
|
# the proper content it is rejected.
|
||||||
|
#
|
||||||
|
# Note that if the caller is in the room or invited, then they do
|
||||||
|
# not need to meet the allow rules.
|
||||||
|
if not caller_in_room and not caller_invited:
|
||||||
|
authorising_user = event.content.get("join_authorised_via_users_server")
|
||||||
|
|
||||||
|
if authorising_user is None:
|
||||||
|
raise AuthError(403, "Join event is missing authorising user.")
|
||||||
|
|
||||||
|
# The authorising user must be in the room.
|
||||||
|
key = (EventTypes.Member, authorising_user)
|
||||||
|
member_event = auth_events.get(key)
|
||||||
|
_check_joined_room(member_event, authorising_user, event.room_id)
|
||||||
|
|
||||||
|
authorising_user_level = get_user_power_level(
|
||||||
|
authorising_user, auth_events
|
||||||
|
)
|
||||||
|
if authorising_user_level < invite_level:
|
||||||
|
raise AuthError(403, "Join event authorised by invalid server.")
|
||||||
|
|
||||||
elif join_rule == JoinRules.INVITE or (
|
elif join_rule == JoinRules.INVITE or (
|
||||||
room_version.msc2403_knocking and join_rule == JoinRules.KNOCK
|
room_version.msc2403_knocking and join_rule == JoinRules.KNOCK
|
||||||
):
|
):
|
||||||
|
@ -369,7 +404,7 @@ def _is_membership_change_allowed(
|
||||||
if target_banned and user_level < ban_level:
|
if target_banned and user_level < ban_level:
|
||||||
raise AuthError(403, "You cannot unban user %s." % (target_user_id,))
|
raise AuthError(403, "You cannot unban user %s." % (target_user_id,))
|
||||||
elif target_user_id != event.user_id:
|
elif target_user_id != event.user_id:
|
||||||
kick_level = _get_named_level(auth_events, "kick", 50)
|
kick_level = get_named_level(auth_events, "kick", 50)
|
||||||
|
|
||||||
if user_level < kick_level or user_level <= target_level:
|
if user_level < kick_level or user_level <= target_level:
|
||||||
raise AuthError(403, "You cannot kick user %s." % target_user_id)
|
raise AuthError(403, "You cannot kick user %s." % target_user_id)
|
||||||
|
@ -445,7 +480,7 @@ def get_send_level(
|
||||||
|
|
||||||
|
|
||||||
def _can_send_event(event: EventBase, auth_events: StateMap[EventBase]) -> bool:
|
def _can_send_event(event: EventBase, auth_events: StateMap[EventBase]) -> bool:
|
||||||
power_levels_event = _get_power_level_event(auth_events)
|
power_levels_event = get_power_level_event(auth_events)
|
||||||
|
|
||||||
send_level = get_send_level(event.type, event.get("state_key"), power_levels_event)
|
send_level = get_send_level(event.type, event.get("state_key"), power_levels_event)
|
||||||
user_level = get_user_power_level(event.user_id, auth_events)
|
user_level = get_user_power_level(event.user_id, auth_events)
|
||||||
|
@ -485,7 +520,7 @@ def check_redaction(
|
||||||
"""
|
"""
|
||||||
user_level = get_user_power_level(event.user_id, auth_events)
|
user_level = get_user_power_level(event.user_id, auth_events)
|
||||||
|
|
||||||
redact_level = _get_named_level(auth_events, "redact", 50)
|
redact_level = get_named_level(auth_events, "redact", 50)
|
||||||
|
|
||||||
if user_level >= redact_level:
|
if user_level >= redact_level:
|
||||||
return False
|
return False
|
||||||
|
@ -600,7 +635,7 @@ def _check_power_levels(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_power_level_event(auth_events: StateMap[EventBase]) -> Optional[EventBase]:
|
def get_power_level_event(auth_events: StateMap[EventBase]) -> Optional[EventBase]:
|
||||||
return auth_events.get((EventTypes.PowerLevels, ""))
|
return auth_events.get((EventTypes.PowerLevels, ""))
|
||||||
|
|
||||||
|
|
||||||
|
@ -616,7 +651,7 @@ def get_user_power_level(user_id: str, auth_events: StateMap[EventBase]) -> int:
|
||||||
Returns:
|
Returns:
|
||||||
the user's power level in this room.
|
the user's power level in this room.
|
||||||
"""
|
"""
|
||||||
power_level_event = _get_power_level_event(auth_events)
|
power_level_event = get_power_level_event(auth_events)
|
||||||
if power_level_event:
|
if power_level_event:
|
||||||
level = power_level_event.content.get("users", {}).get(user_id)
|
level = power_level_event.content.get("users", {}).get(user_id)
|
||||||
if not level:
|
if not level:
|
||||||
|
@ -640,8 +675,8 @@ def get_user_power_level(user_id: str, auth_events: StateMap[EventBase]) -> int:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
def _get_named_level(auth_events: StateMap[EventBase], name: str, default: int) -> int:
|
def get_named_level(auth_events: StateMap[EventBase], name: str, default: int) -> int:
|
||||||
power_level_event = _get_power_level_event(auth_events)
|
power_level_event = get_power_level_event(auth_events)
|
||||||
|
|
||||||
if not power_level_event:
|
if not power_level_event:
|
||||||
return default
|
return default
|
||||||
|
@ -728,7 +763,9 @@ def get_public_keys(invite_event: EventBase) -> List[Dict[str, Any]]:
|
||||||
return public_keys
|
return public_keys
|
||||||
|
|
||||||
|
|
||||||
def auth_types_for_event(event: Union[EventBase, EventBuilder]) -> Set[Tuple[str, str]]:
|
def auth_types_for_event(
|
||||||
|
room_version: RoomVersion, event: Union[EventBase, EventBuilder]
|
||||||
|
) -> Set[Tuple[str, str]]:
|
||||||
"""Given an event, return a list of (EventType, StateKey) that may be
|
"""Given an event, return a list of (EventType, StateKey) that may be
|
||||||
needed to auth the event. The returned list may be a superset of what
|
needed to auth the event. The returned list may be a superset of what
|
||||||
would actually be required depending on the full state of the room.
|
would actually be required depending on the full state of the room.
|
||||||
|
@ -760,4 +797,12 @@ def auth_types_for_event(event: Union[EventBase, EventBuilder]) -> Set[Tuple[str
|
||||||
)
|
)
|
||||||
auth_types.add(key)
|
auth_types.add(key)
|
||||||
|
|
||||||
|
if room_version.msc3083_join_rules and membership == Membership.JOIN:
|
||||||
|
if "join_authorised_via_users_server" in event.content:
|
||||||
|
key = (
|
||||||
|
EventTypes.Member,
|
||||||
|
event.content["join_authorised_via_users_server"],
|
||||||
|
)
|
||||||
|
auth_types.add(key)
|
||||||
|
|
||||||
return auth_types
|
return auth_types
|
||||||
|
|
|
@ -178,6 +178,34 @@ async def _check_sigs_on_pdu(
|
||||||
)
|
)
|
||||||
raise SynapseError(403, errmsg, Codes.FORBIDDEN)
|
raise SynapseError(403, errmsg, Codes.FORBIDDEN)
|
||||||
|
|
||||||
|
# If this is a join event for a restricted room it may have been authorised
|
||||||
|
# via a different server from the sending server. Check those signatures.
|
||||||
|
if (
|
||||||
|
room_version.msc3083_join_rules
|
||||||
|
and pdu.type == EventTypes.Member
|
||||||
|
and pdu.membership == Membership.JOIN
|
||||||
|
and "join_authorised_via_users_server" in pdu.content
|
||||||
|
):
|
||||||
|
authorising_server = get_domain_from_id(
|
||||||
|
pdu.content["join_authorised_via_users_server"]
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
await keyring.verify_event_for_server(
|
||||||
|
authorising_server,
|
||||||
|
pdu,
|
||||||
|
pdu.origin_server_ts if room_version.enforce_key_validity else 0,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
errmsg = (
|
||||||
|
"event id %s: unable to verify signature for authorising server %s: %s"
|
||||||
|
% (
|
||||||
|
pdu.event_id,
|
||||||
|
authorising_server,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
raise SynapseError(403, errmsg, Codes.FORBIDDEN)
|
||||||
|
|
||||||
|
|
||||||
def _is_invite_via_3pid(event: EventBase) -> bool:
|
def _is_invite_via_3pid(event: EventBase) -> bool:
|
||||||
return (
|
return (
|
||||||
|
|
|
@ -19,7 +19,6 @@ import itertools
|
||||||
import logging
|
import logging
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
|
||||||
Awaitable,
|
Awaitable,
|
||||||
Callable,
|
Callable,
|
||||||
Collection,
|
Collection,
|
||||||
|
@ -79,7 +78,15 @@ class InvalidResponseError(RuntimeError):
|
||||||
we couldn't parse
|
we couldn't parse
|
||||||
"""
|
"""
|
||||||
|
|
||||||
pass
|
|
||||||
|
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||||
|
class SendJoinResult:
|
||||||
|
# The event to persist.
|
||||||
|
event: EventBase
|
||||||
|
# A string giving the server the event was sent to.
|
||||||
|
origin: str
|
||||||
|
state: List[EventBase]
|
||||||
|
auth_chain: List[EventBase]
|
||||||
|
|
||||||
|
|
||||||
class FederationClient(FederationBase):
|
class FederationClient(FederationBase):
|
||||||
|
@ -677,7 +684,7 @@ class FederationClient(FederationBase):
|
||||||
|
|
||||||
async def send_join(
|
async def send_join(
|
||||||
self, destinations: Iterable[str], pdu: EventBase, room_version: RoomVersion
|
self, destinations: Iterable[str], pdu: EventBase, room_version: RoomVersion
|
||||||
) -> Dict[str, Any]:
|
) -> SendJoinResult:
|
||||||
"""Sends a join event to one of a list of homeservers.
|
"""Sends a join event to one of a list of homeservers.
|
||||||
|
|
||||||
Doing so will cause the remote server to add the event to the graph,
|
Doing so will cause the remote server to add the event to the graph,
|
||||||
|
@ -691,18 +698,38 @@ class FederationClient(FederationBase):
|
||||||
did the make_join)
|
did the make_join)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
a dict with members ``origin`` (a string
|
The result of the send join request.
|
||||||
giving the server the event was sent to, ``state`` (?) and
|
|
||||||
``auth_chain``.
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
SynapseError: if the chosen remote server returns a 300/400 code, or
|
SynapseError: if the chosen remote server returns a 300/400 code, or
|
||||||
no servers successfully handle the request.
|
no servers successfully handle the request.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def send_request(destination) -> Dict[str, Any]:
|
async def send_request(destination) -> SendJoinResult:
|
||||||
response = await self._do_send_join(room_version, destination, pdu)
|
response = await self._do_send_join(room_version, destination, pdu)
|
||||||
|
|
||||||
|
# If an event was returned (and expected to be returned):
|
||||||
|
#
|
||||||
|
# * Ensure it has the same event ID (note that the event ID is a hash
|
||||||
|
# of the event fields for versions which support MSC3083).
|
||||||
|
# * Ensure the signatures are good.
|
||||||
|
#
|
||||||
|
# Otherwise, fallback to the provided event.
|
||||||
|
if room_version.msc3083_join_rules and response.event:
|
||||||
|
event = response.event
|
||||||
|
|
||||||
|
valid_pdu = await self._check_sigs_and_hash_and_fetch_one(
|
||||||
|
pdu=event,
|
||||||
|
origin=destination,
|
||||||
|
outlier=True,
|
||||||
|
room_version=room_version,
|
||||||
|
)
|
||||||
|
|
||||||
|
if valid_pdu is None or event.event_id != pdu.event_id:
|
||||||
|
raise InvalidResponseError("Returned an invalid join event")
|
||||||
|
else:
|
||||||
|
event = pdu
|
||||||
|
|
||||||
state = response.state
|
state = response.state
|
||||||
auth_chain = response.auth_events
|
auth_chain = response.auth_events
|
||||||
|
|
||||||
|
@ -784,11 +811,21 @@ class FederationClient(FederationBase):
|
||||||
% (auth_chain_create_events,)
|
% (auth_chain_create_events,)
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return SendJoinResult(
|
||||||
"state": signed_state,
|
event=event,
|
||||||
"auth_chain": signed_auth,
|
state=signed_state,
|
||||||
"origin": destination,
|
auth_chain=signed_auth,
|
||||||
}
|
origin=destination,
|
||||||
|
)
|
||||||
|
|
||||||
|
if room_version.msc3083_join_rules:
|
||||||
|
# If the join is being authorised via allow rules, we need to send
|
||||||
|
# the /send_join back to the same server that was originally used
|
||||||
|
# with /make_join.
|
||||||
|
if "join_authorised_via_users_server" in pdu.content:
|
||||||
|
destinations = [
|
||||||
|
get_domain_from_id(pdu.content["join_authorised_via_users_server"])
|
||||||
|
]
|
||||||
|
|
||||||
return await self._try_destination_list("send_join", destinations, send_request)
|
return await self._try_destination_list("send_join", destinations, send_request)
|
||||||
|
|
||||||
|
|
|
@ -45,6 +45,7 @@ from synapse.api.errors import (
|
||||||
UnsupportedRoomVersionError,
|
UnsupportedRoomVersionError,
|
||||||
)
|
)
|
||||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
|
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
|
||||||
|
from synapse.crypto.event_signing import compute_event_signature
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.events.snapshot import EventContext
|
from synapse.events.snapshot import EventContext
|
||||||
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
|
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
|
||||||
|
@ -64,7 +65,7 @@ from synapse.replication.http.federation import (
|
||||||
ReplicationGetQueryRestServlet,
|
ReplicationGetQueryRestServlet,
|
||||||
)
|
)
|
||||||
from synapse.storage.databases.main.lock import Lock
|
from synapse.storage.databases.main.lock import Lock
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict, get_domain_from_id
|
||||||
from synapse.util import glob_to_regex, json_decoder, unwrapFirstError
|
from synapse.util import glob_to_regex, json_decoder, unwrapFirstError
|
||||||
from synapse.util.async_helpers import Linearizer, concurrently_execute
|
from synapse.util.async_helpers import Linearizer, concurrently_execute
|
||||||
from synapse.util.caches.response_cache import ResponseCache
|
from synapse.util.caches.response_cache import ResponseCache
|
||||||
|
@ -586,7 +587,7 @@ class FederationServer(FederationBase):
|
||||||
async def on_send_join_request(
|
async def on_send_join_request(
|
||||||
self, origin: str, content: JsonDict, room_id: str
|
self, origin: str, content: JsonDict, room_id: str
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
context = await self._on_send_membership_event(
|
event, context = await self._on_send_membership_event(
|
||||||
origin, content, Membership.JOIN, room_id
|
origin, content, Membership.JOIN, room_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -597,6 +598,7 @@ class FederationServer(FederationBase):
|
||||||
|
|
||||||
time_now = self._clock.time_msec()
|
time_now = self._clock.time_msec()
|
||||||
return {
|
return {
|
||||||
|
"org.matrix.msc3083.v2.event": event.get_pdu_json(),
|
||||||
"state": [p.get_pdu_json(time_now) for p in state.values()],
|
"state": [p.get_pdu_json(time_now) for p in state.values()],
|
||||||
"auth_chain": [p.get_pdu_json(time_now) for p in auth_chain],
|
"auth_chain": [p.get_pdu_json(time_now) for p in auth_chain],
|
||||||
}
|
}
|
||||||
|
@ -681,7 +683,7 @@ class FederationServer(FederationBase):
|
||||||
Returns:
|
Returns:
|
||||||
The stripped room state.
|
The stripped room state.
|
||||||
"""
|
"""
|
||||||
event_context = await self._on_send_membership_event(
|
_, context = await self._on_send_membership_event(
|
||||||
origin, content, Membership.KNOCK, room_id
|
origin, content, Membership.KNOCK, room_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -690,14 +692,14 @@ class FederationServer(FederationBase):
|
||||||
# related to the room while the knock request is pending.
|
# related to the room while the knock request is pending.
|
||||||
stripped_room_state = (
|
stripped_room_state = (
|
||||||
await self.store.get_stripped_room_state_from_event_context(
|
await self.store.get_stripped_room_state_from_event_context(
|
||||||
event_context, self._room_prejoin_state_types
|
context, self._room_prejoin_state_types
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return {"knock_state_events": stripped_room_state}
|
return {"knock_state_events": stripped_room_state}
|
||||||
|
|
||||||
async def _on_send_membership_event(
|
async def _on_send_membership_event(
|
||||||
self, origin: str, content: JsonDict, membership_type: str, room_id: str
|
self, origin: str, content: JsonDict, membership_type: str, room_id: str
|
||||||
) -> EventContext:
|
) -> Tuple[EventBase, EventContext]:
|
||||||
"""Handle an on_send_{join,leave,knock} request
|
"""Handle an on_send_{join,leave,knock} request
|
||||||
|
|
||||||
Does some preliminary validation before passing the request on to the
|
Does some preliminary validation before passing the request on to the
|
||||||
|
@ -712,7 +714,7 @@ class FederationServer(FederationBase):
|
||||||
in the event
|
in the event
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The context of the event after inserting it into the room graph.
|
The event and context of the event after inserting it into the room graph.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
SynapseError if there is a problem with the request, including things like
|
SynapseError if there is a problem with the request, including things like
|
||||||
|
@ -748,6 +750,33 @@ class FederationServer(FederationBase):
|
||||||
|
|
||||||
logger.debug("_on_send_membership_event: pdu sigs: %s", event.signatures)
|
logger.debug("_on_send_membership_event: pdu sigs: %s", event.signatures)
|
||||||
|
|
||||||
|
# Sign the event since we're vouching on behalf of the remote server that
|
||||||
|
# the event is valid to be sent into the room. Currently this is only done
|
||||||
|
# if the user is being joined via restricted join rules.
|
||||||
|
if (
|
||||||
|
room_version.msc3083_join_rules
|
||||||
|
and event.membership == Membership.JOIN
|
||||||
|
and "join_authorised_via_users_server" in event.content
|
||||||
|
):
|
||||||
|
# We can only authorise our own users.
|
||||||
|
authorising_server = get_domain_from_id(
|
||||||
|
event.content["join_authorised_via_users_server"]
|
||||||
|
)
|
||||||
|
if authorising_server != self.server_name:
|
||||||
|
raise SynapseError(
|
||||||
|
400,
|
||||||
|
f"Cannot authorise request from resident server: {authorising_server}",
|
||||||
|
)
|
||||||
|
|
||||||
|
event.signatures.update(
|
||||||
|
compute_event_signature(
|
||||||
|
room_version,
|
||||||
|
event.get_pdu_json(),
|
||||||
|
self.hs.hostname,
|
||||||
|
self.hs.signing_key,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
event = await self._check_sigs_and_hash(room_version, event)
|
event = await self._check_sigs_and_hash(room_version, event)
|
||||||
|
|
||||||
return await self.handler.on_send_membership_event(origin, event)
|
return await self.handler.on_send_membership_event(origin, event)
|
||||||
|
|
|
@ -1219,8 +1219,26 @@ def _create_v2_path(path: str, *args: str) -> str:
|
||||||
class SendJoinResponse:
|
class SendJoinResponse:
|
||||||
"""The parsed response of a `/send_join` request."""
|
"""The parsed response of a `/send_join` request."""
|
||||||
|
|
||||||
|
# The list of auth events from the /send_join response.
|
||||||
auth_events: List[EventBase]
|
auth_events: List[EventBase]
|
||||||
|
# The list of state from the /send_join response.
|
||||||
state: List[EventBase]
|
state: List[EventBase]
|
||||||
|
# The raw join event from the /send_join response.
|
||||||
|
event_dict: JsonDict
|
||||||
|
# The parsed join event from the /send_join response. This will be None if
|
||||||
|
# "event" is not included in the response.
|
||||||
|
event: Optional[EventBase] = None
|
||||||
|
|
||||||
|
|
||||||
|
@ijson.coroutine
|
||||||
|
def _event_parser(event_dict: JsonDict):
|
||||||
|
"""Helper function for use with `ijson.kvitems_coro` to parse key-value pairs
|
||||||
|
to add them to a given dictionary.
|
||||||
|
"""
|
||||||
|
|
||||||
|
while True:
|
||||||
|
key, value = yield
|
||||||
|
event_dict[key] = value
|
||||||
|
|
||||||
|
|
||||||
@ijson.coroutine
|
@ijson.coroutine
|
||||||
|
@ -1246,7 +1264,8 @@ class SendJoinParser(ByteParser[SendJoinResponse]):
|
||||||
CONTENT_TYPE = "application/json"
|
CONTENT_TYPE = "application/json"
|
||||||
|
|
||||||
def __init__(self, room_version: RoomVersion, v1_api: bool):
|
def __init__(self, room_version: RoomVersion, v1_api: bool):
|
||||||
self._response = SendJoinResponse([], [])
|
self._response = SendJoinResponse([], [], {})
|
||||||
|
self._room_version = room_version
|
||||||
|
|
||||||
# The V1 API has the shape of `[200, {...}]`, which we handle by
|
# The V1 API has the shape of `[200, {...}]`, which we handle by
|
||||||
# prefixing with `item.*`.
|
# prefixing with `item.*`.
|
||||||
|
@ -1260,12 +1279,21 @@ class SendJoinParser(ByteParser[SendJoinResponse]):
|
||||||
_event_list_parser(room_version, self._response.auth_events),
|
_event_list_parser(room_version, self._response.auth_events),
|
||||||
prefix + "auth_chain.item",
|
prefix + "auth_chain.item",
|
||||||
)
|
)
|
||||||
|
self._coro_event = ijson.kvitems_coro(
|
||||||
|
_event_parser(self._response.event_dict),
|
||||||
|
prefix + "org.matrix.msc3083.v2.event",
|
||||||
|
)
|
||||||
|
|
||||||
def write(self, data: bytes) -> int:
|
def write(self, data: bytes) -> int:
|
||||||
self._coro_state.send(data)
|
self._coro_state.send(data)
|
||||||
self._coro_auth.send(data)
|
self._coro_auth.send(data)
|
||||||
|
self._coro_event.send(data)
|
||||||
|
|
||||||
return len(data)
|
return len(data)
|
||||||
|
|
||||||
def finish(self) -> SendJoinResponse:
|
def finish(self) -> SendJoinResponse:
|
||||||
|
if self._response.event_dict:
|
||||||
|
self._response.event = make_event_from_dict(
|
||||||
|
self._response.event_dict, self._room_version
|
||||||
|
)
|
||||||
return self._response
|
return self._response
|
||||||
|
|
|
@ -11,6 +11,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import logging
|
||||||
from typing import TYPE_CHECKING, Collection, List, Optional, Union
|
from typing import TYPE_CHECKING, Collection, List, Optional, Union
|
||||||
|
|
||||||
from synapse import event_auth
|
from synapse import event_auth
|
||||||
|
@ -20,16 +21,18 @@ from synapse.api.constants import (
|
||||||
Membership,
|
Membership,
|
||||||
RestrictedJoinRuleTypes,
|
RestrictedJoinRuleTypes,
|
||||||
)
|
)
|
||||||
from synapse.api.errors import AuthError
|
from synapse.api.errors import AuthError, Codes, SynapseError
|
||||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
|
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.events.builder import EventBuilder
|
from synapse.events.builder import EventBuilder
|
||||||
from synapse.types import StateMap
|
from synapse.types import StateMap, get_domain_from_id
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class EventAuthHandler:
|
class EventAuthHandler:
|
||||||
"""
|
"""
|
||||||
|
@ -39,6 +42,7 @@ class EventAuthHandler:
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self._clock = hs.get_clock()
|
self._clock = hs.get_clock()
|
||||||
self._store = hs.get_datastore()
|
self._store = hs.get_datastore()
|
||||||
|
self._server_name = hs.hostname
|
||||||
|
|
||||||
async def check_from_context(
|
async def check_from_context(
|
||||||
self, room_version: str, event, context, do_sig_check=True
|
self, room_version: str, event, context, do_sig_check=True
|
||||||
|
@ -81,15 +85,76 @@ class EventAuthHandler:
|
||||||
# introduce undesirable "state reset" behaviour.
|
# introduce undesirable "state reset" behaviour.
|
||||||
#
|
#
|
||||||
# All of which sounds a bit tricky so we don't bother for now.
|
# All of which sounds a bit tricky so we don't bother for now.
|
||||||
|
|
||||||
auth_ids = []
|
auth_ids = []
|
||||||
for etype, state_key in event_auth.auth_types_for_event(event):
|
for etype, state_key in event_auth.auth_types_for_event(
|
||||||
|
event.room_version, event
|
||||||
|
):
|
||||||
auth_ev_id = current_state_ids.get((etype, state_key))
|
auth_ev_id = current_state_ids.get((etype, state_key))
|
||||||
if auth_ev_id:
|
if auth_ev_id:
|
||||||
auth_ids.append(auth_ev_id)
|
auth_ids.append(auth_ev_id)
|
||||||
|
|
||||||
return auth_ids
|
return auth_ids
|
||||||
|
|
||||||
|
async def get_user_which_could_invite(
|
||||||
|
self, room_id: str, current_state_ids: StateMap[str]
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Searches the room state for a local user who has the power level necessary
|
||||||
|
to invite other users.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id: The room ID under search.
|
||||||
|
current_state_ids: The current state of the room.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The MXID of the user which could issue an invite.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SynapseError if no appropriate user is found.
|
||||||
|
"""
|
||||||
|
power_level_event_id = current_state_ids.get((EventTypes.PowerLevels, ""))
|
||||||
|
invite_level = 0
|
||||||
|
users_default_level = 0
|
||||||
|
if power_level_event_id:
|
||||||
|
power_level_event = await self._store.get_event(power_level_event_id)
|
||||||
|
invite_level = power_level_event.content.get("invite", invite_level)
|
||||||
|
users_default_level = power_level_event.content.get(
|
||||||
|
"users_default", users_default_level
|
||||||
|
)
|
||||||
|
users = power_level_event.content.get("users", {})
|
||||||
|
else:
|
||||||
|
users = {}
|
||||||
|
|
||||||
|
# Find the user with the highest power level.
|
||||||
|
users_in_room = await self._store.get_users_in_room(room_id)
|
||||||
|
# Only interested in local users.
|
||||||
|
local_users_in_room = [
|
||||||
|
u for u in users_in_room if get_domain_from_id(u) == self._server_name
|
||||||
|
]
|
||||||
|
chosen_user = max(
|
||||||
|
local_users_in_room,
|
||||||
|
key=lambda user: users.get(user, users_default_level),
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return the chosen if they can issue invites.
|
||||||
|
user_power_level = users.get(chosen_user, users_default_level)
|
||||||
|
if chosen_user and user_power_level >= invite_level:
|
||||||
|
logger.debug(
|
||||||
|
"Found a user who can issue invites %s with power level %d >= invite level %d",
|
||||||
|
chosen_user,
|
||||||
|
user_power_level,
|
||||||
|
invite_level,
|
||||||
|
)
|
||||||
|
return chosen_user
|
||||||
|
|
||||||
|
# No user was found.
|
||||||
|
raise SynapseError(
|
||||||
|
400,
|
||||||
|
"Unable to find a user which could issue an invite",
|
||||||
|
Codes.UNABLE_TO_GRANT_JOIN,
|
||||||
|
)
|
||||||
|
|
||||||
async def check_host_in_room(self, room_id: str, host: str) -> bool:
|
async def check_host_in_room(self, room_id: str, host: str) -> bool:
|
||||||
with Measure(self._clock, "check_host_in_room"):
|
with Measure(self._clock, "check_host_in_room"):
|
||||||
return await self._store.is_host_joined(room_id, host)
|
return await self._store.is_host_joined(room_id, host)
|
||||||
|
@ -134,6 +199,18 @@ class EventAuthHandler:
|
||||||
# in any of them.
|
# in any of them.
|
||||||
allowed_rooms = await self.get_rooms_that_allow_join(state_ids)
|
allowed_rooms = await self.get_rooms_that_allow_join(state_ids)
|
||||||
if not await self.is_user_in_rooms(allowed_rooms, user_id):
|
if not await self.is_user_in_rooms(allowed_rooms, user_id):
|
||||||
|
|
||||||
|
# If this is a remote request, the user might be in an allowed room
|
||||||
|
# that we do not know about.
|
||||||
|
if get_domain_from_id(user_id) != self._server_name:
|
||||||
|
for room_id in allowed_rooms:
|
||||||
|
if not await self._store.is_host_joined(room_id, self._server_name):
|
||||||
|
raise SynapseError(
|
||||||
|
400,
|
||||||
|
f"Unable to check if {user_id} is in allowed rooms.",
|
||||||
|
Codes.UNABLE_AUTHORISE_JOIN,
|
||||||
|
)
|
||||||
|
|
||||||
raise AuthError(
|
raise AuthError(
|
||||||
403,
|
403,
|
||||||
"You do not belong to any of the required rooms to join this room.",
|
"You do not belong to any of the required rooms to join this room.",
|
||||||
|
|
|
@ -1494,9 +1494,10 @@ class FederationHandler(BaseHandler):
|
||||||
host_list, event, room_version_obj
|
host_list, event, room_version_obj
|
||||||
)
|
)
|
||||||
|
|
||||||
origin = ret["origin"]
|
event = ret.event
|
||||||
state = ret["state"]
|
origin = ret.origin
|
||||||
auth_chain = ret["auth_chain"]
|
state = ret.state
|
||||||
|
auth_chain = ret.auth_chain
|
||||||
auth_chain.sort(key=lambda e: e.depth)
|
auth_chain.sort(key=lambda e: e.depth)
|
||||||
|
|
||||||
logger.debug("do_invite_join auth_chain: %s", auth_chain)
|
logger.debug("do_invite_join auth_chain: %s", auth_chain)
|
||||||
|
@ -1676,7 +1677,7 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
# checking the room version will check that we've actually heard of the room
|
# checking the room version will check that we've actually heard of the room
|
||||||
# (and return a 404 otherwise)
|
# (and return a 404 otherwise)
|
||||||
room_version = await self.store.get_room_version_id(room_id)
|
room_version = await self.store.get_room_version(room_id)
|
||||||
|
|
||||||
# now check that we are *still* in the room
|
# now check that we are *still* in the room
|
||||||
is_in_room = await self._event_auth_handler.check_host_in_room(
|
is_in_room = await self._event_auth_handler.check_host_in_room(
|
||||||
|
@ -1691,8 +1692,38 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
event_content = {"membership": Membership.JOIN}
|
event_content = {"membership": Membership.JOIN}
|
||||||
|
|
||||||
|
# If the current room is using restricted join rules, additional information
|
||||||
|
# may need to be included in the event content in order to efficiently
|
||||||
|
# validate the event.
|
||||||
|
#
|
||||||
|
# Note that this requires the /send_join request to come back to the
|
||||||
|
# same server.
|
||||||
|
if room_version.msc3083_join_rules:
|
||||||
|
state_ids = await self.store.get_current_state_ids(room_id)
|
||||||
|
if await self._event_auth_handler.has_restricted_join_rules(
|
||||||
|
state_ids, room_version
|
||||||
|
):
|
||||||
|
prev_member_event_id = state_ids.get((EventTypes.Member, user_id), None)
|
||||||
|
# If the user is invited or joined to the room already, then
|
||||||
|
# no additional info is needed.
|
||||||
|
include_auth_user_id = True
|
||||||
|
if prev_member_event_id:
|
||||||
|
prev_member_event = await self.store.get_event(prev_member_event_id)
|
||||||
|
include_auth_user_id = prev_member_event.membership not in (
|
||||||
|
Membership.JOIN,
|
||||||
|
Membership.INVITE,
|
||||||
|
)
|
||||||
|
|
||||||
|
if include_auth_user_id:
|
||||||
|
event_content[
|
||||||
|
"join_authorised_via_users_server"
|
||||||
|
] = await self._event_auth_handler.get_user_which_could_invite(
|
||||||
|
room_id,
|
||||||
|
state_ids,
|
||||||
|
)
|
||||||
|
|
||||||
builder = self.event_builder_factory.new(
|
builder = self.event_builder_factory.new(
|
||||||
room_version,
|
room_version.identifier,
|
||||||
{
|
{
|
||||||
"type": EventTypes.Member,
|
"type": EventTypes.Member,
|
||||||
"content": event_content,
|
"content": event_content,
|
||||||
|
@ -1710,10 +1741,13 @@ class FederationHandler(BaseHandler):
|
||||||
logger.warning("Failed to create join to %s because %s", room_id, e)
|
logger.warning("Failed to create join to %s because %s", room_id, e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
# Ensure the user can even join the room.
|
||||||
|
await self._check_join_restrictions(context, event)
|
||||||
|
|
||||||
# The remote hasn't signed it yet, obviously. We'll do the full checks
|
# The remote hasn't signed it yet, obviously. We'll do the full checks
|
||||||
# when we get the event back in `on_send_join_request`
|
# when we get the event back in `on_send_join_request`
|
||||||
await self._event_auth_handler.check_from_context(
|
await self._event_auth_handler.check_from_context(
|
||||||
room_version, event, context, do_sig_check=False
|
room_version.identifier, event, context, do_sig_check=False
|
||||||
)
|
)
|
||||||
|
|
||||||
return event
|
return event
|
||||||
|
@ -1958,7 +1992,7 @@ class FederationHandler(BaseHandler):
|
||||||
@log_function
|
@log_function
|
||||||
async def on_send_membership_event(
|
async def on_send_membership_event(
|
||||||
self, origin: str, event: EventBase
|
self, origin: str, event: EventBase
|
||||||
) -> EventContext:
|
) -> Tuple[EventBase, EventContext]:
|
||||||
"""
|
"""
|
||||||
We have received a join/leave/knock event for a room via send_join/leave/knock.
|
We have received a join/leave/knock event for a room via send_join/leave/knock.
|
||||||
|
|
||||||
|
@ -1981,7 +2015,7 @@ class FederationHandler(BaseHandler):
|
||||||
event: The member event that has been signed by the remote homeserver.
|
event: The member event that has been signed by the remote homeserver.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The context of the event after inserting it into the room graph.
|
The event and context of the event after inserting it into the room graph.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
SynapseError if the event is not accepted into the room
|
SynapseError if the event is not accepted into the room
|
||||||
|
@ -2037,7 +2071,7 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
# all looks good, we can persist the event.
|
# all looks good, we can persist the event.
|
||||||
await self._run_push_actions_and_persist_event(event, context)
|
await self._run_push_actions_and_persist_event(event, context)
|
||||||
return context
|
return event, context
|
||||||
|
|
||||||
async def _check_join_restrictions(
|
async def _check_join_restrictions(
|
||||||
self, context: EventContext, event: EventBase
|
self, context: EventContext, event: EventBase
|
||||||
|
@ -2473,7 +2507,7 @@ class FederationHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Now check if event pass auth against said current state
|
# Now check if event pass auth against said current state
|
||||||
auth_types = auth_types_for_event(event)
|
auth_types = auth_types_for_event(room_version_obj, event)
|
||||||
current_state_ids_list = [
|
current_state_ids_list = [
|
||||||
e for k, e in current_state_ids.items() if k in auth_types
|
e for k, e in current_state_ids.items() if k in auth_types
|
||||||
]
|
]
|
||||||
|
|
|
@ -16,7 +16,7 @@ import abc
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Iterable, List, Optional, Set, Tuple
|
||||||
|
|
||||||
from synapse import types
|
from synapse import types
|
||||||
from synapse.api.constants import AccountDataTypes, EventTypes, Membership
|
from synapse.api.constants import AccountDataTypes, EventTypes, Membership
|
||||||
|
@ -28,6 +28,7 @@ from synapse.api.errors import (
|
||||||
SynapseError,
|
SynapseError,
|
||||||
)
|
)
|
||||||
from synapse.api.ratelimiting import Ratelimiter
|
from synapse.api.ratelimiting import Ratelimiter
|
||||||
|
from synapse.event_auth import get_named_level, get_power_level_event
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.events.snapshot import EventContext
|
from synapse.events.snapshot import EventContext
|
||||||
from synapse.types import (
|
from synapse.types import (
|
||||||
|
@ -340,16 +341,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||||
|
|
||||||
if event.membership == Membership.JOIN:
|
if event.membership == Membership.JOIN:
|
||||||
newly_joined = True
|
newly_joined = True
|
||||||
prev_member_event = None
|
|
||||||
if prev_member_event_id:
|
if prev_member_event_id:
|
||||||
prev_member_event = await self.store.get_event(prev_member_event_id)
|
prev_member_event = await self.store.get_event(prev_member_event_id)
|
||||||
newly_joined = prev_member_event.membership != Membership.JOIN
|
newly_joined = prev_member_event.membership != Membership.JOIN
|
||||||
|
|
||||||
# Check if the member should be allowed access via membership in a space.
|
|
||||||
await self.event_auth_handler.check_restricted_join_rules(
|
|
||||||
prev_state_ids, event.room_version, user_id, prev_member_event
|
|
||||||
)
|
|
||||||
|
|
||||||
# Only rate-limit if the user actually joined the room, otherwise we'll end
|
# Only rate-limit if the user actually joined the room, otherwise we'll end
|
||||||
# up blocking profile updates.
|
# up blocking profile updates.
|
||||||
if newly_joined and ratelimit:
|
if newly_joined and ratelimit:
|
||||||
|
@ -701,7 +696,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||||
# so don't really fit into the general auth process.
|
# so don't really fit into the general auth process.
|
||||||
raise AuthError(403, "Guest access not allowed")
|
raise AuthError(403, "Guest access not allowed")
|
||||||
|
|
||||||
if not is_host_in_room:
|
# Check if a remote join should be performed.
|
||||||
|
remote_join, remote_room_hosts = await self._should_perform_remote_join(
|
||||||
|
target.to_string(), room_id, remote_room_hosts, content, is_host_in_room
|
||||||
|
)
|
||||||
|
if remote_join:
|
||||||
if ratelimit:
|
if ratelimit:
|
||||||
time_now_s = self.clock.time()
|
time_now_s = self.clock.time()
|
||||||
(
|
(
|
||||||
|
@ -826,6 +825,106 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||||
outlier=outlier,
|
outlier=outlier,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _should_perform_remote_join(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
room_id: str,
|
||||||
|
remote_room_hosts: List[str],
|
||||||
|
content: JsonDict,
|
||||||
|
is_host_in_room: bool,
|
||||||
|
) -> Tuple[bool, List[str]]:
|
||||||
|
"""
|
||||||
|
Check whether the server should do a remote join (as opposed to a local
|
||||||
|
join) for a user.
|
||||||
|
|
||||||
|
Generally a remote join is used if:
|
||||||
|
|
||||||
|
* The server is not yet in the room.
|
||||||
|
* The server is in the room, the room has restricted join rules, the user
|
||||||
|
is not joined or invited to the room, and the server does not have
|
||||||
|
another user who is capable of issuing invites.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user joining the room.
|
||||||
|
room_id: The room being joined.
|
||||||
|
remote_room_hosts: A list of remote room hosts.
|
||||||
|
content: The content to use as the event body of the join. This may
|
||||||
|
be modified.
|
||||||
|
is_host_in_room: True if the host is in the room.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of:
|
||||||
|
True if a remote join should be performed. False if the join can be
|
||||||
|
done locally.
|
||||||
|
|
||||||
|
A list of remote room hosts to use. This is an empty list if a
|
||||||
|
local join is to be done.
|
||||||
|
"""
|
||||||
|
# If the host isn't in the room, pass through the prospective hosts.
|
||||||
|
if not is_host_in_room:
|
||||||
|
return True, remote_room_hosts
|
||||||
|
|
||||||
|
# If the host is in the room, but not one of the authorised hosts
|
||||||
|
# for restricted join rules, a remote join must be used.
|
||||||
|
room_version = await self.store.get_room_version(room_id)
|
||||||
|
current_state_ids = await self.store.get_current_state_ids(room_id)
|
||||||
|
|
||||||
|
# If restricted join rules are not being used, a local join can always
|
||||||
|
# be used.
|
||||||
|
if not await self.event_auth_handler.has_restricted_join_rules(
|
||||||
|
current_state_ids, room_version
|
||||||
|
):
|
||||||
|
return False, []
|
||||||
|
|
||||||
|
# If the user is invited to the room or already joined, the join
|
||||||
|
# event can always be issued locally.
|
||||||
|
prev_member_event_id = current_state_ids.get((EventTypes.Member, user_id), None)
|
||||||
|
prev_member_event = None
|
||||||
|
if prev_member_event_id:
|
||||||
|
prev_member_event = await self.store.get_event(prev_member_event_id)
|
||||||
|
if prev_member_event.membership in (
|
||||||
|
Membership.JOIN,
|
||||||
|
Membership.INVITE,
|
||||||
|
):
|
||||||
|
return False, []
|
||||||
|
|
||||||
|
# If the local host has a user who can issue invites, then a local
|
||||||
|
# join can be done.
|
||||||
|
#
|
||||||
|
# If not, generate a new list of remote hosts based on which
|
||||||
|
# can issue invites.
|
||||||
|
event_map = await self.store.get_events(current_state_ids.values())
|
||||||
|
current_state = {
|
||||||
|
state_key: event_map[event_id]
|
||||||
|
for state_key, event_id in current_state_ids.items()
|
||||||
|
}
|
||||||
|
allowed_servers = get_servers_from_users(
|
||||||
|
get_users_which_can_issue_invite(current_state)
|
||||||
|
)
|
||||||
|
|
||||||
|
# If the local server is not one of allowed servers, then a remote
|
||||||
|
# join must be done. Return the list of prospective servers based on
|
||||||
|
# which can issue invites.
|
||||||
|
if self.hs.hostname not in allowed_servers:
|
||||||
|
return True, list(allowed_servers)
|
||||||
|
|
||||||
|
# Ensure the member should be allowed access via membership in a room.
|
||||||
|
await self.event_auth_handler.check_restricted_join_rules(
|
||||||
|
current_state_ids, room_version, user_id, prev_member_event
|
||||||
|
)
|
||||||
|
|
||||||
|
# If this is going to be a local join, additional information must
|
||||||
|
# be included in the event content in order to efficiently validate
|
||||||
|
# the event.
|
||||||
|
content[
|
||||||
|
"join_authorised_via_users_server"
|
||||||
|
] = await self.event_auth_handler.get_user_which_could_invite(
|
||||||
|
room_id,
|
||||||
|
current_state_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
return False, []
|
||||||
|
|
||||||
async def transfer_room_state_on_room_upgrade(
|
async def transfer_room_state_on_room_upgrade(
|
||||||
self, old_room_id: str, room_id: str
|
self, old_room_id: str, room_id: str
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -1514,3 +1613,63 @@ class RoomMemberMasterHandler(RoomMemberHandler):
|
||||||
|
|
||||||
if membership:
|
if membership:
|
||||||
await self.store.forget(user_id, room_id)
|
await self.store.forget(user_id, room_id)
|
||||||
|
|
||||||
|
|
||||||
|
def get_users_which_can_issue_invite(auth_events: StateMap[EventBase]) -> List[str]:
|
||||||
|
"""
|
||||||
|
Return the list of users which can issue invites.
|
||||||
|
|
||||||
|
This is done by exploring the joined users and comparing their power levels
|
||||||
|
to the necessyar power level to issue an invite.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
auth_events: state in force at this point in the room
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The users which can issue invites.
|
||||||
|
"""
|
||||||
|
invite_level = get_named_level(auth_events, "invite", 0)
|
||||||
|
users_default_level = get_named_level(auth_events, "users_default", 0)
|
||||||
|
power_level_event = get_power_level_event(auth_events)
|
||||||
|
|
||||||
|
# Custom power-levels for users.
|
||||||
|
if power_level_event:
|
||||||
|
users = power_level_event.content.get("users", {})
|
||||||
|
else:
|
||||||
|
users = {}
|
||||||
|
|
||||||
|
result = []
|
||||||
|
|
||||||
|
# Check which members are able to invite by ensuring they're joined and have
|
||||||
|
# the necessary power level.
|
||||||
|
for (event_type, state_key), event in auth_events.items():
|
||||||
|
if event_type != EventTypes.Member:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if event.membership != Membership.JOIN:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if the user has a custom power level.
|
||||||
|
if users.get(state_key, users_default_level) >= invite_level:
|
||||||
|
result.append(state_key)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def get_servers_from_users(users: List[str]) -> Set[str]:
|
||||||
|
"""
|
||||||
|
Resolve a list of users into their servers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
users: A list of users.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A set of servers.
|
||||||
|
"""
|
||||||
|
servers = set()
|
||||||
|
for user in users:
|
||||||
|
try:
|
||||||
|
servers.add(get_domain_from_id(user))
|
||||||
|
except SynapseError:
|
||||||
|
pass
|
||||||
|
return servers
|
||||||
|
|
|
@ -636,16 +636,20 @@ class StateResolutionHandler:
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
with Measure(self.clock, "state._resolve_events") as m:
|
with Measure(self.clock, "state._resolve_events") as m:
|
||||||
v = KNOWN_ROOM_VERSIONS[room_version]
|
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
|
||||||
if v.state_res == StateResolutionVersions.V1:
|
if room_version_obj.state_res == StateResolutionVersions.V1:
|
||||||
return await v1.resolve_events_with_store(
|
return await v1.resolve_events_with_store(
|
||||||
room_id, state_sets, event_map, state_res_store.get_events
|
room_id,
|
||||||
|
room_version_obj,
|
||||||
|
state_sets,
|
||||||
|
event_map,
|
||||||
|
state_res_store.get_events,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return await v2.resolve_events_with_store(
|
return await v2.resolve_events_with_store(
|
||||||
self.clock,
|
self.clock,
|
||||||
room_id,
|
room_id,
|
||||||
room_version,
|
room_version_obj,
|
||||||
state_sets,
|
state_sets,
|
||||||
event_map,
|
event_map,
|
||||||
state_res_store,
|
state_res_store,
|
||||||
|
|
|
@ -29,7 +29,7 @@ from typing import (
|
||||||
from synapse import event_auth
|
from synapse import event_auth
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.api.errors import AuthError
|
from synapse.api.errors import AuthError
|
||||||
from synapse.api.room_versions import RoomVersions
|
from synapse.api.room_versions import RoomVersion, RoomVersions
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.types import MutableStateMap, StateMap
|
from synapse.types import MutableStateMap, StateMap
|
||||||
|
|
||||||
|
@ -41,6 +41,7 @@ POWER_KEY = (EventTypes.PowerLevels, "")
|
||||||
|
|
||||||
async def resolve_events_with_store(
|
async def resolve_events_with_store(
|
||||||
room_id: str,
|
room_id: str,
|
||||||
|
room_version: RoomVersion,
|
||||||
state_sets: Sequence[StateMap[str]],
|
state_sets: Sequence[StateMap[str]],
|
||||||
event_map: Optional[Dict[str, EventBase]],
|
event_map: Optional[Dict[str, EventBase]],
|
||||||
state_map_factory: Callable[[Iterable[str]], Awaitable[Dict[str, EventBase]]],
|
state_map_factory: Callable[[Iterable[str]], Awaitable[Dict[str, EventBase]]],
|
||||||
|
@ -104,7 +105,7 @@ async def resolve_events_with_store(
|
||||||
# get the ids of the auth events which allow us to authenticate the
|
# get the ids of the auth events which allow us to authenticate the
|
||||||
# conflicted state, picking only from the unconflicting state.
|
# conflicted state, picking only from the unconflicting state.
|
||||||
auth_events = _create_auth_events_from_maps(
|
auth_events = _create_auth_events_from_maps(
|
||||||
unconflicted_state, conflicted_state, state_map
|
room_version, unconflicted_state, conflicted_state, state_map
|
||||||
)
|
)
|
||||||
|
|
||||||
new_needed_events = set(auth_events.values())
|
new_needed_events = set(auth_events.values())
|
||||||
|
@ -132,7 +133,7 @@ async def resolve_events_with_store(
|
||||||
state_map.update(state_map_new)
|
state_map.update(state_map_new)
|
||||||
|
|
||||||
return _resolve_with_state(
|
return _resolve_with_state(
|
||||||
unconflicted_state, conflicted_state, auth_events, state_map
|
room_version, unconflicted_state, conflicted_state, auth_events, state_map
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -187,6 +188,7 @@ def _seperate(
|
||||||
|
|
||||||
|
|
||||||
def _create_auth_events_from_maps(
|
def _create_auth_events_from_maps(
|
||||||
|
room_version: RoomVersion,
|
||||||
unconflicted_state: StateMap[str],
|
unconflicted_state: StateMap[str],
|
||||||
conflicted_state: StateMap[Set[str]],
|
conflicted_state: StateMap[Set[str]],
|
||||||
state_map: Dict[str, EventBase],
|
state_map: Dict[str, EventBase],
|
||||||
|
@ -194,6 +196,7 @@ def _create_auth_events_from_maps(
|
||||||
"""
|
"""
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
room_version: The room version.
|
||||||
unconflicted_state: The unconflicted state map.
|
unconflicted_state: The unconflicted state map.
|
||||||
conflicted_state: The conflicted state map.
|
conflicted_state: The conflicted state map.
|
||||||
state_map:
|
state_map:
|
||||||
|
@ -205,7 +208,9 @@ def _create_auth_events_from_maps(
|
||||||
for event_ids in conflicted_state.values():
|
for event_ids in conflicted_state.values():
|
||||||
for event_id in event_ids:
|
for event_id in event_ids:
|
||||||
if event_id in state_map:
|
if event_id in state_map:
|
||||||
keys = event_auth.auth_types_for_event(state_map[event_id])
|
keys = event_auth.auth_types_for_event(
|
||||||
|
room_version, state_map[event_id]
|
||||||
|
)
|
||||||
for key in keys:
|
for key in keys:
|
||||||
if key not in auth_events:
|
if key not in auth_events:
|
||||||
auth_event_id = unconflicted_state.get(key, None)
|
auth_event_id = unconflicted_state.get(key, None)
|
||||||
|
@ -215,6 +220,7 @@ def _create_auth_events_from_maps(
|
||||||
|
|
||||||
|
|
||||||
def _resolve_with_state(
|
def _resolve_with_state(
|
||||||
|
room_version: RoomVersion,
|
||||||
unconflicted_state_ids: MutableStateMap[str],
|
unconflicted_state_ids: MutableStateMap[str],
|
||||||
conflicted_state_ids: StateMap[Set[str]],
|
conflicted_state_ids: StateMap[Set[str]],
|
||||||
auth_event_ids: StateMap[str],
|
auth_event_ids: StateMap[str],
|
||||||
|
@ -235,7 +241,9 @@ def _resolve_with_state(
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
resolved_state = _resolve_state_events(conflicted_state, auth_events)
|
resolved_state = _resolve_state_events(
|
||||||
|
room_version, conflicted_state, auth_events
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to resolve state")
|
logger.exception("Failed to resolve state")
|
||||||
raise
|
raise
|
||||||
|
@ -248,7 +256,9 @@ def _resolve_with_state(
|
||||||
|
|
||||||
|
|
||||||
def _resolve_state_events(
|
def _resolve_state_events(
|
||||||
conflicted_state: StateMap[List[EventBase]], auth_events: MutableStateMap[EventBase]
|
room_version: RoomVersion,
|
||||||
|
conflicted_state: StateMap[List[EventBase]],
|
||||||
|
auth_events: MutableStateMap[EventBase],
|
||||||
) -> StateMap[EventBase]:
|
) -> StateMap[EventBase]:
|
||||||
"""This is where we actually decide which of the conflicted state to
|
"""This is where we actually decide which of the conflicted state to
|
||||||
use.
|
use.
|
||||||
|
@ -263,21 +273,27 @@ def _resolve_state_events(
|
||||||
if POWER_KEY in conflicted_state:
|
if POWER_KEY in conflicted_state:
|
||||||
events = conflicted_state[POWER_KEY]
|
events = conflicted_state[POWER_KEY]
|
||||||
logger.debug("Resolving conflicted power levels %r", events)
|
logger.debug("Resolving conflicted power levels %r", events)
|
||||||
resolved_state[POWER_KEY] = _resolve_auth_events(events, auth_events)
|
resolved_state[POWER_KEY] = _resolve_auth_events(
|
||||||
|
room_version, events, auth_events
|
||||||
|
)
|
||||||
|
|
||||||
auth_events.update(resolved_state)
|
auth_events.update(resolved_state)
|
||||||
|
|
||||||
for key, events in conflicted_state.items():
|
for key, events in conflicted_state.items():
|
||||||
if key[0] == EventTypes.JoinRules:
|
if key[0] == EventTypes.JoinRules:
|
||||||
logger.debug("Resolving conflicted join rules %r", events)
|
logger.debug("Resolving conflicted join rules %r", events)
|
||||||
resolved_state[key] = _resolve_auth_events(events, auth_events)
|
resolved_state[key] = _resolve_auth_events(
|
||||||
|
room_version, events, auth_events
|
||||||
|
)
|
||||||
|
|
||||||
auth_events.update(resolved_state)
|
auth_events.update(resolved_state)
|
||||||
|
|
||||||
for key, events in conflicted_state.items():
|
for key, events in conflicted_state.items():
|
||||||
if key[0] == EventTypes.Member:
|
if key[0] == EventTypes.Member:
|
||||||
logger.debug("Resolving conflicted member lists %r", events)
|
logger.debug("Resolving conflicted member lists %r", events)
|
||||||
resolved_state[key] = _resolve_auth_events(events, auth_events)
|
resolved_state[key] = _resolve_auth_events(
|
||||||
|
room_version, events, auth_events
|
||||||
|
)
|
||||||
|
|
||||||
auth_events.update(resolved_state)
|
auth_events.update(resolved_state)
|
||||||
|
|
||||||
|
@ -290,12 +306,14 @@ def _resolve_state_events(
|
||||||
|
|
||||||
|
|
||||||
def _resolve_auth_events(
|
def _resolve_auth_events(
|
||||||
events: List[EventBase], auth_events: StateMap[EventBase]
|
room_version: RoomVersion, events: List[EventBase], auth_events: StateMap[EventBase]
|
||||||
) -> EventBase:
|
) -> EventBase:
|
||||||
reverse = list(reversed(_ordered_events(events)))
|
reverse = list(reversed(_ordered_events(events)))
|
||||||
|
|
||||||
auth_keys = {
|
auth_keys = {
|
||||||
key for event in events for key in event_auth.auth_types_for_event(event)
|
key
|
||||||
|
for event in events
|
||||||
|
for key in event_auth.auth_types_for_event(room_version, event)
|
||||||
}
|
}
|
||||||
|
|
||||||
new_auth_events = {}
|
new_auth_events = {}
|
||||||
|
|
|
@ -36,7 +36,7 @@ import synapse.state
|
||||||
from synapse import event_auth
|
from synapse import event_auth
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.api.errors import AuthError
|
from synapse.api.errors import AuthError
|
||||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
from synapse.api.room_versions import RoomVersion
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.types import MutableStateMap, StateMap
|
from synapse.types import MutableStateMap, StateMap
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
@ -53,7 +53,7 @@ _AWAIT_AFTER_ITERATIONS = 100
|
||||||
async def resolve_events_with_store(
|
async def resolve_events_with_store(
|
||||||
clock: Clock,
|
clock: Clock,
|
||||||
room_id: str,
|
room_id: str,
|
||||||
room_version: str,
|
room_version: RoomVersion,
|
||||||
state_sets: Sequence[StateMap[str]],
|
state_sets: Sequence[StateMap[str]],
|
||||||
event_map: Optional[Dict[str, EventBase]],
|
event_map: Optional[Dict[str, EventBase]],
|
||||||
state_res_store: "synapse.state.StateResolutionStore",
|
state_res_store: "synapse.state.StateResolutionStore",
|
||||||
|
@ -497,7 +497,7 @@ async def _reverse_topological_power_sort(
|
||||||
async def _iterative_auth_checks(
|
async def _iterative_auth_checks(
|
||||||
clock: Clock,
|
clock: Clock,
|
||||||
room_id: str,
|
room_id: str,
|
||||||
room_version: str,
|
room_version: RoomVersion,
|
||||||
event_ids: List[str],
|
event_ids: List[str],
|
||||||
base_state: StateMap[str],
|
base_state: StateMap[str],
|
||||||
event_map: Dict[str, EventBase],
|
event_map: Dict[str, EventBase],
|
||||||
|
@ -519,7 +519,6 @@ async def _iterative_auth_checks(
|
||||||
Returns the final updated state
|
Returns the final updated state
|
||||||
"""
|
"""
|
||||||
resolved_state = dict(base_state)
|
resolved_state = dict(base_state)
|
||||||
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
|
|
||||||
|
|
||||||
for idx, event_id in enumerate(event_ids, start=1):
|
for idx, event_id in enumerate(event_ids, start=1):
|
||||||
event = event_map[event_id]
|
event = event_map[event_id]
|
||||||
|
@ -538,7 +537,7 @@ async def _iterative_auth_checks(
|
||||||
if ev.rejected_reason is None:
|
if ev.rejected_reason is None:
|
||||||
auth_events[(ev.type, ev.state_key)] = ev
|
auth_events[(ev.type, ev.state_key)] = ev
|
||||||
|
|
||||||
for key in event_auth.auth_types_for_event(event):
|
for key in event_auth.auth_types_for_event(room_version, event):
|
||||||
if key in resolved_state:
|
if key in resolved_state:
|
||||||
ev_id = resolved_state[key]
|
ev_id = resolved_state[key]
|
||||||
ev = await _get_event(room_id, ev_id, event_map, state_res_store)
|
ev = await _get_event(room_id, ev_id, event_map, state_res_store)
|
||||||
|
@ -548,7 +547,7 @@ async def _iterative_auth_checks(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
event_auth.check(
|
event_auth.check(
|
||||||
room_version_obj,
|
room_version,
|
||||||
event,
|
event,
|
||||||
auth_events,
|
auth_events,
|
||||||
do_sig_check=False,
|
do_sig_check=False,
|
||||||
|
|
|
@ -484,7 +484,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
state_d = resolve_events_with_store(
|
state_d = resolve_events_with_store(
|
||||||
FakeClock(),
|
FakeClock(),
|
||||||
ROOM_ID,
|
ROOM_ID,
|
||||||
RoomVersions.V2.identifier,
|
RoomVersions.V2,
|
||||||
[state_at_event[n] for n in prev_events],
|
[state_at_event[n] for n in prev_events],
|
||||||
event_map=event_map,
|
event_map=event_map,
|
||||||
state_res_store=TestStateResolutionStore(event_map),
|
state_res_store=TestStateResolutionStore(event_map),
|
||||||
|
@ -496,7 +496,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
if fake_event.state_key is not None:
|
if fake_event.state_key is not None:
|
||||||
state_after[(fake_event.type, fake_event.state_key)] = event_id
|
state_after[(fake_event.type, fake_event.state_key)] = event_id
|
||||||
|
|
||||||
auth_types = set(auth_types_for_event(fake_event))
|
auth_types = set(auth_types_for_event(RoomVersions.V6, fake_event))
|
||||||
|
|
||||||
auth_events = []
|
auth_events = []
|
||||||
for key in auth_types:
|
for key in auth_types:
|
||||||
|
@ -633,7 +633,7 @@ class SimpleParamStateTestCase(unittest.TestCase):
|
||||||
state_d = resolve_events_with_store(
|
state_d = resolve_events_with_store(
|
||||||
FakeClock(),
|
FakeClock(),
|
||||||
ROOM_ID,
|
ROOM_ID,
|
||||||
RoomVersions.V2.identifier,
|
RoomVersions.V2,
|
||||||
[self.state_at_bob, self.state_at_charlie],
|
[self.state_at_bob, self.state_at_charlie],
|
||||||
event_map=None,
|
event_map=None,
|
||||||
state_res_store=TestStateResolutionStore(self.event_map),
|
state_res_store=TestStateResolutionStore(self.event_map),
|
||||||
|
|
|
@ -11,7 +11,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from canonicaljson import json
|
from canonicaljson import json
|
||||||
|
|
||||||
|
@ -234,8 +234,8 @@ class RedactionTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
async def build(
|
async def build(
|
||||||
self,
|
self,
|
||||||
prev_event_ids,
|
prev_event_ids: List[str],
|
||||||
auth_event_ids,
|
auth_event_ids: Optional[List[str]],
|
||||||
depth: Optional[int] = None,
|
depth: Optional[int] = None,
|
||||||
):
|
):
|
||||||
built_event = await self._base_builder.build(
|
built_event = await self._base_builder.build(
|
||||||
|
|
|
@ -351,7 +351,11 @@ class EventAuthTestCase(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
Test joining a restricted room from MSC3083.
|
Test joining a restricted room from MSC3083.
|
||||||
|
|
||||||
This is pretty much the same test as public.
|
This is similar to the public test, but has some additional checks on
|
||||||
|
signatures.
|
||||||
|
|
||||||
|
The checks which care about signatures fake them by simply adding an
|
||||||
|
object of the proper form, not generating valid signatures.
|
||||||
"""
|
"""
|
||||||
creator = "@creator:example.com"
|
creator = "@creator:example.com"
|
||||||
pleb = "@joiner:example.com"
|
pleb = "@joiner:example.com"
|
||||||
|
@ -359,6 +363,7 @@ class EventAuthTestCase(unittest.TestCase):
|
||||||
auth_events = {
|
auth_events = {
|
||||||
("m.room.create", ""): _create_event(creator),
|
("m.room.create", ""): _create_event(creator),
|
||||||
("m.room.member", creator): _join_event(creator),
|
("m.room.member", creator): _join_event(creator),
|
||||||
|
("m.room.power_levels", ""): _power_levels_event(creator, {"invite": 0}),
|
||||||
("m.room.join_rules", ""): _join_rules_event(creator, "restricted"),
|
("m.room.join_rules", ""): _join_rules_event(creator, "restricted"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -371,19 +376,81 @@ class EventAuthTestCase(unittest.TestCase):
|
||||||
do_sig_check=False,
|
do_sig_check=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check join.
|
# A properly formatted join event should work.
|
||||||
|
authorised_join_event = _join_event(
|
||||||
|
pleb,
|
||||||
|
additional_content={
|
||||||
|
"join_authorised_via_users_server": "@creator:example.com"
|
||||||
|
},
|
||||||
|
)
|
||||||
event_auth.check(
|
event_auth.check(
|
||||||
RoomVersions.MSC3083,
|
RoomVersions.MSC3083,
|
||||||
_join_event(pleb),
|
authorised_join_event,
|
||||||
auth_events,
|
auth_events,
|
||||||
do_sig_check=False,
|
do_sig_check=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# A user cannot be force-joined to a room.
|
# A join issued by a specific user works (i.e. the power level checks
|
||||||
|
# are done properly).
|
||||||
|
pl_auth_events = auth_events.copy()
|
||||||
|
pl_auth_events[("m.room.power_levels", "")] = _power_levels_event(
|
||||||
|
creator, {"invite": 100, "users": {"@inviter:foo.test": 150}}
|
||||||
|
)
|
||||||
|
pl_auth_events[("m.room.member", "@inviter:foo.test")] = _join_event(
|
||||||
|
"@inviter:foo.test"
|
||||||
|
)
|
||||||
|
event_auth.check(
|
||||||
|
RoomVersions.MSC3083,
|
||||||
|
_join_event(
|
||||||
|
pleb,
|
||||||
|
additional_content={
|
||||||
|
"join_authorised_via_users_server": "@inviter:foo.test"
|
||||||
|
},
|
||||||
|
),
|
||||||
|
pl_auth_events,
|
||||||
|
do_sig_check=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# A join which is missing an authorised server is rejected.
|
||||||
with self.assertRaises(AuthError):
|
with self.assertRaises(AuthError):
|
||||||
event_auth.check(
|
event_auth.check(
|
||||||
RoomVersions.MSC3083,
|
RoomVersions.MSC3083,
|
||||||
_member_event(pleb, "join", sender=creator),
|
_join_event(pleb),
|
||||||
|
auth_events,
|
||||||
|
do_sig_check=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# An join authorised by a user who is not in the room is rejected.
|
||||||
|
pl_auth_events = auth_events.copy()
|
||||||
|
pl_auth_events[("m.room.power_levels", "")] = _power_levels_event(
|
||||||
|
creator, {"invite": 100, "users": {"@other:example.com": 150}}
|
||||||
|
)
|
||||||
|
with self.assertRaises(AuthError):
|
||||||
|
event_auth.check(
|
||||||
|
RoomVersions.MSC3083,
|
||||||
|
_join_event(
|
||||||
|
pleb,
|
||||||
|
additional_content={
|
||||||
|
"join_authorised_via_users_server": "@other:example.com"
|
||||||
|
},
|
||||||
|
),
|
||||||
|
auth_events,
|
||||||
|
do_sig_check=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# A user cannot be force-joined to a room. (This uses an event which
|
||||||
|
# *would* be valid, but is sent be a different user.)
|
||||||
|
with self.assertRaises(AuthError):
|
||||||
|
event_auth.check(
|
||||||
|
RoomVersions.MSC3083,
|
||||||
|
_member_event(
|
||||||
|
pleb,
|
||||||
|
"join",
|
||||||
|
sender=creator,
|
||||||
|
additional_content={
|
||||||
|
"join_authorised_via_users_server": "@inviter:foo.test"
|
||||||
|
},
|
||||||
|
),
|
||||||
auth_events,
|
auth_events,
|
||||||
do_sig_check=False,
|
do_sig_check=False,
|
||||||
)
|
)
|
||||||
|
@ -393,7 +460,7 @@ class EventAuthTestCase(unittest.TestCase):
|
||||||
with self.assertRaises(AuthError):
|
with self.assertRaises(AuthError):
|
||||||
event_auth.check(
|
event_auth.check(
|
||||||
RoomVersions.MSC3083,
|
RoomVersions.MSC3083,
|
||||||
_join_event(pleb),
|
authorised_join_event,
|
||||||
auth_events,
|
auth_events,
|
||||||
do_sig_check=False,
|
do_sig_check=False,
|
||||||
)
|
)
|
||||||
|
@ -402,12 +469,13 @@ class EventAuthTestCase(unittest.TestCase):
|
||||||
auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave")
|
auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave")
|
||||||
event_auth.check(
|
event_auth.check(
|
||||||
RoomVersions.MSC3083,
|
RoomVersions.MSC3083,
|
||||||
_join_event(pleb),
|
authorised_join_event,
|
||||||
auth_events,
|
auth_events,
|
||||||
do_sig_check=False,
|
do_sig_check=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# A user can send a join if they're in the room.
|
# A user can send a join if they're in the room. (This doesn't need to
|
||||||
|
# be authorised since the user is already joined.)
|
||||||
auth_events[("m.room.member", pleb)] = _member_event(pleb, "join")
|
auth_events[("m.room.member", pleb)] = _member_event(pleb, "join")
|
||||||
event_auth.check(
|
event_auth.check(
|
||||||
RoomVersions.MSC3083,
|
RoomVersions.MSC3083,
|
||||||
|
@ -416,7 +484,8 @@ class EventAuthTestCase(unittest.TestCase):
|
||||||
do_sig_check=False,
|
do_sig_check=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# A user can accept an invite.
|
# A user can accept an invite. (This doesn't need to be authorised since
|
||||||
|
# the user was invited.)
|
||||||
auth_events[("m.room.member", pleb)] = _member_event(
|
auth_events[("m.room.member", pleb)] = _member_event(
|
||||||
pleb, "invite", sender=creator
|
pleb, "invite", sender=creator
|
||||||
)
|
)
|
||||||
|
@ -446,7 +515,10 @@ def _create_event(user_id: str) -> EventBase:
|
||||||
|
|
||||||
|
|
||||||
def _member_event(
|
def _member_event(
|
||||||
user_id: str, membership: str, sender: Optional[str] = None
|
user_id: str,
|
||||||
|
membership: str,
|
||||||
|
sender: Optional[str] = None,
|
||||||
|
additional_content: Optional[dict] = None,
|
||||||
) -> EventBase:
|
) -> EventBase:
|
||||||
return make_event_from_dict(
|
return make_event_from_dict(
|
||||||
{
|
{
|
||||||
|
@ -455,14 +527,14 @@ def _member_event(
|
||||||
"type": "m.room.member",
|
"type": "m.room.member",
|
||||||
"sender": sender or user_id,
|
"sender": sender or user_id,
|
||||||
"state_key": user_id,
|
"state_key": user_id,
|
||||||
"content": {"membership": membership},
|
"content": {"membership": membership, **(additional_content or {})},
|
||||||
"prev_events": [],
|
"prev_events": [],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _join_event(user_id: str) -> EventBase:
|
def _join_event(user_id: str, additional_content: Optional[dict] = None) -> EventBase:
|
||||||
return _member_event(user_id, "join")
|
return _member_event(user_id, "join", additional_content=additional_content)
|
||||||
|
|
||||||
|
|
||||||
def _power_levels_event(sender: str, content: JsonDict) -> EventBase:
|
def _power_levels_event(sender: str, content: JsonDict) -> EventBase:
|
||||||
|
|
Loading…
Reference in a new issue