0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2025-04-29 18:00:04 +02:00

Add spam checker option for restricted joins

This commit is contained in:
Tulir Asokan 2025-04-11 16:19:36 +03:00
parent 7f316b1d6f
commit 16a031d746
7 changed files with 86 additions and 10 deletions

View file

@ -83,6 +83,8 @@ class RestrictedJoinRuleTypes:
ROOM_MEMBERSHIP: Final = "m.room_membership"
MAU_SPAM_CHECKER: Final = "fi.mau.spam_checker"
class LoginType:
PASSWORD: Final = "m.login.password"

View file

@ -55,6 +55,7 @@ class EventAuthHandler:
self._state_storage_controller = hs.get_storage_controllers().state
self._server_name = hs.hostname
self._is_mine_id = hs.is_mine_id
self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker
async def check_auth_rules_from_context(
self,
@ -220,6 +221,7 @@ class EventAuthHandler:
room_version: RoomVersion,
user_id: str,
prev_membership: Optional[str],
room_id: str,
) -> None:
"""
Check whether a user can join a room without an invite due to restricted join rules.
@ -251,8 +253,27 @@ class EventAuthHandler:
# Get the rooms which allow access to this room and check if the user is
# in any of them.
allowed_rooms = await self.get_rooms_that_allow_join(state_ids)
if not await self.is_user_in_rooms(allowed_rooms, user_id):
allowed_rooms, ask_spam_checker = await self.get_rooms_that_allow_join(
state_ids
)
if await self.is_user_in_rooms(allowed_rooms, user_id):
# If the user is in allowed rooms, allow the join without asking the spam checker
pass
elif ask_spam_checker:
spam_check = (
await self._spam_checker_module_callbacks.accept_make_join_callback(
user_id, room_id
)
)
if spam_check != self._spam_checker_module_callbacks.NOT_SPAM:
raise AuthError(
403,
"The spam checker rejected your join",
errcode=spam_check[0],
additional_fields=spam_check[1],
)
else:
# If this is a remote request, the user might be in an allowed room
# that we do not know about.
if not self._is_mine_id(user_id):
@ -306,7 +327,7 @@ class EventAuthHandler:
async def get_rooms_that_allow_join(
self, state_ids: StateMap[str]
) -> StrCollection:
) -> tuple[StrCollection, bool]:
"""
Generate a list of rooms in which membership allows access to a room.
@ -319,7 +340,7 @@ class EventAuthHandler:
# If there's no join rule, then it defaults to invite (so this doesn't apply).
join_rules_event_id = state_ids.get((EventTypes.JoinRules, ""), None)
if not join_rules_event_id:
return ()
return (), False
# If the join rule is not restricted, this doesn't apply.
join_rules_event = await self._store.get_event(join_rules_event_id)
@ -327,14 +348,19 @@ class EventAuthHandler:
# If allowed is of the wrong form, then only allow invited users.
allow_list = join_rules_event.content.get("allow", [])
if not isinstance(allow_list, list):
return ()
return (), False
# Pull out the other room IDs, invalid data gets filtered.
result = []
ask_spam_checker = False
for allow in allow_list:
if not isinstance(allow, dict):
continue
if allow.get("type") == RestrictedJoinRuleTypes.MAU_SPAM_CHECKER:
ask_spam_checker = True
continue
# If the type is unexpected, skip it.
if allow.get("type") != RestrictedJoinRuleTypes.ROOM_MEMBERSHIP:
continue
@ -345,7 +371,7 @@ class EventAuthHandler:
result.append(room_id)
return result
return result, ask_spam_checker
async def is_user_in_rooms(self, room_ids: StrCollection, user_id: str) -> bool:
"""

View file

@ -467,6 +467,7 @@ class FederationEventHandler:
event.room_version,
user_id,
prev_membership,
event.room_id,
)
@trace

View file

@ -1296,7 +1296,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# Ensure the member should be allowed access via membership in a room.
await self.event_auth_handler.check_restricted_join_rules(
state_before_join, room_version, user_id, previous_membership
state_before_join, room_version, user_id, previous_membership, room_id
)
# If this is going to be a local join, additional information must

View file

@ -638,7 +638,7 @@ class RoomSummaryHandler:
if await self._event_auth_handler.has_restricted_join_rules(
state_ids, room_version
):
allowed_rooms = (
allowed_rooms, _ = (
await self._event_auth_handler.get_rooms_that_allow_join(state_ids)
)
if await self._event_auth_handler.is_user_in_rooms(
@ -660,7 +660,7 @@ class RoomSummaryHandler:
if await self._event_auth_handler.has_restricted_join_rules(
state_ids, room_version
):
allowed_rooms = (
allowed_rooms, _ = (
await self._event_auth_handler.get_rooms_that_allow_join(state_ids)
)
for space_id in allowed_rooms:
@ -771,7 +771,7 @@ class RoomSummaryHandler:
if await self._event_auth_handler.has_restricted_join_rules(
current_state_ids, room_version
):
allowed_rooms = (
allowed_rooms, _ = (
await self._event_auth_handler.get_rooms_that_allow_join(
current_state_ids
)

View file

@ -101,6 +101,7 @@ from synapse.module_api.callbacks.spamchecker_callbacks import (
USER_MAY_CREATE_ROOM_CALLBACK,
USER_MAY_INVITE_CALLBACK,
USER_MAY_JOIN_ROOM_CALLBACK,
ACCEPT_MAKE_JOIN_CALLBACK,
USER_MAY_PUBLISH_ROOM_CALLBACK,
USER_MAY_SEND_3PID_INVITE_CALLBACK,
SpamCheckerModuleApiCallbacks,
@ -304,6 +305,7 @@ class ModuleApi:
SHOULD_DROP_FEDERATED_EVENT_CALLBACK
] = None,
user_may_join_room: Optional[USER_MAY_JOIN_ROOM_CALLBACK] = None,
accept_make_join: Optional[ACCEPT_MAKE_JOIN_CALLBACK] = None,
user_may_invite: Optional[USER_MAY_INVITE_CALLBACK] = None,
user_may_send_3pid_invite: Optional[USER_MAY_SEND_3PID_INVITE_CALLBACK] = None,
user_may_create_room: Optional[USER_MAY_CREATE_ROOM_CALLBACK] = None,
@ -326,6 +328,7 @@ class ModuleApi:
check_event_for_spam=check_event_for_spam,
should_drop_federated_event=should_drop_federated_event,
user_may_join_room=user_may_join_room,
accept_make_join=accept_make_join,
user_may_invite=user_may_invite,
user_may_send_3pid_invite=user_may_send_3pid_invite,
user_may_create_room=user_may_create_room,

View file

@ -88,6 +88,16 @@ USER_MAY_JOIN_ROOM_CALLBACK = Callable[
]
],
]
ACCEPT_MAKE_JOIN_CALLBACK = Callable[
[str, str],
Awaitable[
Union[
Literal["NOT_SPAM"],
Codes,
Tuple[Codes, JsonDict],
]
],
]
USER_MAY_INVITE_CALLBACK = Callable[
[str, str, str],
Awaitable[
@ -327,6 +337,7 @@ class SpamCheckerModuleApiCallbacks:
SHOULD_DROP_FEDERATED_EVENT_CALLBACK
] = []
self._user_may_join_room_callbacks: List[USER_MAY_JOIN_ROOM_CALLBACK] = []
self._accept_make_join_callbacks: List[ACCEPT_MAKE_JOIN_CALLBACK] = []
self._user_may_invite_callbacks: List[USER_MAY_INVITE_CALLBACK] = []
self._user_may_send_3pid_invite_callbacks: List[
USER_MAY_SEND_3PID_INVITE_CALLBACK
@ -354,6 +365,7 @@ class SpamCheckerModuleApiCallbacks:
SHOULD_DROP_FEDERATED_EVENT_CALLBACK
] = None,
user_may_join_room: Optional[USER_MAY_JOIN_ROOM_CALLBACK] = None,
accept_make_join: Optional[ACCEPT_MAKE_JOIN_CALLBACK] = None,
user_may_invite: Optional[USER_MAY_INVITE_CALLBACK] = None,
user_may_send_3pid_invite: Optional[USER_MAY_SEND_3PID_INVITE_CALLBACK] = None,
user_may_create_room: Optional[USER_MAY_CREATE_ROOM_CALLBACK] = None,
@ -380,6 +392,9 @@ class SpamCheckerModuleApiCallbacks:
if user_may_join_room is not None:
self._user_may_join_room_callbacks.append(user_may_join_room)
if accept_make_join is not None:
self._accept_make_join_callbacks.append(accept_make_join)
if user_may_invite is not None:
self._user_may_invite_callbacks.append(user_may_invite)
@ -536,6 +551,35 @@ class SpamCheckerModuleApiCallbacks:
# No spam-checker has rejected the request, let it pass.
return self.NOT_SPAM
async def accept_make_join_callback(
self, user_id: str, room_id: str
) -> Union[Tuple[Codes, JsonDict], Literal["NOT_SPAM"]]:
for callback in self._accept_make_join_callbacks:
with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
res = await delay_cancellation(callback(user_id, room_id))
# Normalize return values to `Codes` or `"NOT_SPAM"`.
if res is True or res is self.NOT_SPAM:
continue
elif res is False:
return synapse.api.errors.Codes.FORBIDDEN, {}
elif isinstance(res, synapse.api.errors.Codes):
return res, {}
elif (
isinstance(res, tuple)
and len(res) == 2
and isinstance(res[0], synapse.api.errors.Codes)
and isinstance(res[1], dict)
):
return res
else:
logger.warning(
"Module returned invalid value, rejecting join as spam"
)
return synapse.api.errors.Codes.FORBIDDEN, {}
# No spam-checker has rejected the request, let it pass.
return self.NOT_SPAM
async def user_may_invite(
self, inviter_userid: str, invitee_userid: str, room_id: str
) -> Union[Tuple[Codes, dict], Literal["NOT_SPAM"]]: