forked from MirrorHub/synapse
Allow spam-checker modules to be provide async methods. (#8890)
Spam checker modules can now provide async methods. This is implemented in a backwards-compatible manner.
This commit is contained in:
parent
5d34f40d49
commit
f14428b25c
19 changed files with 98 additions and 73 deletions
1
changelog.d/8890.feature
Normal file
1
changelog.d/8890.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Spam-checkers may now define their methods as `async`.
|
|
@ -22,6 +22,8 @@ well as some specific methods:
|
|||
* `user_may_create_room`
|
||||
* `user_may_create_room_alias`
|
||||
* `user_may_publish_room`
|
||||
* `check_username_for_spam`
|
||||
* `check_registration_for_spam`
|
||||
|
||||
The details of the each of these methods (as well as their inputs and outputs)
|
||||
are documented in the `synapse.events.spamcheck.SpamChecker` class.
|
||||
|
@ -32,28 +34,33 @@ call back into the homeserver internals.
|
|||
### Example
|
||||
|
||||
```python
|
||||
from synapse.spam_checker_api import RegistrationBehaviour
|
||||
|
||||
class ExampleSpamChecker:
|
||||
def __init__(self, config, api):
|
||||
self.config = config
|
||||
self.api = api
|
||||
|
||||
def check_event_for_spam(self, foo):
|
||||
async def check_event_for_spam(self, foo):
|
||||
return False # allow all events
|
||||
|
||||
def user_may_invite(self, inviter_userid, invitee_userid, room_id):
|
||||
async def user_may_invite(self, inviter_userid, invitee_userid, room_id):
|
||||
return True # allow all invites
|
||||
|
||||
def user_may_create_room(self, userid):
|
||||
async def user_may_create_room(self, userid):
|
||||
return True # allow all room creations
|
||||
|
||||
def user_may_create_room_alias(self, userid, room_alias):
|
||||
async def user_may_create_room_alias(self, userid, room_alias):
|
||||
return True # allow all room aliases
|
||||
|
||||
def user_may_publish_room(self, userid, room_id):
|
||||
async def user_may_publish_room(self, userid, room_id):
|
||||
return True # allow publishing of all rooms
|
||||
|
||||
def check_username_for_spam(self, user_profile):
|
||||
async def check_username_for_spam(self, user_profile):
|
||||
return False # allow all usernames
|
||||
|
||||
async def check_registration_for_spam(self, email_threepid, username, request_info):
|
||||
return RegistrationBehaviour.ALLOW # allow all registrations
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
|
|
@ -15,10 +15,11 @@
|
|||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from synapse.spam_checker_api import RegistrationBehaviour
|
||||
from synapse.types import Collection
|
||||
from synapse.util.async_helpers import maybe_awaitable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import synapse.events
|
||||
|
@ -39,7 +40,9 @@ class SpamChecker:
|
|||
else:
|
||||
self.spam_checkers.append(module(config=config))
|
||||
|
||||
def check_event_for_spam(self, event: "synapse.events.EventBase") -> bool:
|
||||
async def check_event_for_spam(
|
||||
self, event: "synapse.events.EventBase"
|
||||
) -> Union[bool, str]:
|
||||
"""Checks if a given event is considered "spammy" by this server.
|
||||
|
||||
If the server considers an event spammy, then it will be rejected if
|
||||
|
@ -50,15 +53,16 @@ class SpamChecker:
|
|||
event: the event to be checked
|
||||
|
||||
Returns:
|
||||
True if the event is spammy.
|
||||
True or a string if the event is spammy. If a string is returned it
|
||||
will be used as the error message returned to the user.
|
||||
"""
|
||||
for spam_checker in self.spam_checkers:
|
||||
if spam_checker.check_event_for_spam(event):
|
||||
if await maybe_awaitable(spam_checker.check_event_for_spam(event)):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def user_may_invite(
|
||||
async def user_may_invite(
|
||||
self, inviter_userid: str, invitee_userid: str, room_id: str
|
||||
) -> bool:
|
||||
"""Checks if a given user may send an invite
|
||||
|
@ -75,14 +79,18 @@ class SpamChecker:
|
|||
"""
|
||||
for spam_checker in self.spam_checkers:
|
||||
if (
|
||||
spam_checker.user_may_invite(inviter_userid, invitee_userid, room_id)
|
||||
await maybe_awaitable(
|
||||
spam_checker.user_may_invite(
|
||||
inviter_userid, invitee_userid, room_id
|
||||
)
|
||||
)
|
||||
is False
|
||||
):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def user_may_create_room(self, userid: str) -> bool:
|
||||
async def user_may_create_room(self, userid: str) -> bool:
|
||||
"""Checks if a given user may create a room
|
||||
|
||||
If this method returns false, the creation request will be rejected.
|
||||
|
@ -94,12 +102,15 @@ class SpamChecker:
|
|||
True if the user may create a room, otherwise False
|
||||
"""
|
||||
for spam_checker in self.spam_checkers:
|
||||
if spam_checker.user_may_create_room(userid) is False:
|
||||
if (
|
||||
await maybe_awaitable(spam_checker.user_may_create_room(userid))
|
||||
is False
|
||||
):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def user_may_create_room_alias(self, userid: str, room_alias: str) -> bool:
|
||||
async def user_may_create_room_alias(self, userid: str, room_alias: str) -> bool:
|
||||
"""Checks if a given user may create a room alias
|
||||
|
||||
If this method returns false, the association request will be rejected.
|
||||
|
@ -112,12 +123,17 @@ class SpamChecker:
|
|||
True if the user may create a room alias, otherwise False
|
||||
"""
|
||||
for spam_checker in self.spam_checkers:
|
||||
if spam_checker.user_may_create_room_alias(userid, room_alias) is False:
|
||||
if (
|
||||
await maybe_awaitable(
|
||||
spam_checker.user_may_create_room_alias(userid, room_alias)
|
||||
)
|
||||
is False
|
||||
):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def user_may_publish_room(self, userid: str, room_id: str) -> bool:
|
||||
async def user_may_publish_room(self, userid: str, room_id: str) -> bool:
|
||||
"""Checks if a given user may publish a room to the directory
|
||||
|
||||
If this method returns false, the publish request will be rejected.
|
||||
|
@ -130,12 +146,17 @@ class SpamChecker:
|
|||
True if the user may publish the room, otherwise False
|
||||
"""
|
||||
for spam_checker in self.spam_checkers:
|
||||
if spam_checker.user_may_publish_room(userid, room_id) is False:
|
||||
if (
|
||||
await maybe_awaitable(
|
||||
spam_checker.user_may_publish_room(userid, room_id)
|
||||
)
|
||||
is False
|
||||
):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool:
|
||||
async def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool:
|
||||
"""Checks if a user ID or display name are considered "spammy" by this server.
|
||||
|
||||
If the server considers a username spammy, then it will not be included in
|
||||
|
@ -157,12 +178,12 @@ class SpamChecker:
|
|||
if checker:
|
||||
# Make a copy of the user profile object to ensure the spam checker
|
||||
# cannot modify it.
|
||||
if checker(user_profile.copy()):
|
||||
if await maybe_awaitable(checker(user_profile.copy())):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def check_registration_for_spam(
|
||||
async def check_registration_for_spam(
|
||||
self,
|
||||
email_threepid: Optional[dict],
|
||||
username: Optional[str],
|
||||
|
@ -185,7 +206,9 @@ class SpamChecker:
|
|||
# spam checker
|
||||
checker = getattr(spam_checker, "check_registration_for_spam", None)
|
||||
if checker:
|
||||
behaviour = checker(email_threepid, username, request_info)
|
||||
behaviour = await maybe_awaitable(
|
||||
checker(email_threepid, username, request_info)
|
||||
)
|
||||
assert isinstance(behaviour, RegistrationBehaviour)
|
||||
if behaviour != RegistrationBehaviour.ALLOW:
|
||||
return behaviour
|
||||
|
|
|
@ -78,6 +78,7 @@ class FederationBase:
|
|||
|
||||
ctx = current_context()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def callback(_, pdu: EventBase):
|
||||
with PreserveLoggingContext(ctx):
|
||||
if not check_event_content_hash(pdu):
|
||||
|
@ -105,7 +106,11 @@ class FederationBase:
|
|||
)
|
||||
return redacted_event
|
||||
|
||||
if self.spam_checker.check_event_for_spam(pdu):
|
||||
result = yield defer.ensureDeferred(
|
||||
self.spam_checker.check_event_for_spam(pdu)
|
||||
)
|
||||
|
||||
if result:
|
||||
logger.warning(
|
||||
"Event contains spam, redacting %s: %s",
|
||||
pdu.event_id,
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
import logging
|
||||
import time
|
||||
import unicodedata
|
||||
|
@ -59,6 +58,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
|
|||
from synapse.module_api import ModuleApi
|
||||
from synapse.types import JsonDict, Requester, UserID
|
||||
from synapse.util import stringutils as stringutils
|
||||
from synapse.util.async_helpers import maybe_awaitable
|
||||
from synapse.util.msisdn import phone_number_to_msisdn
|
||||
from synapse.util.threepids import canonicalise_email
|
||||
|
||||
|
@ -1639,6 +1639,6 @@ class PasswordProvider:
|
|||
|
||||
# This might return an awaitable, if it does block the log out
|
||||
# until it completes.
|
||||
result = g(user_id=user_id, device_id=device_id, access_token=access_token,)
|
||||
if inspect.isawaitable(result):
|
||||
await result
|
||||
await maybe_awaitable(
|
||||
g(user_id=user_id, device_id=device_id, access_token=access_token,)
|
||||
)
|
||||
|
|
|
@ -133,7 +133,9 @@ class DirectoryHandler(BaseHandler):
|
|||
403, "You must be in the room to create an alias for it"
|
||||
)
|
||||
|
||||
if not self.spam_checker.user_may_create_room_alias(user_id, room_alias):
|
||||
if not await self.spam_checker.user_may_create_room_alias(
|
||||
user_id, room_alias
|
||||
):
|
||||
raise AuthError(403, "This user is not permitted to create this alias")
|
||||
|
||||
if not self.config.is_alias_creation_allowed(
|
||||
|
@ -409,7 +411,7 @@ class DirectoryHandler(BaseHandler):
|
|||
"""
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
if not self.spam_checker.user_may_publish_room(user_id, room_id):
|
||||
if not await self.spam_checker.user_may_publish_room(user_id, room_id):
|
||||
raise AuthError(
|
||||
403, "This user is not permitted to publish rooms to the room list"
|
||||
)
|
||||
|
|
|
@ -1593,7 +1593,7 @@ class FederationHandler(BaseHandler):
|
|||
if self.hs.config.block_non_admin_invites:
|
||||
raise SynapseError(403, "This server does not accept room invites")
|
||||
|
||||
if not self.spam_checker.user_may_invite(
|
||||
if not await self.spam_checker.user_may_invite(
|
||||
event.sender, event.state_key, event.room_id
|
||||
):
|
||||
raise SynapseError(
|
||||
|
|
|
@ -744,7 +744,7 @@ class EventCreationHandler:
|
|||
event.sender,
|
||||
)
|
||||
|
||||
spam_error = self.spam_checker.check_event_for_spam(event)
|
||||
spam_error = await self.spam_checker.check_event_for_spam(event)
|
||||
if spam_error:
|
||||
if not isinstance(spam_error, str):
|
||||
spam_error = "Spam is not permitted here"
|
||||
|
|
|
@ -18,7 +18,6 @@ from typing import List, Tuple
|
|||
from synapse.appservice import ApplicationService
|
||||
from synapse.handlers._base import BaseHandler
|
||||
from synapse.types import JsonDict, ReadReceipt, get_domain_from_id
|
||||
from synapse.util.async_helpers import maybe_awaitable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -98,10 +97,8 @@ class ReceiptsHandler(BaseHandler):
|
|||
|
||||
self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids)
|
||||
# Note that the min here shouldn't be relied upon to be accurate.
|
||||
await maybe_awaitable(
|
||||
self.hs.get_pusherpool().on_new_receipts(
|
||||
min_batch_id, max_batch_id, affected_room_ids
|
||||
)
|
||||
await self.hs.get_pusherpool().on_new_receipts(
|
||||
min_batch_id, max_batch_id, affected_room_ids
|
||||
)
|
||||
|
||||
return True
|
||||
|
|
|
@ -187,7 +187,7 @@ class RegistrationHandler(BaseHandler):
|
|||
"""
|
||||
self.check_registration_ratelimit(address)
|
||||
|
||||
result = self.spam_checker.check_registration_for_spam(
|
||||
result = await self.spam_checker.check_registration_for_spam(
|
||||
threepid, localpart, user_agent_ips or [],
|
||||
)
|
||||
|
||||
|
|
|
@ -358,7 +358,7 @@ class RoomCreationHandler(BaseHandler):
|
|||
"""
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
if not self.spam_checker.user_may_create_room(user_id):
|
||||
if not await self.spam_checker.user_may_create_room(user_id):
|
||||
raise SynapseError(403, "You are not permitted to create rooms")
|
||||
|
||||
creation_content = {
|
||||
|
@ -609,7 +609,7 @@ class RoomCreationHandler(BaseHandler):
|
|||
403, "You are not permitted to create rooms", Codes.FORBIDDEN
|
||||
)
|
||||
|
||||
if not is_requester_admin and not self.spam_checker.user_may_create_room(
|
||||
if not is_requester_admin and not await self.spam_checker.user_may_create_room(
|
||||
user_id
|
||||
):
|
||||
raise SynapseError(403, "You are not permitted to create rooms")
|
||||
|
|
|
@ -408,7 +408,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
|||
)
|
||||
block_invite = True
|
||||
|
||||
if not self.spam_checker.user_may_invite(
|
||||
if not await self.spam_checker.user_may_invite(
|
||||
requester.user.to_string(), target.to_string(), room_id
|
||||
):
|
||||
logger.info("Blocking invite due to spam checker")
|
||||
|
|
|
@ -81,11 +81,11 @@ class UserDirectoryHandler(StateDeltasHandler):
|
|||
results = await self.store.search_user_dir(user_id, search_term, limit)
|
||||
|
||||
# Remove any spammy users from the results.
|
||||
results["results"] = [
|
||||
user
|
||||
for user in results["results"]
|
||||
if not self.spam_checker.check_username_for_spam(user)
|
||||
]
|
||||
non_spammy_users = []
|
||||
for user in results["results"]:
|
||||
if not await self.spam_checker.check_username_for_spam(user):
|
||||
non_spammy_users.append(user)
|
||||
results["results"] = non_spammy_users
|
||||
|
||||
return results
|
||||
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
import threading
|
||||
from functools import wraps
|
||||
|
@ -25,6 +24,7 @@ from twisted.internet import defer
|
|||
|
||||
from synapse.logging.context import LoggingContext, PreserveLoggingContext
|
||||
from synapse.logging.opentracing import noop_context_manager, start_active_span
|
||||
from synapse.util.async_helpers import maybe_awaitable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import resource
|
||||
|
@ -206,12 +206,7 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar
|
|||
if bg_start_span:
|
||||
ctx = start_active_span(desc, tags={"request_id": context.request})
|
||||
with ctx:
|
||||
result = func(*args, **kwargs)
|
||||
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
|
||||
return result
|
||||
return await maybe_awaitable(func(*args, **kwargs))
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Background process '%s' threw an exception", desc,
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
|
@ -21,6 +20,7 @@ from typing import Optional
|
|||
|
||||
from synapse.config._base import Config
|
||||
from synapse.logging.context import defer_to_thread, run_in_background
|
||||
from synapse.util.async_helpers import maybe_awaitable
|
||||
|
||||
from ._base import FileInfo, Responder
|
||||
from .media_storage import FileResponder
|
||||
|
@ -91,16 +91,14 @@ class StorageProviderWrapper(StorageProvider):
|
|||
if self.store_synchronous:
|
||||
# store_file is supposed to return an Awaitable, but guard
|
||||
# against improper implementations.
|
||||
result = self.backend.store_file(path, file_info)
|
||||
if inspect.isawaitable(result):
|
||||
return await result
|
||||
return await maybe_awaitable(self.backend.store_file(path, file_info))
|
||||
else:
|
||||
# TODO: Handle errors.
|
||||
async def store():
|
||||
try:
|
||||
result = self.backend.store_file(path, file_info)
|
||||
if inspect.isawaitable(result):
|
||||
return await result
|
||||
return await maybe_awaitable(
|
||||
self.backend.store_file(path, file_info)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error storing file")
|
||||
|
||||
|
@ -110,9 +108,7 @@ class StorageProviderWrapper(StorageProvider):
|
|||
async def fetch(self, path, file_info):
|
||||
# store_file is supposed to return an Awaitable, but guard
|
||||
# against improper implementations.
|
||||
result = self.backend.fetch(path, file_info)
|
||||
if inspect.isawaitable(result):
|
||||
return await result
|
||||
return await maybe_awaitable(self.backend.fetch(path, file_info))
|
||||
|
||||
|
||||
class FileStorageProviderBackend(StorageProvider):
|
||||
|
|
|
@ -618,7 +618,7 @@ class HomeServer(metaclass=abc.ABCMeta):
|
|||
return StatsHandler(self)
|
||||
|
||||
@cache_in_self
|
||||
def get_spam_checker(self):
|
||||
def get_spam_checker(self) -> SpamChecker:
|
||||
return SpamChecker(self)
|
||||
|
||||
@cache_in_self
|
||||
|
|
|
@ -15,10 +15,12 @@
|
|||
# limitations under the License.
|
||||
|
||||
import collections
|
||||
import inspect
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
Hashable,
|
||||
|
@ -542,11 +544,11 @@ class DoneAwaitable:
|
|||
raise StopIteration(self.value)
|
||||
|
||||
|
||||
def maybe_awaitable(value):
|
||||
def maybe_awaitable(value: Union[Awaitable[R], R]) -> Awaitable[R]:
|
||||
"""Convert a value to an awaitable if not already an awaitable.
|
||||
"""
|
||||
|
||||
if hasattr(value, "__await__"):
|
||||
if inspect.isawaitable(value):
|
||||
assert isinstance(value, Awaitable)
|
||||
return value
|
||||
|
||||
return DoneAwaitable(value)
|
||||
|
|
|
@ -12,13 +12,13 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.util.async_helpers import maybe_awaitable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -105,10 +105,7 @@ class Signal:
|
|||
|
||||
async def do(observer):
|
||||
try:
|
||||
result = observer(*args, **kwargs)
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
return result
|
||||
return await maybe_awaitable(observer(*args, **kwargs))
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"%s signal observer %s failed: %r", self.name, observer, e,
|
||||
|
|
|
@ -270,7 +270,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
|||
spam_checker = self.hs.get_spam_checker()
|
||||
|
||||
class AllowAll:
|
||||
def check_username_for_spam(self, user_profile):
|
||||
async def check_username_for_spam(self, user_profile):
|
||||
# Allow all users.
|
||||
return False
|
||||
|
||||
|
@ -283,7 +283,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
# Configure a spam checker that filters all users.
|
||||
class BlockAll:
|
||||
def check_username_for_spam(self, user_profile):
|
||||
async def check_username_for_spam(self, user_profile):
|
||||
# All users are spammy.
|
||||
return True
|
||||
|
||||
|
|
Loading…
Reference in a new issue