Implement cancellation support/protection for module callbacks (#12568)

There's no guarantee that module callbacks will handle cancellation
appropriately. Protect module callbacks with read semantics from
cancellation and avoid swallowing `CancelledError`s that arise.

Other module callbacks, such as the `on_*` callbacks, are presumed to
live on code paths that involve writes and aren't cancellation-friendly.
These module callbacks have been left alone.

Signed-off-by: Sean Quah <seanq@element.io>
This commit is contained in:
Sean Quah 2022-05-09 12:31:14 +01:00 committed by GitHub
parent 8de0facaae
commit a00462dd99
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 86 additions and 27 deletions

1
changelog.d/12568.misc Normal file
View file

@ -0,0 +1 @@
Protect module callbacks with read semantics against cancellation.

View file

@ -28,8 +28,10 @@ from typing import (
from typing_extensions import ParamSpec from typing_extensions import ParamSpec
from twisted.internet.defer import CancelledError
from synapse.api.presence import UserPresenceState from synapse.api.presence import UserPresenceState
from synapse.util.async_helpers import maybe_awaitable from synapse.util.async_helpers import delay_cancellation, maybe_awaitable
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -158,7 +160,9 @@ class PresenceRouter:
try: try:
# Note: result is an object here, because we don't trust modules to # Note: result is an object here, because we don't trust modules to
# return the types they're supposed to. # return the types they're supposed to.
result: object = await callback(state_updates) result: object = await delay_cancellation(callback(state_updates))
except CancelledError:
raise
except Exception as e: except Exception as e:
logger.warning("Failed to run module API callback %s: %s", callback, e) logger.warning("Failed to run module API callback %s: %s", callback, e)
continue continue
@ -210,7 +214,9 @@ class PresenceRouter:
# run all the callbacks for get_interested_users and combine the results # run all the callbacks for get_interested_users and combine the results
for callback in self._get_interested_users_callbacks: for callback in self._get_interested_users_callbacks:
try: try:
result = await callback(user_id) result = await delay_cancellation(callback(user_id))
except CancelledError:
raise
except Exception as e: except Exception as e:
logger.warning("Failed to run module API callback %s: %s", callback, e) logger.warning("Failed to run module API callback %s: %s", callback, e)
continue continue

View file

@ -31,7 +31,7 @@ from synapse.rest.media.v1._base import FileInfo
from synapse.rest.media.v1.media_storage import ReadableFileWrapper from synapse.rest.media.v1.media_storage import ReadableFileWrapper
from synapse.spam_checker_api import RegistrationBehaviour from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import RoomAlias, UserProfile from synapse.types import RoomAlias, UserProfile
from synapse.util.async_helpers import maybe_awaitable from synapse.util.async_helpers import delay_cancellation, maybe_awaitable
if TYPE_CHECKING: if TYPE_CHECKING:
import synapse.events import synapse.events
@ -255,7 +255,7 @@ class SpamChecker:
will be used as the error message returned to the user. will be used as the error message returned to the user.
""" """
for callback in self._check_event_for_spam_callbacks: for callback in self._check_event_for_spam_callbacks:
res: Union[bool, str] = await callback(event) res: Union[bool, str] = await delay_cancellation(callback(event))
if res: if res:
return res return res
@ -276,7 +276,10 @@ class SpamChecker:
Whether the user may join the room Whether the user may join the room
""" """
for callback in self._user_may_join_room_callbacks: for callback in self._user_may_join_room_callbacks:
if await callback(user_id, room_id, is_invited) is False: may_join_room = await delay_cancellation(
callback(user_id, room_id, is_invited)
)
if may_join_room is False:
return False return False
return True return True
@ -297,7 +300,10 @@ class SpamChecker:
True if the user may send an invite, otherwise False True if the user may send an invite, otherwise False
""" """
for callback in self._user_may_invite_callbacks: for callback in self._user_may_invite_callbacks:
if await callback(inviter_userid, invitee_userid, room_id) is False: may_invite = await delay_cancellation(
callback(inviter_userid, invitee_userid, room_id)
)
if may_invite is False:
return False return False
return True return True
@ -322,7 +328,10 @@ class SpamChecker:
True if the user may send the invite, otherwise False True if the user may send the invite, otherwise False
""" """
for callback in self._user_may_send_3pid_invite_callbacks: for callback in self._user_may_send_3pid_invite_callbacks:
if await callback(inviter_userid, medium, address, room_id) is False: may_send_3pid_invite = await delay_cancellation(
callback(inviter_userid, medium, address, room_id)
)
if may_send_3pid_invite is False:
return False return False
return True return True
@ -339,7 +348,8 @@ class SpamChecker:
True if the user may create a room, otherwise False True if the user may create a room, otherwise False
""" """
for callback in self._user_may_create_room_callbacks: for callback in self._user_may_create_room_callbacks:
if await callback(userid) is False: may_create_room = await delay_cancellation(callback(userid))
if may_create_room is False:
return False return False
return True return True
@ -359,7 +369,10 @@ class SpamChecker:
True if the user may create a room alias, otherwise False True if the user may create a room alias, otherwise False
""" """
for callback in self._user_may_create_room_alias_callbacks: for callback in self._user_may_create_room_alias_callbacks:
if await callback(userid, room_alias) is False: may_create_room_alias = await delay_cancellation(
callback(userid, room_alias)
)
if may_create_room_alias is False:
return False return False
return True return True
@ -377,7 +390,8 @@ class SpamChecker:
True if the user may publish the room, otherwise False True if the user may publish the room, otherwise False
""" """
for callback in self._user_may_publish_room_callbacks: for callback in self._user_may_publish_room_callbacks:
if await callback(userid, room_id) is False: may_publish_room = await delay_cancellation(callback(userid, room_id))
if may_publish_room is False:
return False return False
return True return True
@ -400,7 +414,7 @@ class SpamChecker:
for callback in self._check_username_for_spam_callbacks: for callback in self._check_username_for_spam_callbacks:
# Make a copy of the user profile object to ensure the spam checker cannot # Make a copy of the user profile object to ensure the spam checker cannot
# modify it. # modify it.
if await callback(user_profile.copy()): if await delay_cancellation(callback(user_profile.copy())):
return True return True
return False return False
@ -428,7 +442,7 @@ class SpamChecker:
""" """
for callback in self._check_registration_for_spam_callbacks: for callback in self._check_registration_for_spam_callbacks:
behaviour = await ( behaviour = await delay_cancellation(
callback(email_threepid, username, request_info, auth_provider_id) callback(email_threepid, username, request_info, auth_provider_id)
) )
assert isinstance(behaviour, RegistrationBehaviour) assert isinstance(behaviour, RegistrationBehaviour)
@ -472,7 +486,7 @@ class SpamChecker:
""" """
for callback in self._check_media_file_for_spam_callbacks: for callback in self._check_media_file_for_spam_callbacks:
spam = await callback(file_wrapper, file_info) spam = await delay_cancellation(callback(file_wrapper, file_info))
if spam: if spam:
return True return True

View file

@ -14,12 +14,14 @@
import logging import logging
from typing import TYPE_CHECKING, Any, Awaitable, Callable, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Awaitable, Callable, List, Optional, Tuple
from twisted.internet.defer import CancelledError
from synapse.api.errors import ModuleFailedException, SynapseError from synapse.api.errors import ModuleFailedException, SynapseError
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.storage.roommember import ProfileInfo from synapse.storage.roommember import ProfileInfo
from synapse.types import Requester, StateMap from synapse.types import Requester, StateMap
from synapse.util.async_helpers import maybe_awaitable from synapse.util.async_helpers import delay_cancellation, maybe_awaitable
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -263,7 +265,11 @@ class ThirdPartyEventRules:
for callback in self._check_event_allowed_callbacks: for callback in self._check_event_allowed_callbacks:
try: try:
res, replacement_data = await callback(event, state_events) res, replacement_data = await delay_cancellation(
callback(event, state_events)
)
except CancelledError:
raise
except SynapseError as e: except SynapseError as e:
# FIXME: Being able to throw SynapseErrors is relied upon by # FIXME: Being able to throw SynapseErrors is relied upon by
# some modules. PR #10386 accidentally broke this ability. # some modules. PR #10386 accidentally broke this ability.
@ -333,8 +339,13 @@ class ThirdPartyEventRules:
for callback in self._check_threepid_can_be_invited_callbacks: for callback in self._check_threepid_can_be_invited_callbacks:
try: try:
if await callback(medium, address, state_events) is False: threepid_can_be_invited = await delay_cancellation(
callback(medium, address, state_events)
)
if threepid_can_be_invited is False:
return False return False
except CancelledError:
raise
except Exception as e: except Exception as e:
logger.warning("Failed to run module API callback %s: %s", callback, e) logger.warning("Failed to run module API callback %s: %s", callback, e)
@ -361,8 +372,13 @@ class ThirdPartyEventRules:
for callback in self._check_visibility_can_be_modified_callbacks: for callback in self._check_visibility_can_be_modified_callbacks:
try: try:
if await callback(room_id, state_events, new_visibility) is False: visibility_can_be_modified = await delay_cancellation(
callback(room_id, state_events, new_visibility)
)
if visibility_can_be_modified is False:
return False return False
except CancelledError:
raise
except Exception as e: except Exception as e:
logger.warning("Failed to run module API callback %s: %s", callback, e) logger.warning("Failed to run module API callback %s: %s", callback, e)
@ -400,8 +416,11 @@ class ThirdPartyEventRules:
""" """
for callback in self._check_can_shutdown_room_callbacks: for callback in self._check_can_shutdown_room_callbacks:
try: try:
if await callback(user_id, room_id) is False: can_shutdown_room = await delay_cancellation(callback(user_id, room_id))
if can_shutdown_room is False:
return False return False
except CancelledError:
raise
except Exception as e: except Exception as e:
logger.exception( logger.exception(
"Failed to run module API callback %s: %s", callback, e "Failed to run module API callback %s: %s", callback, e
@ -422,8 +441,13 @@ class ThirdPartyEventRules:
""" """
for callback in self._check_can_deactivate_user_callbacks: for callback in self._check_can_deactivate_user_callbacks:
try: try:
if await callback(user_id, by_admin) is False: can_deactivate_user = await delay_cancellation(
callback(user_id, by_admin)
)
if can_deactivate_user is False:
return False return False
except CancelledError:
raise
except Exception as e: except Exception as e:
logger.exception( logger.exception(
"Failed to run module API callback %s: %s", callback, e "Failed to run module API callback %s: %s", callback, e

View file

@ -23,6 +23,7 @@ from synapse.api.errors import AuthError, StoreError, SynapseError
from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.types import UserID from synapse.types import UserID
from synapse.util import stringutils from synapse.util import stringutils
from synapse.util.async_helpers import delay_cancellation
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -150,7 +151,7 @@ class AccountValidityHandler:
Whether the user has expired. Whether the user has expired.
""" """
for callback in self._is_user_expired_callbacks: for callback in self._is_user_expired_callbacks:
expired = await callback(user_id) expired = await delay_cancellation(callback(user_id))
if expired is not None: if expired is not None:
return expired return expired

View file

@ -41,6 +41,7 @@ import pymacaroons
import unpaddedbase64 import unpaddedbase64
from pymacaroons.exceptions import MacaroonVerificationFailedException from pymacaroons.exceptions import MacaroonVerificationFailedException
from twisted.internet.defer import CancelledError
from twisted.web.server import Request from twisted.web.server import Request
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
@ -67,7 +68,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.roommember import ProfileInfo from synapse.storage.roommember import ProfileInfo
from synapse.types import JsonDict, Requester, UserID from synapse.types import JsonDict, Requester, UserID
from synapse.util import stringutils as stringutils from synapse.util import stringutils as stringutils
from synapse.util.async_helpers import maybe_awaitable from synapse.util.async_helpers import delay_cancellation, maybe_awaitable
from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.stringutils import base62_encode from synapse.util.stringutils import base62_encode
@ -2202,7 +2203,11 @@ class PasswordAuthProvider:
# other than None (i.e. until a callback returns a success) # other than None (i.e. until a callback returns a success)
for callback in self.auth_checker_callbacks[login_type]: for callback in self.auth_checker_callbacks[login_type]:
try: try:
result = await callback(username, login_type, login_dict) result = await delay_cancellation(
callback(username, login_type, login_dict)
)
except CancelledError:
raise
except Exception as e: except Exception as e:
logger.warning("Failed to run module API callback %s: %s", callback, e) logger.warning("Failed to run module API callback %s: %s", callback, e)
continue continue
@ -2263,7 +2268,9 @@ class PasswordAuthProvider:
for callback in self.check_3pid_auth_callbacks: for callback in self.check_3pid_auth_callbacks:
try: try:
result = await callback(medium, address, password) result = await delay_cancellation(callback(medium, address, password))
except CancelledError:
raise
except Exception as e: except Exception as e:
logger.warning("Failed to run module API callback %s: %s", callback, e) logger.warning("Failed to run module API callback %s: %s", callback, e)
continue continue
@ -2345,7 +2352,7 @@ class PasswordAuthProvider:
""" """
for callback in self.get_username_for_registration_callbacks: for callback in self.get_username_for_registration_callbacks:
try: try:
res = await callback(uia_results, params) res = await delay_cancellation(callback(uia_results, params))
if isinstance(res, str): if isinstance(res, str):
return res return res
@ -2359,6 +2366,8 @@ class PasswordAuthProvider:
callback, callback,
res, res,
) )
except CancelledError:
raise
except Exception as e: except Exception as e:
logger.error( logger.error(
"Module raised an exception in get_username_for_registration: %s", "Module raised an exception in get_username_for_registration: %s",
@ -2388,7 +2397,7 @@ class PasswordAuthProvider:
""" """
for callback in self.get_displayname_for_registration_callbacks: for callback in self.get_displayname_for_registration_callbacks:
try: try:
res = await callback(uia_results, params) res = await delay_cancellation(callback(uia_results, params))
if isinstance(res, str): if isinstance(res, str):
return res return res
@ -2402,6 +2411,8 @@ class PasswordAuthProvider:
callback, callback,
res, res,
) )
except CancelledError:
raise
except Exception as e: except Exception as e:
logger.error( logger.error(
"Module raised an exception in get_displayname_for_registration: %s", "Module raised an exception in get_displayname_for_registration: %s",
@ -2429,7 +2440,7 @@ class PasswordAuthProvider:
""" """
for callback in self.is_3pid_allowed_callbacks: for callback in self.is_3pid_allowed_callbacks:
try: try:
res = await callback(medium, address, registration) res = await delay_cancellation(callback(medium, address, registration))
if res is False: if res is False:
return res return res
@ -2443,6 +2454,8 @@ class PasswordAuthProvider:
callback, callback,
res, res,
) )
except CancelledError:
raise
except Exception as e: except Exception as e:
logger.error("Module raised an exception in is_3pid_allowed: %s", e) logger.error("Module raised an exception in is_3pid_allowed: %s", e)
raise SynapseError(code=500, msg="Internal Server Error") raise SynapseError(code=500, msg="Internal Server Error")