0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-14 17:44:01 +01:00

Update the MSC3083 support to verify if joins are from an authorized server. (#10254)

This commit is contained in:
Patrick Cloke 2021-07-26 12:17:00 -04:00 committed by GitHub
parent 4fb92d93ea
commit 228decfce1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 632 additions and 98 deletions

View file

@ -0,0 +1 @@
Update support for [MSC3083](https://github.com/matrix-org/matrix-doc/pull/3083) to consider changes in the MSC around which servers can issue join events.

View file

@ -75,6 +75,9 @@ class Codes:
INVALID_SIGNATURE = "M_INVALID_SIGNATURE" INVALID_SIGNATURE = "M_INVALID_SIGNATURE"
USER_DEACTIVATED = "M_USER_DEACTIVATED" USER_DEACTIVATED = "M_USER_DEACTIVATED"
BAD_ALIAS = "M_BAD_ALIAS" BAD_ALIAS = "M_BAD_ALIAS"
# For restricted join rules.
UNABLE_AUTHORISE_JOIN = "M_UNABLE_TO_AUTHORISE_JOIN"
UNABLE_TO_GRANT_JOIN = "M_UNABLE_TO_GRANT_JOIN"
class CodeMessageException(RuntimeError): class CodeMessageException(RuntimeError):

View file

@ -168,7 +168,7 @@ class RoomVersions:
msc2403_knocking=False, msc2403_knocking=False,
) )
MSC3083 = RoomVersion( MSC3083 = RoomVersion(
"org.matrix.msc3083", "org.matrix.msc3083.v2",
RoomDisposition.UNSTABLE, RoomDisposition.UNSTABLE,
EventFormatVersions.V3, EventFormatVersions.V3,
StateResolutionVersions.V2, StateResolutionVersions.V2,

View file

@ -106,6 +106,18 @@ def check(
if not event.signatures.get(event_id_domain): if not event.signatures.get(event_id_domain):
raise AuthError(403, "Event not signed by sending server") raise AuthError(403, "Event not signed by sending server")
is_invite_via_allow_rule = (
event.type == EventTypes.Member
and event.membership == Membership.JOIN
and "join_authorised_via_users_server" in event.content
)
if is_invite_via_allow_rule:
authoriser_domain = get_domain_from_id(
event.content["join_authorised_via_users_server"]
)
if not event.signatures.get(authoriser_domain):
raise AuthError(403, "Event not signed by authorising server")
# Implementation of https://matrix.org/docs/spec/rooms/v1#authorization-rules # Implementation of https://matrix.org/docs/spec/rooms/v1#authorization-rules
# #
# 1. If type is m.room.create: # 1. If type is m.room.create:
@ -177,7 +189,7 @@ def check(
# https://github.com/vector-im/vector-web/issues/1208 hopefully # https://github.com/vector-im/vector-web/issues/1208 hopefully
if event.type == EventTypes.ThirdPartyInvite: if event.type == EventTypes.ThirdPartyInvite:
user_level = get_user_power_level(event.user_id, auth_events) user_level = get_user_power_level(event.user_id, auth_events)
invite_level = _get_named_level(auth_events, "invite", 0) invite_level = get_named_level(auth_events, "invite", 0)
if user_level < invite_level: if user_level < invite_level:
raise AuthError(403, "You don't have permission to invite users") raise AuthError(403, "You don't have permission to invite users")
@ -285,8 +297,8 @@ def _is_membership_change_allowed(
user_level = get_user_power_level(event.user_id, auth_events) user_level = get_user_power_level(event.user_id, auth_events)
target_level = get_user_power_level(target_user_id, auth_events) target_level = get_user_power_level(target_user_id, auth_events)
# FIXME (erikj): What should we do here as the default? invite_level = get_named_level(auth_events, "invite", 0)
ban_level = _get_named_level(auth_events, "ban", 50) ban_level = get_named_level(auth_events, "ban", 50)
logger.debug( logger.debug(
"_is_membership_change_allowed: %s", "_is_membership_change_allowed: %s",
@ -336,8 +348,6 @@ def _is_membership_change_allowed(
elif target_in_room: # the target is already in the room. elif target_in_room: # the target is already in the room.
raise AuthError(403, "%s is already in the room." % target_user_id) raise AuthError(403, "%s is already in the room." % target_user_id)
else: else:
invite_level = _get_named_level(auth_events, "invite", 0)
if user_level < invite_level: if user_level < invite_level:
raise AuthError(403, "You don't have permission to invite users") raise AuthError(403, "You don't have permission to invite users")
elif Membership.JOIN == membership: elif Membership.JOIN == membership:
@ -345,16 +355,41 @@ def _is_membership_change_allowed(
# * They are not banned. # * They are not banned.
# * They are accepting a previously sent invitation. # * They are accepting a previously sent invitation.
# * They are already joined (it's a NOOP). # * They are already joined (it's a NOOP).
# * The room is public or restricted. # * The room is public.
# * The room is restricted and the user meets the allows rules.
if event.user_id != target_user_id: if event.user_id != target_user_id:
raise AuthError(403, "Cannot force another user to join.") raise AuthError(403, "Cannot force another user to join.")
elif target_banned: elif target_banned:
raise AuthError(403, "You are banned from this room") raise AuthError(403, "You are banned from this room")
elif join_rule == JoinRules.PUBLIC or ( elif join_rule == JoinRules.PUBLIC:
pass
elif (
room_version.msc3083_join_rules room_version.msc3083_join_rules
and join_rule == JoinRules.MSC3083_RESTRICTED and join_rule == JoinRules.MSC3083_RESTRICTED
): ):
pass # This is the same as public, but the event must contain a reference
# to the server who authorised the join. If the event does not contain
# the proper content it is rejected.
#
# Note that if the caller is in the room or invited, then they do
# not need to meet the allow rules.
if not caller_in_room and not caller_invited:
authorising_user = event.content.get("join_authorised_via_users_server")
if authorising_user is None:
raise AuthError(403, "Join event is missing authorising user.")
# The authorising user must be in the room.
key = (EventTypes.Member, authorising_user)
member_event = auth_events.get(key)
_check_joined_room(member_event, authorising_user, event.room_id)
authorising_user_level = get_user_power_level(
authorising_user, auth_events
)
if authorising_user_level < invite_level:
raise AuthError(403, "Join event authorised by invalid server.")
elif join_rule == JoinRules.INVITE or ( elif join_rule == JoinRules.INVITE or (
room_version.msc2403_knocking and join_rule == JoinRules.KNOCK room_version.msc2403_knocking and join_rule == JoinRules.KNOCK
): ):
@ -369,7 +404,7 @@ def _is_membership_change_allowed(
if target_banned and user_level < ban_level: if target_banned and user_level < ban_level:
raise AuthError(403, "You cannot unban user %s." % (target_user_id,)) raise AuthError(403, "You cannot unban user %s." % (target_user_id,))
elif target_user_id != event.user_id: elif target_user_id != event.user_id:
kick_level = _get_named_level(auth_events, "kick", 50) kick_level = get_named_level(auth_events, "kick", 50)
if user_level < kick_level or user_level <= target_level: if user_level < kick_level or user_level <= target_level:
raise AuthError(403, "You cannot kick user %s." % target_user_id) raise AuthError(403, "You cannot kick user %s." % target_user_id)
@ -445,7 +480,7 @@ def get_send_level(
def _can_send_event(event: EventBase, auth_events: StateMap[EventBase]) -> bool: def _can_send_event(event: EventBase, auth_events: StateMap[EventBase]) -> bool:
power_levels_event = _get_power_level_event(auth_events) power_levels_event = get_power_level_event(auth_events)
send_level = get_send_level(event.type, event.get("state_key"), power_levels_event) send_level = get_send_level(event.type, event.get("state_key"), power_levels_event)
user_level = get_user_power_level(event.user_id, auth_events) user_level = get_user_power_level(event.user_id, auth_events)
@ -485,7 +520,7 @@ def check_redaction(
""" """
user_level = get_user_power_level(event.user_id, auth_events) user_level = get_user_power_level(event.user_id, auth_events)
redact_level = _get_named_level(auth_events, "redact", 50) redact_level = get_named_level(auth_events, "redact", 50)
if user_level >= redact_level: if user_level >= redact_level:
return False return False
@ -600,7 +635,7 @@ def _check_power_levels(
) )
def _get_power_level_event(auth_events: StateMap[EventBase]) -> Optional[EventBase]: def get_power_level_event(auth_events: StateMap[EventBase]) -> Optional[EventBase]:
return auth_events.get((EventTypes.PowerLevels, "")) return auth_events.get((EventTypes.PowerLevels, ""))
@ -616,7 +651,7 @@ def get_user_power_level(user_id: str, auth_events: StateMap[EventBase]) -> int:
Returns: Returns:
the user's power level in this room. the user's power level in this room.
""" """
power_level_event = _get_power_level_event(auth_events) power_level_event = get_power_level_event(auth_events)
if power_level_event: if power_level_event:
level = power_level_event.content.get("users", {}).get(user_id) level = power_level_event.content.get("users", {}).get(user_id)
if not level: if not level:
@ -640,8 +675,8 @@ def get_user_power_level(user_id: str, auth_events: StateMap[EventBase]) -> int:
return 0 return 0
def _get_named_level(auth_events: StateMap[EventBase], name: str, default: int) -> int: def get_named_level(auth_events: StateMap[EventBase], name: str, default: int) -> int:
power_level_event = _get_power_level_event(auth_events) power_level_event = get_power_level_event(auth_events)
if not power_level_event: if not power_level_event:
return default return default
@ -728,7 +763,9 @@ def get_public_keys(invite_event: EventBase) -> List[Dict[str, Any]]:
return public_keys return public_keys
def auth_types_for_event(event: Union[EventBase, EventBuilder]) -> Set[Tuple[str, str]]: def auth_types_for_event(
room_version: RoomVersion, event: Union[EventBase, EventBuilder]
) -> Set[Tuple[str, str]]:
"""Given an event, return a list of (EventType, StateKey) that may be """Given an event, return a list of (EventType, StateKey) that may be
needed to auth the event. The returned list may be a superset of what needed to auth the event. The returned list may be a superset of what
would actually be required depending on the full state of the room. would actually be required depending on the full state of the room.
@ -760,4 +797,12 @@ def auth_types_for_event(event: Union[EventBase, EventBuilder]) -> Set[Tuple[str
) )
auth_types.add(key) auth_types.add(key)
if room_version.msc3083_join_rules and membership == Membership.JOIN:
if "join_authorised_via_users_server" in event.content:
key = (
EventTypes.Member,
event.content["join_authorised_via_users_server"],
)
auth_types.add(key)
return auth_types return auth_types

View file

@ -178,6 +178,34 @@ async def _check_sigs_on_pdu(
) )
raise SynapseError(403, errmsg, Codes.FORBIDDEN) raise SynapseError(403, errmsg, Codes.FORBIDDEN)
# If this is a join event for a restricted room it may have been authorised
# via a different server from the sending server. Check those signatures.
if (
room_version.msc3083_join_rules
and pdu.type == EventTypes.Member
and pdu.membership == Membership.JOIN
and "join_authorised_via_users_server" in pdu.content
):
authorising_server = get_domain_from_id(
pdu.content["join_authorised_via_users_server"]
)
try:
await keyring.verify_event_for_server(
authorising_server,
pdu,
pdu.origin_server_ts if room_version.enforce_key_validity else 0,
)
except Exception as e:
errmsg = (
"event id %s: unable to verify signature for authorising server %s: %s"
% (
pdu.event_id,
authorising_server,
e,
)
)
raise SynapseError(403, errmsg, Codes.FORBIDDEN)
def _is_invite_via_3pid(event: EventBase) -> bool: def _is_invite_via_3pid(event: EventBase) -> bool:
return ( return (

View file

@ -19,7 +19,6 @@ import itertools
import logging import logging
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any,
Awaitable, Awaitable,
Callable, Callable,
Collection, Collection,
@ -79,7 +78,15 @@ class InvalidResponseError(RuntimeError):
we couldn't parse we couldn't parse
""" """
pass
@attr.s(slots=True, frozen=True, auto_attribs=True)
class SendJoinResult:
# The event to persist.
event: EventBase
# A string giving the server the event was sent to.
origin: str
state: List[EventBase]
auth_chain: List[EventBase]
class FederationClient(FederationBase): class FederationClient(FederationBase):
@ -677,7 +684,7 @@ class FederationClient(FederationBase):
async def send_join( async def send_join(
self, destinations: Iterable[str], pdu: EventBase, room_version: RoomVersion self, destinations: Iterable[str], pdu: EventBase, room_version: RoomVersion
) -> Dict[str, Any]: ) -> SendJoinResult:
"""Sends a join event to one of a list of homeservers. """Sends a join event to one of a list of homeservers.
Doing so will cause the remote server to add the event to the graph, Doing so will cause the remote server to add the event to the graph,
@ -691,18 +698,38 @@ class FederationClient(FederationBase):
did the make_join) did the make_join)
Returns: Returns:
a dict with members ``origin`` (a string The result of the send join request.
giving the server the event was sent to, ``state`` (?) and
``auth_chain``.
Raises: Raises:
SynapseError: if the chosen remote server returns a 300/400 code, or SynapseError: if the chosen remote server returns a 300/400 code, or
no servers successfully handle the request. no servers successfully handle the request.
""" """
async def send_request(destination) -> Dict[str, Any]: async def send_request(destination) -> SendJoinResult:
response = await self._do_send_join(room_version, destination, pdu) response = await self._do_send_join(room_version, destination, pdu)
# If an event was returned (and expected to be returned):
#
# * Ensure it has the same event ID (note that the event ID is a hash
# of the event fields for versions which support MSC3083).
# * Ensure the signatures are good.
#
# Otherwise, fallback to the provided event.
if room_version.msc3083_join_rules and response.event:
event = response.event
valid_pdu = await self._check_sigs_and_hash_and_fetch_one(
pdu=event,
origin=destination,
outlier=True,
room_version=room_version,
)
if valid_pdu is None or event.event_id != pdu.event_id:
raise InvalidResponseError("Returned an invalid join event")
else:
event = pdu
state = response.state state = response.state
auth_chain = response.auth_events auth_chain = response.auth_events
@ -784,11 +811,21 @@ class FederationClient(FederationBase):
% (auth_chain_create_events,) % (auth_chain_create_events,)
) )
return { return SendJoinResult(
"state": signed_state, event=event,
"auth_chain": signed_auth, state=signed_state,
"origin": destination, auth_chain=signed_auth,
} origin=destination,
)
if room_version.msc3083_join_rules:
# If the join is being authorised via allow rules, we need to send
# the /send_join back to the same server that was originally used
# with /make_join.
if "join_authorised_via_users_server" in pdu.content:
destinations = [
get_domain_from_id(pdu.content["join_authorised_via_users_server"])
]
return await self._try_destination_list("send_join", destinations, send_request) return await self._try_destination_list("send_join", destinations, send_request)

View file

@ -45,6 +45,7 @@ from synapse.api.errors import (
UnsupportedRoomVersionError, UnsupportedRoomVersionError,
) )
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.crypto.event_signing import compute_event_signature
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.federation.federation_base import FederationBase, event_from_pdu_json
@ -64,7 +65,7 @@ from synapse.replication.http.federation import (
ReplicationGetQueryRestServlet, ReplicationGetQueryRestServlet,
) )
from synapse.storage.databases.main.lock import Lock from synapse.storage.databases.main.lock import Lock
from synapse.types import JsonDict from synapse.types import JsonDict, get_domain_from_id
from synapse.util import glob_to_regex, json_decoder, unwrapFirstError from synapse.util import glob_to_regex, json_decoder, unwrapFirstError
from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
@ -586,7 +587,7 @@ class FederationServer(FederationBase):
async def on_send_join_request( async def on_send_join_request(
self, origin: str, content: JsonDict, room_id: str self, origin: str, content: JsonDict, room_id: str
) -> Dict[str, Any]: ) -> Dict[str, Any]:
context = await self._on_send_membership_event( event, context = await self._on_send_membership_event(
origin, content, Membership.JOIN, room_id origin, content, Membership.JOIN, room_id
) )
@ -597,6 +598,7 @@ class FederationServer(FederationBase):
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
return { return {
"org.matrix.msc3083.v2.event": event.get_pdu_json(),
"state": [p.get_pdu_json(time_now) for p in state.values()], "state": [p.get_pdu_json(time_now) for p in state.values()],
"auth_chain": [p.get_pdu_json(time_now) for p in auth_chain], "auth_chain": [p.get_pdu_json(time_now) for p in auth_chain],
} }
@ -681,7 +683,7 @@ class FederationServer(FederationBase):
Returns: Returns:
The stripped room state. The stripped room state.
""" """
event_context = await self._on_send_membership_event( _, context = await self._on_send_membership_event(
origin, content, Membership.KNOCK, room_id origin, content, Membership.KNOCK, room_id
) )
@ -690,14 +692,14 @@ class FederationServer(FederationBase):
# related to the room while the knock request is pending. # related to the room while the knock request is pending.
stripped_room_state = ( stripped_room_state = (
await self.store.get_stripped_room_state_from_event_context( await self.store.get_stripped_room_state_from_event_context(
event_context, self._room_prejoin_state_types context, self._room_prejoin_state_types
) )
) )
return {"knock_state_events": stripped_room_state} return {"knock_state_events": stripped_room_state}
async def _on_send_membership_event( async def _on_send_membership_event(
self, origin: str, content: JsonDict, membership_type: str, room_id: str self, origin: str, content: JsonDict, membership_type: str, room_id: str
) -> EventContext: ) -> Tuple[EventBase, EventContext]:
"""Handle an on_send_{join,leave,knock} request """Handle an on_send_{join,leave,knock} request
Does some preliminary validation before passing the request on to the Does some preliminary validation before passing the request on to the
@ -712,7 +714,7 @@ class FederationServer(FederationBase):
in the event in the event
Returns: Returns:
The context of the event after inserting it into the room graph. The event and context of the event after inserting it into the room graph.
Raises: Raises:
SynapseError if there is a problem with the request, including things like SynapseError if there is a problem with the request, including things like
@ -748,6 +750,33 @@ class FederationServer(FederationBase):
logger.debug("_on_send_membership_event: pdu sigs: %s", event.signatures) logger.debug("_on_send_membership_event: pdu sigs: %s", event.signatures)
# Sign the event since we're vouching on behalf of the remote server that
# the event is valid to be sent into the room. Currently this is only done
# if the user is being joined via restricted join rules.
if (
room_version.msc3083_join_rules
and event.membership == Membership.JOIN
and "join_authorised_via_users_server" in event.content
):
# We can only authorise our own users.
authorising_server = get_domain_from_id(
event.content["join_authorised_via_users_server"]
)
if authorising_server != self.server_name:
raise SynapseError(
400,
f"Cannot authorise request from resident server: {authorising_server}",
)
event.signatures.update(
compute_event_signature(
room_version,
event.get_pdu_json(),
self.hs.hostname,
self.hs.signing_key,
)
)
event = await self._check_sigs_and_hash(room_version, event) event = await self._check_sigs_and_hash(room_version, event)
return await self.handler.on_send_membership_event(origin, event) return await self.handler.on_send_membership_event(origin, event)

View file

@ -1219,8 +1219,26 @@ def _create_v2_path(path: str, *args: str) -> str:
class SendJoinResponse: class SendJoinResponse:
"""The parsed response of a `/send_join` request.""" """The parsed response of a `/send_join` request."""
# The list of auth events from the /send_join response.
auth_events: List[EventBase] auth_events: List[EventBase]
# The list of state from the /send_join response.
state: List[EventBase] state: List[EventBase]
# The raw join event from the /send_join response.
event_dict: JsonDict
# The parsed join event from the /send_join response. This will be None if
# "event" is not included in the response.
event: Optional[EventBase] = None
@ijson.coroutine
def _event_parser(event_dict: JsonDict):
"""Helper function for use with `ijson.kvitems_coro` to parse key-value pairs
to add them to a given dictionary.
"""
while True:
key, value = yield
event_dict[key] = value
@ijson.coroutine @ijson.coroutine
@ -1246,7 +1264,8 @@ class SendJoinParser(ByteParser[SendJoinResponse]):
CONTENT_TYPE = "application/json" CONTENT_TYPE = "application/json"
def __init__(self, room_version: RoomVersion, v1_api: bool): def __init__(self, room_version: RoomVersion, v1_api: bool):
self._response = SendJoinResponse([], []) self._response = SendJoinResponse([], [], {})
self._room_version = room_version
# The V1 API has the shape of `[200, {...}]`, which we handle by # The V1 API has the shape of `[200, {...}]`, which we handle by
# prefixing with `item.*`. # prefixing with `item.*`.
@ -1260,12 +1279,21 @@ class SendJoinParser(ByteParser[SendJoinResponse]):
_event_list_parser(room_version, self._response.auth_events), _event_list_parser(room_version, self._response.auth_events),
prefix + "auth_chain.item", prefix + "auth_chain.item",
) )
self._coro_event = ijson.kvitems_coro(
_event_parser(self._response.event_dict),
prefix + "org.matrix.msc3083.v2.event",
)
def write(self, data: bytes) -> int: def write(self, data: bytes) -> int:
self._coro_state.send(data) self._coro_state.send(data)
self._coro_auth.send(data) self._coro_auth.send(data)
self._coro_event.send(data)
return len(data) return len(data)
def finish(self) -> SendJoinResponse: def finish(self) -> SendJoinResponse:
if self._response.event_dict:
self._response.event = make_event_from_dict(
self._response.event_dict, self._room_version
)
return self._response return self._response

View file

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
from typing import TYPE_CHECKING, Collection, List, Optional, Union from typing import TYPE_CHECKING, Collection, List, Optional, Union
from synapse import event_auth from synapse import event_auth
@ -20,16 +21,18 @@ from synapse.api.constants import (
Membership, Membership,
RestrictedJoinRuleTypes, RestrictedJoinRuleTypes,
) )
from synapse.api.errors import AuthError from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.builder import EventBuilder from synapse.events.builder import EventBuilder
from synapse.types import StateMap from synapse.types import StateMap, get_domain_from_id
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class EventAuthHandler: class EventAuthHandler:
""" """
@ -39,6 +42,7 @@ class EventAuthHandler:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self._clock = hs.get_clock() self._clock = hs.get_clock()
self._store = hs.get_datastore() self._store = hs.get_datastore()
self._server_name = hs.hostname
async def check_from_context( async def check_from_context(
self, room_version: str, event, context, do_sig_check=True self, room_version: str, event, context, do_sig_check=True
@ -81,15 +85,76 @@ class EventAuthHandler:
# introduce undesirable "state reset" behaviour. # introduce undesirable "state reset" behaviour.
# #
# All of which sounds a bit tricky so we don't bother for now. # All of which sounds a bit tricky so we don't bother for now.
auth_ids = [] auth_ids = []
for etype, state_key in event_auth.auth_types_for_event(event): for etype, state_key in event_auth.auth_types_for_event(
event.room_version, event
):
auth_ev_id = current_state_ids.get((etype, state_key)) auth_ev_id = current_state_ids.get((etype, state_key))
if auth_ev_id: if auth_ev_id:
auth_ids.append(auth_ev_id) auth_ids.append(auth_ev_id)
return auth_ids return auth_ids
async def get_user_which_could_invite(
self, room_id: str, current_state_ids: StateMap[str]
) -> str:
"""
Searches the room state for a local user who has the power level necessary
to invite other users.
Args:
room_id: The room ID under search.
current_state_ids: The current state of the room.
Returns:
The MXID of the user which could issue an invite.
Raises:
SynapseError if no appropriate user is found.
"""
power_level_event_id = current_state_ids.get((EventTypes.PowerLevels, ""))
invite_level = 0
users_default_level = 0
if power_level_event_id:
power_level_event = await self._store.get_event(power_level_event_id)
invite_level = power_level_event.content.get("invite", invite_level)
users_default_level = power_level_event.content.get(
"users_default", users_default_level
)
users = power_level_event.content.get("users", {})
else:
users = {}
# Find the user with the highest power level.
users_in_room = await self._store.get_users_in_room(room_id)
# Only interested in local users.
local_users_in_room = [
u for u in users_in_room if get_domain_from_id(u) == self._server_name
]
chosen_user = max(
local_users_in_room,
key=lambda user: users.get(user, users_default_level),
default=None,
)
# Return the chosen if they can issue invites.
user_power_level = users.get(chosen_user, users_default_level)
if chosen_user and user_power_level >= invite_level:
logger.debug(
"Found a user who can issue invites %s with power level %d >= invite level %d",
chosen_user,
user_power_level,
invite_level,
)
return chosen_user
# No user was found.
raise SynapseError(
400,
"Unable to find a user which could issue an invite",
Codes.UNABLE_TO_GRANT_JOIN,
)
async def check_host_in_room(self, room_id: str, host: str) -> bool: async def check_host_in_room(self, room_id: str, host: str) -> bool:
with Measure(self._clock, "check_host_in_room"): with Measure(self._clock, "check_host_in_room"):
return await self._store.is_host_joined(room_id, host) return await self._store.is_host_joined(room_id, host)
@ -134,6 +199,18 @@ class EventAuthHandler:
# in any of them. # in any of them.
allowed_rooms = await self.get_rooms_that_allow_join(state_ids) allowed_rooms = await self.get_rooms_that_allow_join(state_ids)
if not await self.is_user_in_rooms(allowed_rooms, user_id): if not await self.is_user_in_rooms(allowed_rooms, user_id):
# If this is a remote request, the user might be in an allowed room
# that we do not know about.
if get_domain_from_id(user_id) != self._server_name:
for room_id in allowed_rooms:
if not await self._store.is_host_joined(room_id, self._server_name):
raise SynapseError(
400,
f"Unable to check if {user_id} is in allowed rooms.",
Codes.UNABLE_AUTHORISE_JOIN,
)
raise AuthError( raise AuthError(
403, 403,
"You do not belong to any of the required rooms to join this room.", "You do not belong to any of the required rooms to join this room.",

View file

@ -1494,9 +1494,10 @@ class FederationHandler(BaseHandler):
host_list, event, room_version_obj host_list, event, room_version_obj
) )
origin = ret["origin"] event = ret.event
state = ret["state"] origin = ret.origin
auth_chain = ret["auth_chain"] state = ret.state
auth_chain = ret.auth_chain
auth_chain.sort(key=lambda e: e.depth) auth_chain.sort(key=lambda e: e.depth)
logger.debug("do_invite_join auth_chain: %s", auth_chain) logger.debug("do_invite_join auth_chain: %s", auth_chain)
@ -1676,7 +1677,7 @@ class FederationHandler(BaseHandler):
# checking the room version will check that we've actually heard of the room # checking the room version will check that we've actually heard of the room
# (and return a 404 otherwise) # (and return a 404 otherwise)
room_version = await self.store.get_room_version_id(room_id) room_version = await self.store.get_room_version(room_id)
# now check that we are *still* in the room # now check that we are *still* in the room
is_in_room = await self._event_auth_handler.check_host_in_room( is_in_room = await self._event_auth_handler.check_host_in_room(
@ -1691,8 +1692,38 @@ class FederationHandler(BaseHandler):
event_content = {"membership": Membership.JOIN} event_content = {"membership": Membership.JOIN}
# If the current room is using restricted join rules, additional information
# may need to be included in the event content in order to efficiently
# validate the event.
#
# Note that this requires the /send_join request to come back to the
# same server.
if room_version.msc3083_join_rules:
state_ids = await self.store.get_current_state_ids(room_id)
if await self._event_auth_handler.has_restricted_join_rules(
state_ids, room_version
):
prev_member_event_id = state_ids.get((EventTypes.Member, user_id), None)
# If the user is invited or joined to the room already, then
# no additional info is needed.
include_auth_user_id = True
if prev_member_event_id:
prev_member_event = await self.store.get_event(prev_member_event_id)
include_auth_user_id = prev_member_event.membership not in (
Membership.JOIN,
Membership.INVITE,
)
if include_auth_user_id:
event_content[
"join_authorised_via_users_server"
] = await self._event_auth_handler.get_user_which_could_invite(
room_id,
state_ids,
)
builder = self.event_builder_factory.new( builder = self.event_builder_factory.new(
room_version, room_version.identifier,
{ {
"type": EventTypes.Member, "type": EventTypes.Member,
"content": event_content, "content": event_content,
@ -1710,10 +1741,13 @@ class FederationHandler(BaseHandler):
logger.warning("Failed to create join to %s because %s", room_id, e) logger.warning("Failed to create join to %s because %s", room_id, e)
raise raise
# Ensure the user can even join the room.
await self._check_join_restrictions(context, event)
# The remote hasn't signed it yet, obviously. We'll do the full checks # The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_join_request` # when we get the event back in `on_send_join_request`
await self._event_auth_handler.check_from_context( await self._event_auth_handler.check_from_context(
room_version, event, context, do_sig_check=False room_version.identifier, event, context, do_sig_check=False
) )
return event return event
@ -1958,7 +1992,7 @@ class FederationHandler(BaseHandler):
@log_function @log_function
async def on_send_membership_event( async def on_send_membership_event(
self, origin: str, event: EventBase self, origin: str, event: EventBase
) -> EventContext: ) -> Tuple[EventBase, EventContext]:
""" """
We have received a join/leave/knock event for a room via send_join/leave/knock. We have received a join/leave/knock event for a room via send_join/leave/knock.
@ -1981,7 +2015,7 @@ class FederationHandler(BaseHandler):
event: The member event that has been signed by the remote homeserver. event: The member event that has been signed by the remote homeserver.
Returns: Returns:
The context of the event after inserting it into the room graph. The event and context of the event after inserting it into the room graph.
Raises: Raises:
SynapseError if the event is not accepted into the room SynapseError if the event is not accepted into the room
@ -2037,7 +2071,7 @@ class FederationHandler(BaseHandler):
# all looks good, we can persist the event. # all looks good, we can persist the event.
await self._run_push_actions_and_persist_event(event, context) await self._run_push_actions_and_persist_event(event, context)
return context return event, context
async def _check_join_restrictions( async def _check_join_restrictions(
self, context: EventContext, event: EventBase self, context: EventContext, event: EventBase
@ -2473,7 +2507,7 @@ class FederationHandler(BaseHandler):
) )
# Now check if event pass auth against said current state # Now check if event pass auth against said current state
auth_types = auth_types_for_event(event) auth_types = auth_types_for_event(room_version_obj, event)
current_state_ids_list = [ current_state_ids_list = [
e for k, e in current_state_ids.items() if k in auth_types e for k, e in current_state_ids.items() if k in auth_types
] ]

View file

@ -16,7 +16,7 @@ import abc
import logging import logging
import random import random
from http import HTTPStatus from http import HTTPStatus
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple from typing import TYPE_CHECKING, Iterable, List, Optional, Set, Tuple
from synapse import types from synapse import types
from synapse.api.constants import AccountDataTypes, EventTypes, Membership from synapse.api.constants import AccountDataTypes, EventTypes, Membership
@ -28,6 +28,7 @@ from synapse.api.errors import (
SynapseError, SynapseError,
) )
from synapse.api.ratelimiting import Ratelimiter from synapse.api.ratelimiting import Ratelimiter
from synapse.event_auth import get_named_level, get_power_level_event
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.types import ( from synapse.types import (
@ -340,16 +341,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
newly_joined = True newly_joined = True
prev_member_event = None
if prev_member_event_id: if prev_member_event_id:
prev_member_event = await self.store.get_event(prev_member_event_id) prev_member_event = await self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN newly_joined = prev_member_event.membership != Membership.JOIN
# Check if the member should be allowed access via membership in a space.
await self.event_auth_handler.check_restricted_join_rules(
prev_state_ids, event.room_version, user_id, prev_member_event
)
# Only rate-limit if the user actually joined the room, otherwise we'll end # Only rate-limit if the user actually joined the room, otherwise we'll end
# up blocking profile updates. # up blocking profile updates.
if newly_joined and ratelimit: if newly_joined and ratelimit:
@ -701,7 +696,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# so don't really fit into the general auth process. # so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed") raise AuthError(403, "Guest access not allowed")
if not is_host_in_room: # Check if a remote join should be performed.
remote_join, remote_room_hosts = await self._should_perform_remote_join(
target.to_string(), room_id, remote_room_hosts, content, is_host_in_room
)
if remote_join:
if ratelimit: if ratelimit:
time_now_s = self.clock.time() time_now_s = self.clock.time()
( (
@ -826,6 +825,106 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
outlier=outlier, outlier=outlier,
) )
async def _should_perform_remote_join(
self,
user_id: str,
room_id: str,
remote_room_hosts: List[str],
content: JsonDict,
is_host_in_room: bool,
) -> Tuple[bool, List[str]]:
"""
Check whether the server should do a remote join (as opposed to a local
join) for a user.
Generally a remote join is used if:
* The server is not yet in the room.
* The server is in the room, the room has restricted join rules, the user
is not joined or invited to the room, and the server does not have
another user who is capable of issuing invites.
Args:
user_id: The user joining the room.
room_id: The room being joined.
remote_room_hosts: A list of remote room hosts.
content: The content to use as the event body of the join. This may
be modified.
is_host_in_room: True if the host is in the room.
Returns:
A tuple of:
True if a remote join should be performed. False if the join can be
done locally.
A list of remote room hosts to use. This is an empty list if a
local join is to be done.
"""
# If the host isn't in the room, pass through the prospective hosts.
if not is_host_in_room:
return True, remote_room_hosts
# If the host is in the room, but not one of the authorised hosts
# for restricted join rules, a remote join must be used.
room_version = await self.store.get_room_version(room_id)
current_state_ids = await self.store.get_current_state_ids(room_id)
# If restricted join rules are not being used, a local join can always
# be used.
if not await self.event_auth_handler.has_restricted_join_rules(
current_state_ids, room_version
):
return False, []
# If the user is invited to the room or already joined, the join
# event can always be issued locally.
prev_member_event_id = current_state_ids.get((EventTypes.Member, user_id), None)
prev_member_event = None
if prev_member_event_id:
prev_member_event = await self.store.get_event(prev_member_event_id)
if prev_member_event.membership in (
Membership.JOIN,
Membership.INVITE,
):
return False, []
# If the local host has a user who can issue invites, then a local
# join can be done.
#
# If not, generate a new list of remote hosts based on which
# can issue invites.
event_map = await self.store.get_events(current_state_ids.values())
current_state = {
state_key: event_map[event_id]
for state_key, event_id in current_state_ids.items()
}
allowed_servers = get_servers_from_users(
get_users_which_can_issue_invite(current_state)
)
# If the local server is not one of allowed servers, then a remote
# join must be done. Return the list of prospective servers based on
# which can issue invites.
if self.hs.hostname not in allowed_servers:
return True, list(allowed_servers)
# Ensure the member should be allowed access via membership in a room.
await self.event_auth_handler.check_restricted_join_rules(
current_state_ids, room_version, user_id, prev_member_event
)
# If this is going to be a local join, additional information must
# be included in the event content in order to efficiently validate
# the event.
content[
"join_authorised_via_users_server"
] = await self.event_auth_handler.get_user_which_could_invite(
room_id,
current_state_ids,
)
return False, []
async def transfer_room_state_on_room_upgrade( async def transfer_room_state_on_room_upgrade(
self, old_room_id: str, room_id: str self, old_room_id: str, room_id: str
) -> None: ) -> None:
@ -1514,3 +1613,63 @@ class RoomMemberMasterHandler(RoomMemberHandler):
if membership: if membership:
await self.store.forget(user_id, room_id) await self.store.forget(user_id, room_id)
def get_users_which_can_issue_invite(auth_events: StateMap[EventBase]) -> List[str]:
"""
Return the list of users which can issue invites.
This is done by exploring the joined users and comparing their power levels
to the necessyar power level to issue an invite.
Args:
auth_events: state in force at this point in the room
Returns:
The users which can issue invites.
"""
invite_level = get_named_level(auth_events, "invite", 0)
users_default_level = get_named_level(auth_events, "users_default", 0)
power_level_event = get_power_level_event(auth_events)
# Custom power-levels for users.
if power_level_event:
users = power_level_event.content.get("users", {})
else:
users = {}
result = []
# Check which members are able to invite by ensuring they're joined and have
# the necessary power level.
for (event_type, state_key), event in auth_events.items():
if event_type != EventTypes.Member:
continue
if event.membership != Membership.JOIN:
continue
# Check if the user has a custom power level.
if users.get(state_key, users_default_level) >= invite_level:
result.append(state_key)
return result
def get_servers_from_users(users: List[str]) -> Set[str]:
"""
Resolve a list of users into their servers.
Args:
users: A list of users.
Returns:
A set of servers.
"""
servers = set()
for user in users:
try:
servers.add(get_domain_from_id(user))
except SynapseError:
pass
return servers

View file

@ -636,16 +636,20 @@ class StateResolutionHandler:
""" """
try: try:
with Measure(self.clock, "state._resolve_events") as m: with Measure(self.clock, "state._resolve_events") as m:
v = KNOWN_ROOM_VERSIONS[room_version] room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
if v.state_res == StateResolutionVersions.V1: if room_version_obj.state_res == StateResolutionVersions.V1:
return await v1.resolve_events_with_store( return await v1.resolve_events_with_store(
room_id, state_sets, event_map, state_res_store.get_events room_id,
room_version_obj,
state_sets,
event_map,
state_res_store.get_events,
) )
else: else:
return await v2.resolve_events_with_store( return await v2.resolve_events_with_store(
self.clock, self.clock,
room_id, room_id,
room_version, room_version_obj,
state_sets, state_sets,
event_map, event_map,
state_res_store, state_res_store,

View file

@ -29,7 +29,7 @@ from typing import (
from synapse import event_auth from synapse import event_auth
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersion, RoomVersions
from synapse.events import EventBase from synapse.events import EventBase
from synapse.types import MutableStateMap, StateMap from synapse.types import MutableStateMap, StateMap
@ -41,6 +41,7 @@ POWER_KEY = (EventTypes.PowerLevels, "")
async def resolve_events_with_store( async def resolve_events_with_store(
room_id: str, room_id: str,
room_version: RoomVersion,
state_sets: Sequence[StateMap[str]], state_sets: Sequence[StateMap[str]],
event_map: Optional[Dict[str, EventBase]], event_map: Optional[Dict[str, EventBase]],
state_map_factory: Callable[[Iterable[str]], Awaitable[Dict[str, EventBase]]], state_map_factory: Callable[[Iterable[str]], Awaitable[Dict[str, EventBase]]],
@ -104,7 +105,7 @@ async def resolve_events_with_store(
# get the ids of the auth events which allow us to authenticate the # get the ids of the auth events which allow us to authenticate the
# conflicted state, picking only from the unconflicting state. # conflicted state, picking only from the unconflicting state.
auth_events = _create_auth_events_from_maps( auth_events = _create_auth_events_from_maps(
unconflicted_state, conflicted_state, state_map room_version, unconflicted_state, conflicted_state, state_map
) )
new_needed_events = set(auth_events.values()) new_needed_events = set(auth_events.values())
@ -132,7 +133,7 @@ async def resolve_events_with_store(
state_map.update(state_map_new) state_map.update(state_map_new)
return _resolve_with_state( return _resolve_with_state(
unconflicted_state, conflicted_state, auth_events, state_map room_version, unconflicted_state, conflicted_state, auth_events, state_map
) )
@ -187,6 +188,7 @@ def _seperate(
def _create_auth_events_from_maps( def _create_auth_events_from_maps(
room_version: RoomVersion,
unconflicted_state: StateMap[str], unconflicted_state: StateMap[str],
conflicted_state: StateMap[Set[str]], conflicted_state: StateMap[Set[str]],
state_map: Dict[str, EventBase], state_map: Dict[str, EventBase],
@ -194,6 +196,7 @@ def _create_auth_events_from_maps(
""" """
Args: Args:
room_version: The room version.
unconflicted_state: The unconflicted state map. unconflicted_state: The unconflicted state map.
conflicted_state: The conflicted state map. conflicted_state: The conflicted state map.
state_map: state_map:
@ -205,7 +208,9 @@ def _create_auth_events_from_maps(
for event_ids in conflicted_state.values(): for event_ids in conflicted_state.values():
for event_id in event_ids: for event_id in event_ids:
if event_id in state_map: if event_id in state_map:
keys = event_auth.auth_types_for_event(state_map[event_id]) keys = event_auth.auth_types_for_event(
room_version, state_map[event_id]
)
for key in keys: for key in keys:
if key not in auth_events: if key not in auth_events:
auth_event_id = unconflicted_state.get(key, None) auth_event_id = unconflicted_state.get(key, None)
@ -215,6 +220,7 @@ def _create_auth_events_from_maps(
def _resolve_with_state( def _resolve_with_state(
room_version: RoomVersion,
unconflicted_state_ids: MutableStateMap[str], unconflicted_state_ids: MutableStateMap[str],
conflicted_state_ids: StateMap[Set[str]], conflicted_state_ids: StateMap[Set[str]],
auth_event_ids: StateMap[str], auth_event_ids: StateMap[str],
@ -235,7 +241,9 @@ def _resolve_with_state(
} }
try: try:
resolved_state = _resolve_state_events(conflicted_state, auth_events) resolved_state = _resolve_state_events(
room_version, conflicted_state, auth_events
)
except Exception: except Exception:
logger.exception("Failed to resolve state") logger.exception("Failed to resolve state")
raise raise
@ -248,7 +256,9 @@ def _resolve_with_state(
def _resolve_state_events( def _resolve_state_events(
conflicted_state: StateMap[List[EventBase]], auth_events: MutableStateMap[EventBase] room_version: RoomVersion,
conflicted_state: StateMap[List[EventBase]],
auth_events: MutableStateMap[EventBase],
) -> StateMap[EventBase]: ) -> StateMap[EventBase]:
"""This is where we actually decide which of the conflicted state to """This is where we actually decide which of the conflicted state to
use. use.
@ -263,21 +273,27 @@ def _resolve_state_events(
if POWER_KEY in conflicted_state: if POWER_KEY in conflicted_state:
events = conflicted_state[POWER_KEY] events = conflicted_state[POWER_KEY]
logger.debug("Resolving conflicted power levels %r", events) logger.debug("Resolving conflicted power levels %r", events)
resolved_state[POWER_KEY] = _resolve_auth_events(events, auth_events) resolved_state[POWER_KEY] = _resolve_auth_events(
room_version, events, auth_events
)
auth_events.update(resolved_state) auth_events.update(resolved_state)
for key, events in conflicted_state.items(): for key, events in conflicted_state.items():
if key[0] == EventTypes.JoinRules: if key[0] == EventTypes.JoinRules:
logger.debug("Resolving conflicted join rules %r", events) logger.debug("Resolving conflicted join rules %r", events)
resolved_state[key] = _resolve_auth_events(events, auth_events) resolved_state[key] = _resolve_auth_events(
room_version, events, auth_events
)
auth_events.update(resolved_state) auth_events.update(resolved_state)
for key, events in conflicted_state.items(): for key, events in conflicted_state.items():
if key[0] == EventTypes.Member: if key[0] == EventTypes.Member:
logger.debug("Resolving conflicted member lists %r", events) logger.debug("Resolving conflicted member lists %r", events)
resolved_state[key] = _resolve_auth_events(events, auth_events) resolved_state[key] = _resolve_auth_events(
room_version, events, auth_events
)
auth_events.update(resolved_state) auth_events.update(resolved_state)
@ -290,12 +306,14 @@ def _resolve_state_events(
def _resolve_auth_events( def _resolve_auth_events(
events: List[EventBase], auth_events: StateMap[EventBase] room_version: RoomVersion, events: List[EventBase], auth_events: StateMap[EventBase]
) -> EventBase: ) -> EventBase:
reverse = list(reversed(_ordered_events(events))) reverse = list(reversed(_ordered_events(events)))
auth_keys = { auth_keys = {
key for event in events for key in event_auth.auth_types_for_event(event) key
for event in events
for key in event_auth.auth_types_for_event(room_version, event)
} }
new_auth_events = {} new_auth_events = {}

View file

@ -36,7 +36,7 @@ import synapse.state
from synapse import event_auth from synapse import event_auth
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import RoomVersion
from synapse.events import EventBase from synapse.events import EventBase
from synapse.types import MutableStateMap, StateMap from synapse.types import MutableStateMap, StateMap
from synapse.util import Clock from synapse.util import Clock
@ -53,7 +53,7 @@ _AWAIT_AFTER_ITERATIONS = 100
async def resolve_events_with_store( async def resolve_events_with_store(
clock: Clock, clock: Clock,
room_id: str, room_id: str,
room_version: str, room_version: RoomVersion,
state_sets: Sequence[StateMap[str]], state_sets: Sequence[StateMap[str]],
event_map: Optional[Dict[str, EventBase]], event_map: Optional[Dict[str, EventBase]],
state_res_store: "synapse.state.StateResolutionStore", state_res_store: "synapse.state.StateResolutionStore",
@ -497,7 +497,7 @@ async def _reverse_topological_power_sort(
async def _iterative_auth_checks( async def _iterative_auth_checks(
clock: Clock, clock: Clock,
room_id: str, room_id: str,
room_version: str, room_version: RoomVersion,
event_ids: List[str], event_ids: List[str],
base_state: StateMap[str], base_state: StateMap[str],
event_map: Dict[str, EventBase], event_map: Dict[str, EventBase],
@ -519,7 +519,6 @@ async def _iterative_auth_checks(
Returns the final updated state Returns the final updated state
""" """
resolved_state = dict(base_state) resolved_state = dict(base_state)
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
for idx, event_id in enumerate(event_ids, start=1): for idx, event_id in enumerate(event_ids, start=1):
event = event_map[event_id] event = event_map[event_id]
@ -538,7 +537,7 @@ async def _iterative_auth_checks(
if ev.rejected_reason is None: if ev.rejected_reason is None:
auth_events[(ev.type, ev.state_key)] = ev auth_events[(ev.type, ev.state_key)] = ev
for key in event_auth.auth_types_for_event(event): for key in event_auth.auth_types_for_event(room_version, event):
if key in resolved_state: if key in resolved_state:
ev_id = resolved_state[key] ev_id = resolved_state[key]
ev = await _get_event(room_id, ev_id, event_map, state_res_store) ev = await _get_event(room_id, ev_id, event_map, state_res_store)
@ -548,7 +547,7 @@ async def _iterative_auth_checks(
try: try:
event_auth.check( event_auth.check(
room_version_obj, room_version,
event, event,
auth_events, auth_events,
do_sig_check=False, do_sig_check=False,

View file

@ -484,7 +484,7 @@ class StateTestCase(unittest.TestCase):
state_d = resolve_events_with_store( state_d = resolve_events_with_store(
FakeClock(), FakeClock(),
ROOM_ID, ROOM_ID,
RoomVersions.V2.identifier, RoomVersions.V2,
[state_at_event[n] for n in prev_events], [state_at_event[n] for n in prev_events],
event_map=event_map, event_map=event_map,
state_res_store=TestStateResolutionStore(event_map), state_res_store=TestStateResolutionStore(event_map),
@ -496,7 +496,7 @@ class StateTestCase(unittest.TestCase):
if fake_event.state_key is not None: if fake_event.state_key is not None:
state_after[(fake_event.type, fake_event.state_key)] = event_id state_after[(fake_event.type, fake_event.state_key)] = event_id
auth_types = set(auth_types_for_event(fake_event)) auth_types = set(auth_types_for_event(RoomVersions.V6, fake_event))
auth_events = [] auth_events = []
for key in auth_types: for key in auth_types:
@ -633,7 +633,7 @@ class SimpleParamStateTestCase(unittest.TestCase):
state_d = resolve_events_with_store( state_d = resolve_events_with_store(
FakeClock(), FakeClock(),
ROOM_ID, ROOM_ID,
RoomVersions.V2.identifier, RoomVersions.V2,
[self.state_at_bob, self.state_at_charlie], [self.state_at_bob, self.state_at_charlie],
event_map=None, event_map=None,
state_res_store=TestStateResolutionStore(self.event_map), state_res_store=TestStateResolutionStore(self.event_map),

View file

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional from typing import List, Optional
from canonicaljson import json from canonicaljson import json
@ -234,8 +234,8 @@ class RedactionTestCase(unittest.HomeserverTestCase):
async def build( async def build(
self, self,
prev_event_ids, prev_event_ids: List[str],
auth_event_ids, auth_event_ids: Optional[List[str]],
depth: Optional[int] = None, depth: Optional[int] = None,
): ):
built_event = await self._base_builder.build( built_event = await self._base_builder.build(

View file

@ -351,7 +351,11 @@ class EventAuthTestCase(unittest.TestCase):
""" """
Test joining a restricted room from MSC3083. Test joining a restricted room from MSC3083.
This is pretty much the same test as public. This is similar to the public test, but has some additional checks on
signatures.
The checks which care about signatures fake them by simply adding an
object of the proper form, not generating valid signatures.
""" """
creator = "@creator:example.com" creator = "@creator:example.com"
pleb = "@joiner:example.com" pleb = "@joiner:example.com"
@ -359,6 +363,7 @@ class EventAuthTestCase(unittest.TestCase):
auth_events = { auth_events = {
("m.room.create", ""): _create_event(creator), ("m.room.create", ""): _create_event(creator),
("m.room.member", creator): _join_event(creator), ("m.room.member", creator): _join_event(creator),
("m.room.power_levels", ""): _power_levels_event(creator, {"invite": 0}),
("m.room.join_rules", ""): _join_rules_event(creator, "restricted"), ("m.room.join_rules", ""): _join_rules_event(creator, "restricted"),
} }
@ -371,19 +376,81 @@ class EventAuthTestCase(unittest.TestCase):
do_sig_check=False, do_sig_check=False,
) )
# Check join. # A properly formatted join event should work.
authorised_join_event = _join_event(
pleb,
additional_content={
"join_authorised_via_users_server": "@creator:example.com"
},
)
event_auth.check( event_auth.check(
RoomVersions.MSC3083, RoomVersions.MSC3083,
_join_event(pleb), authorised_join_event,
auth_events, auth_events,
do_sig_check=False, do_sig_check=False,
) )
# A user cannot be force-joined to a room. # A join issued by a specific user works (i.e. the power level checks
# are done properly).
pl_auth_events = auth_events.copy()
pl_auth_events[("m.room.power_levels", "")] = _power_levels_event(
creator, {"invite": 100, "users": {"@inviter:foo.test": 150}}
)
pl_auth_events[("m.room.member", "@inviter:foo.test")] = _join_event(
"@inviter:foo.test"
)
event_auth.check(
RoomVersions.MSC3083,
_join_event(
pleb,
additional_content={
"join_authorised_via_users_server": "@inviter:foo.test"
},
),
pl_auth_events,
do_sig_check=False,
)
# A join which is missing an authorised server is rejected.
with self.assertRaises(AuthError): with self.assertRaises(AuthError):
event_auth.check( event_auth.check(
RoomVersions.MSC3083, RoomVersions.MSC3083,
_member_event(pleb, "join", sender=creator), _join_event(pleb),
auth_events,
do_sig_check=False,
)
# An join authorised by a user who is not in the room is rejected.
pl_auth_events = auth_events.copy()
pl_auth_events[("m.room.power_levels", "")] = _power_levels_event(
creator, {"invite": 100, "users": {"@other:example.com": 150}}
)
with self.assertRaises(AuthError):
event_auth.check(
RoomVersions.MSC3083,
_join_event(
pleb,
additional_content={
"join_authorised_via_users_server": "@other:example.com"
},
),
auth_events,
do_sig_check=False,
)
# A user cannot be force-joined to a room. (This uses an event which
# *would* be valid, but is sent be a different user.)
with self.assertRaises(AuthError):
event_auth.check(
RoomVersions.MSC3083,
_member_event(
pleb,
"join",
sender=creator,
additional_content={
"join_authorised_via_users_server": "@inviter:foo.test"
},
),
auth_events, auth_events,
do_sig_check=False, do_sig_check=False,
) )
@ -393,7 +460,7 @@ class EventAuthTestCase(unittest.TestCase):
with self.assertRaises(AuthError): with self.assertRaises(AuthError):
event_auth.check( event_auth.check(
RoomVersions.MSC3083, RoomVersions.MSC3083,
_join_event(pleb), authorised_join_event,
auth_events, auth_events,
do_sig_check=False, do_sig_check=False,
) )
@ -402,12 +469,13 @@ class EventAuthTestCase(unittest.TestCase):
auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave") auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave")
event_auth.check( event_auth.check(
RoomVersions.MSC3083, RoomVersions.MSC3083,
_join_event(pleb), authorised_join_event,
auth_events, auth_events,
do_sig_check=False, do_sig_check=False,
) )
# A user can send a join if they're in the room. # A user can send a join if they're in the room. (This doesn't need to
# be authorised since the user is already joined.)
auth_events[("m.room.member", pleb)] = _member_event(pleb, "join") auth_events[("m.room.member", pleb)] = _member_event(pleb, "join")
event_auth.check( event_auth.check(
RoomVersions.MSC3083, RoomVersions.MSC3083,
@ -416,7 +484,8 @@ class EventAuthTestCase(unittest.TestCase):
do_sig_check=False, do_sig_check=False,
) )
# A user can accept an invite. # A user can accept an invite. (This doesn't need to be authorised since
# the user was invited.)
auth_events[("m.room.member", pleb)] = _member_event( auth_events[("m.room.member", pleb)] = _member_event(
pleb, "invite", sender=creator pleb, "invite", sender=creator
) )
@ -446,7 +515,10 @@ def _create_event(user_id: str) -> EventBase:
def _member_event( def _member_event(
user_id: str, membership: str, sender: Optional[str] = None user_id: str,
membership: str,
sender: Optional[str] = None,
additional_content: Optional[dict] = None,
) -> EventBase: ) -> EventBase:
return make_event_from_dict( return make_event_from_dict(
{ {
@ -455,14 +527,14 @@ def _member_event(
"type": "m.room.member", "type": "m.room.member",
"sender": sender or user_id, "sender": sender or user_id,
"state_key": user_id, "state_key": user_id,
"content": {"membership": membership}, "content": {"membership": membership, **(additional_content or {})},
"prev_events": [], "prev_events": [],
} }
) )
def _join_event(user_id: str) -> EventBase: def _join_event(user_id: str, additional_content: Optional[dict] = None) -> EventBase:
return _member_event(user_id, "join") return _member_event(user_id, "join", additional_content=additional_content)
def _power_levels_event(sender: str, content: JsonDict) -> EventBase: def _power_levels_event(sender: str, content: JsonDict) -> EventBase: