Allow spam-checker modules to be provide async methods. (#8890)

Spam checker modules can now provide async methods. This is implemented
in a backwards-compatible manner.
This commit is contained in:
David Teller 2020-12-11 20:05:15 +01:00 committed by GitHub
parent 5d34f40d49
commit f14428b25c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 98 additions and 73 deletions

1
changelog.d/8890.feature Normal file
View file

@ -0,0 +1 @@
Spam-checkers may now define their methods as `async`.

View file

@ -22,6 +22,8 @@ well as some specific methods:
* `user_may_create_room`
* `user_may_create_room_alias`
* `user_may_publish_room`
* `check_username_for_spam`
* `check_registration_for_spam`
The details of the each of these methods (as well as their inputs and outputs)
are documented in the `synapse.events.spamcheck.SpamChecker` class.
@ -32,28 +34,33 @@ call back into the homeserver internals.
### Example
```python
from synapse.spam_checker_api import RegistrationBehaviour
class ExampleSpamChecker:
def __init__(self, config, api):
self.config = config
self.api = api
def check_event_for_spam(self, foo):
async def check_event_for_spam(self, foo):
return False # allow all events
def user_may_invite(self, inviter_userid, invitee_userid, room_id):
async def user_may_invite(self, inviter_userid, invitee_userid, room_id):
return True # allow all invites
def user_may_create_room(self, userid):
async def user_may_create_room(self, userid):
return True # allow all room creations
def user_may_create_room_alias(self, userid, room_alias):
async def user_may_create_room_alias(self, userid, room_alias):
return True # allow all room aliases
def user_may_publish_room(self, userid, room_id):
async def user_may_publish_room(self, userid, room_id):
return True # allow publishing of all rooms
def check_username_for_spam(self, user_profile):
async def check_username_for_spam(self, user_profile):
return False # allow all usernames
async def check_registration_for_spam(self, email_threepid, username, request_info):
return RegistrationBehaviour.ALLOW # allow all registrations
```
## Configuration

View file

@ -15,10 +15,11 @@
# limitations under the License.
import inspect
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import Collection
from synapse.util.async_helpers import maybe_awaitable
if TYPE_CHECKING:
import synapse.events
@ -39,7 +40,9 @@ class SpamChecker:
else:
self.spam_checkers.append(module(config=config))
def check_event_for_spam(self, event: "synapse.events.EventBase") -> bool:
async def check_event_for_spam(
self, event: "synapse.events.EventBase"
) -> Union[bool, str]:
"""Checks if a given event is considered "spammy" by this server.
If the server considers an event spammy, then it will be rejected if
@ -50,15 +53,16 @@ class SpamChecker:
event: the event to be checked
Returns:
True if the event is spammy.
True or a string if the event is spammy. If a string is returned it
will be used as the error message returned to the user.
"""
for spam_checker in self.spam_checkers:
if spam_checker.check_event_for_spam(event):
if await maybe_awaitable(spam_checker.check_event_for_spam(event)):
return True
return False
def user_may_invite(
async def user_may_invite(
self, inviter_userid: str, invitee_userid: str, room_id: str
) -> bool:
"""Checks if a given user may send an invite
@ -75,14 +79,18 @@ class SpamChecker:
"""
for spam_checker in self.spam_checkers:
if (
spam_checker.user_may_invite(inviter_userid, invitee_userid, room_id)
await maybe_awaitable(
spam_checker.user_may_invite(
inviter_userid, invitee_userid, room_id
)
)
is False
):
return False
return True
def user_may_create_room(self, userid: str) -> bool:
async def user_may_create_room(self, userid: str) -> bool:
"""Checks if a given user may create a room
If this method returns false, the creation request will be rejected.
@ -94,12 +102,15 @@ class SpamChecker:
True if the user may create a room, otherwise False
"""
for spam_checker in self.spam_checkers:
if spam_checker.user_may_create_room(userid) is False:
if (
await maybe_awaitable(spam_checker.user_may_create_room(userid))
is False
):
return False
return True
def user_may_create_room_alias(self, userid: str, room_alias: str) -> bool:
async def user_may_create_room_alias(self, userid: str, room_alias: str) -> bool:
"""Checks if a given user may create a room alias
If this method returns false, the association request will be rejected.
@ -112,12 +123,17 @@ class SpamChecker:
True if the user may create a room alias, otherwise False
"""
for spam_checker in self.spam_checkers:
if spam_checker.user_may_create_room_alias(userid, room_alias) is False:
if (
await maybe_awaitable(
spam_checker.user_may_create_room_alias(userid, room_alias)
)
is False
):
return False
return True
def user_may_publish_room(self, userid: str, room_id: str) -> bool:
async def user_may_publish_room(self, userid: str, room_id: str) -> bool:
"""Checks if a given user may publish a room to the directory
If this method returns false, the publish request will be rejected.
@ -130,12 +146,17 @@ class SpamChecker:
True if the user may publish the room, otherwise False
"""
for spam_checker in self.spam_checkers:
if spam_checker.user_may_publish_room(userid, room_id) is False:
if (
await maybe_awaitable(
spam_checker.user_may_publish_room(userid, room_id)
)
is False
):
return False
return True
def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool:
async def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool:
"""Checks if a user ID or display name are considered "spammy" by this server.
If the server considers a username spammy, then it will not be included in
@ -157,12 +178,12 @@ class SpamChecker:
if checker:
# Make a copy of the user profile object to ensure the spam checker
# cannot modify it.
if checker(user_profile.copy()):
if await maybe_awaitable(checker(user_profile.copy())):
return True
return False
def check_registration_for_spam(
async def check_registration_for_spam(
self,
email_threepid: Optional[dict],
username: Optional[str],
@ -185,7 +206,9 @@ class SpamChecker:
# spam checker
checker = getattr(spam_checker, "check_registration_for_spam", None)
if checker:
behaviour = checker(email_threepid, username, request_info)
behaviour = await maybe_awaitable(
checker(email_threepid, username, request_info)
)
assert isinstance(behaviour, RegistrationBehaviour)
if behaviour != RegistrationBehaviour.ALLOW:
return behaviour

View file

@ -78,6 +78,7 @@ class FederationBase:
ctx = current_context()
@defer.inlineCallbacks
def callback(_, pdu: EventBase):
with PreserveLoggingContext(ctx):
if not check_event_content_hash(pdu):
@ -105,7 +106,11 @@ class FederationBase:
)
return redacted_event
if self.spam_checker.check_event_for_spam(pdu):
result = yield defer.ensureDeferred(
self.spam_checker.check_event_for_spam(pdu)
)
if result:
logger.warning(
"Event contains spam, redacting %s: %s",
pdu.event_id,

View file

@ -14,7 +14,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import logging
import time
import unicodedata
@ -59,6 +58,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.module_api import ModuleApi
from synapse.types import JsonDict, Requester, UserID
from synapse.util import stringutils as stringutils
from synapse.util.async_helpers import maybe_awaitable
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.threepids import canonicalise_email
@ -1639,6 +1639,6 @@ class PasswordProvider:
# This might return an awaitable, if it does block the log out
# until it completes.
result = g(user_id=user_id, device_id=device_id, access_token=access_token,)
if inspect.isawaitable(result):
await result
await maybe_awaitable(
g(user_id=user_id, device_id=device_id, access_token=access_token,)
)

View file

@ -133,7 +133,9 @@ class DirectoryHandler(BaseHandler):
403, "You must be in the room to create an alias for it"
)
if not self.spam_checker.user_may_create_room_alias(user_id, room_alias):
if not await self.spam_checker.user_may_create_room_alias(
user_id, room_alias
):
raise AuthError(403, "This user is not permitted to create this alias")
if not self.config.is_alias_creation_allowed(
@ -409,7 +411,7 @@ class DirectoryHandler(BaseHandler):
"""
user_id = requester.user.to_string()
if not self.spam_checker.user_may_publish_room(user_id, room_id):
if not await self.spam_checker.user_may_publish_room(user_id, room_id):
raise AuthError(
403, "This user is not permitted to publish rooms to the room list"
)

View file

@ -1593,7 +1593,7 @@ class FederationHandler(BaseHandler):
if self.hs.config.block_non_admin_invites:
raise SynapseError(403, "This server does not accept room invites")
if not self.spam_checker.user_may_invite(
if not await self.spam_checker.user_may_invite(
event.sender, event.state_key, event.room_id
):
raise SynapseError(

View file

@ -744,7 +744,7 @@ class EventCreationHandler:
event.sender,
)
spam_error = self.spam_checker.check_event_for_spam(event)
spam_error = await self.spam_checker.check_event_for_spam(event)
if spam_error:
if not isinstance(spam_error, str):
spam_error = "Spam is not permitted here"

View file

@ -18,7 +18,6 @@ from typing import List, Tuple
from synapse.appservice import ApplicationService
from synapse.handlers._base import BaseHandler
from synapse.types import JsonDict, ReadReceipt, get_domain_from_id
from synapse.util.async_helpers import maybe_awaitable
logger = logging.getLogger(__name__)
@ -98,10 +97,8 @@ class ReceiptsHandler(BaseHandler):
self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids)
# Note that the min here shouldn't be relied upon to be accurate.
await maybe_awaitable(
self.hs.get_pusherpool().on_new_receipts(
min_batch_id, max_batch_id, affected_room_ids
)
await self.hs.get_pusherpool().on_new_receipts(
min_batch_id, max_batch_id, affected_room_ids
)
return True

View file

@ -187,7 +187,7 @@ class RegistrationHandler(BaseHandler):
"""
self.check_registration_ratelimit(address)
result = self.spam_checker.check_registration_for_spam(
result = await self.spam_checker.check_registration_for_spam(
threepid, localpart, user_agent_ips or [],
)

View file

@ -358,7 +358,7 @@ class RoomCreationHandler(BaseHandler):
"""
user_id = requester.user.to_string()
if not self.spam_checker.user_may_create_room(user_id):
if not await self.spam_checker.user_may_create_room(user_id):
raise SynapseError(403, "You are not permitted to create rooms")
creation_content = {
@ -609,7 +609,7 @@ class RoomCreationHandler(BaseHandler):
403, "You are not permitted to create rooms", Codes.FORBIDDEN
)
if not is_requester_admin and not self.spam_checker.user_may_create_room(
if not is_requester_admin and not await self.spam_checker.user_may_create_room(
user_id
):
raise SynapseError(403, "You are not permitted to create rooms")

View file

@ -408,7 +408,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
)
block_invite = True
if not self.spam_checker.user_may_invite(
if not await self.spam_checker.user_may_invite(
requester.user.to_string(), target.to_string(), room_id
):
logger.info("Blocking invite due to spam checker")

View file

@ -81,11 +81,11 @@ class UserDirectoryHandler(StateDeltasHandler):
results = await self.store.search_user_dir(user_id, search_term, limit)
# Remove any spammy users from the results.
results["results"] = [
user
for user in results["results"]
if not self.spam_checker.check_username_for_spam(user)
]
non_spammy_users = []
for user in results["results"]:
if not await self.spam_checker.check_username_for_spam(user):
non_spammy_users.append(user)
results["results"] = non_spammy_users
return results

View file

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import logging
import threading
from functools import wraps
@ -25,6 +24,7 @@ from twisted.internet import defer
from synapse.logging.context import LoggingContext, PreserveLoggingContext
from synapse.logging.opentracing import noop_context_manager, start_active_span
from synapse.util.async_helpers import maybe_awaitable
if TYPE_CHECKING:
import resource
@ -206,12 +206,7 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar
if bg_start_span:
ctx = start_active_span(desc, tags={"request_id": context.request})
with ctx:
result = func(*args, **kwargs)
if inspect.isawaitable(result):
result = await result
return result
return await maybe_awaitable(func(*args, **kwargs))
except Exception:
logger.exception(
"Background process '%s' threw an exception", desc,

View file

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import logging
import os
import shutil
@ -21,6 +20,7 @@ from typing import Optional
from synapse.config._base import Config
from synapse.logging.context import defer_to_thread, run_in_background
from synapse.util.async_helpers import maybe_awaitable
from ._base import FileInfo, Responder
from .media_storage import FileResponder
@ -91,16 +91,14 @@ class StorageProviderWrapper(StorageProvider):
if self.store_synchronous:
# store_file is supposed to return an Awaitable, but guard
# against improper implementations.
result = self.backend.store_file(path, file_info)
if inspect.isawaitable(result):
return await result
return await maybe_awaitable(self.backend.store_file(path, file_info))
else:
# TODO: Handle errors.
async def store():
try:
result = self.backend.store_file(path, file_info)
if inspect.isawaitable(result):
return await result
return await maybe_awaitable(
self.backend.store_file(path, file_info)
)
except Exception:
logger.exception("Error storing file")
@ -110,9 +108,7 @@ class StorageProviderWrapper(StorageProvider):
async def fetch(self, path, file_info):
# store_file is supposed to return an Awaitable, but guard
# against improper implementations.
result = self.backend.fetch(path, file_info)
if inspect.isawaitable(result):
return await result
return await maybe_awaitable(self.backend.fetch(path, file_info))
class FileStorageProviderBackend(StorageProvider):

View file

@ -618,7 +618,7 @@ class HomeServer(metaclass=abc.ABCMeta):
return StatsHandler(self)
@cache_in_self
def get_spam_checker(self):
def get_spam_checker(self) -> SpamChecker:
return SpamChecker(self)
@cache_in_self

View file

@ -15,10 +15,12 @@
# limitations under the License.
import collections
import inspect
import logging
from contextlib import contextmanager
from typing import (
Any,
Awaitable,
Callable,
Dict,
Hashable,
@ -542,11 +544,11 @@ class DoneAwaitable:
raise StopIteration(self.value)
def maybe_awaitable(value):
def maybe_awaitable(value: Union[Awaitable[R], R]) -> Awaitable[R]:
"""Convert a value to an awaitable if not already an awaitable.
"""
if hasattr(value, "__await__"):
if inspect.isawaitable(value):
assert isinstance(value, Awaitable)
return value
return DoneAwaitable(value)

View file

@ -12,13 +12,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import logging
from twisted.internet import defer
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.async_helpers import maybe_awaitable
logger = logging.getLogger(__name__)
@ -105,10 +105,7 @@ class Signal:
async def do(observer):
try:
result = observer(*args, **kwargs)
if inspect.isawaitable(result):
result = await result
return result
return await maybe_awaitable(observer(*args, **kwargs))
except Exception as e:
logger.warning(
"%s signal observer %s failed: %r", self.name, observer, e,

View file

@ -270,7 +270,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
spam_checker = self.hs.get_spam_checker()
class AllowAll:
def check_username_for_spam(self, user_profile):
async def check_username_for_spam(self, user_profile):
# Allow all users.
return False
@ -283,7 +283,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# Configure a spam checker that filters all users.
class BlockAll:
def check_username_for_spam(self, user_profile):
async def check_username_for_spam(self, user_profile):
# All users are spammy.
return True