Merge branch 'erikj/media_spam_checker' into matrix-org-hotfixes
This commit is contained in:
commit
25757a3d47
1
changelog.d/9297.feature
Normal file
1
changelog.d/9297.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Further improvements to the user experience of registration via single sign-on.
|
1
changelog.d/9302.bugfix
Normal file
1
changelog.d/9302.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Fix new ratelimiting for invites to respect the `ratelimit` flag on application services. Introduced in v1.27.0rc1.
|
1
changelog.d/9310.doc
Normal file
1
changelog.d/9310.doc
Normal file
|
@ -0,0 +1 @@
|
|||
Clarify the sample configuration for changes made to the template loading code.
|
1
changelog.d/9311.feature
Normal file
1
changelog.d/9311.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Add hook to spam checker modules that allow checking file uploads and remote downloads.
|
|
@ -1961,8 +1961,7 @@ sso:
|
|||
#
|
||||
# When rendering, this template is given the following variables:
|
||||
# * redirect_url: the URL that the user will be redirected to after
|
||||
# login. Needs manual escaping (see
|
||||
# https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
|
||||
# login.
|
||||
#
|
||||
# * server_name: the homeserver's name.
|
||||
#
|
||||
|
@ -2040,15 +2039,12 @@ sso:
|
|||
#
|
||||
# When rendering, this template is given the following variables:
|
||||
#
|
||||
# * redirect_url: the URL the user is about to be redirected to. Needs
|
||||
# manual escaping (see
|
||||
# https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
|
||||
# * redirect_url: the URL the user is about to be redirected to.
|
||||
#
|
||||
# * display_url: the same as `redirect_url`, but with the query
|
||||
# parameters stripped. The intention is to have a
|
||||
# human-readable URL to show to users, not to use it as
|
||||
# the final address to redirect to. Needs manual escaping
|
||||
# (see https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
|
||||
# the final address to redirect to.
|
||||
#
|
||||
# * server_name: the homeserver's name.
|
||||
#
|
||||
|
@ -2068,9 +2064,7 @@ sso:
|
|||
# process: 'sso_auth_confirm.html'.
|
||||
#
|
||||
# When rendering, this template is given the following variables:
|
||||
# * redirect_url: the URL the user is about to be redirected to. Needs
|
||||
# manual escaping (see
|
||||
# https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
|
||||
# * redirect_url: the URL the user is about to be redirected to.
|
||||
#
|
||||
# * description: the operation which the user is being asked to confirm
|
||||
#
|
||||
|
|
|
@ -61,6 +61,9 @@ class ExampleSpamChecker:
|
|||
|
||||
async def check_registration_for_spam(self, email_threepid, username, request_info):
|
||||
return RegistrationBehaviour.ALLOW # allow all registrations
|
||||
|
||||
async def check_media_file_for_spam(self, file_wrapper, file_info):
|
||||
return False # allow all media
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
|
|
@ -106,8 +106,7 @@ class SSOConfig(Config):
|
|||
#
|
||||
# When rendering, this template is given the following variables:
|
||||
# * redirect_url: the URL that the user will be redirected to after
|
||||
# login. Needs manual escaping (see
|
||||
# https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
|
||||
# login.
|
||||
#
|
||||
# * server_name: the homeserver's name.
|
||||
#
|
||||
|
@ -185,15 +184,12 @@ class SSOConfig(Config):
|
|||
#
|
||||
# When rendering, this template is given the following variables:
|
||||
#
|
||||
# * redirect_url: the URL the user is about to be redirected to. Needs
|
||||
# manual escaping (see
|
||||
# https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
|
||||
# * redirect_url: the URL the user is about to be redirected to.
|
||||
#
|
||||
# * display_url: the same as `redirect_url`, but with the query
|
||||
# parameters stripped. The intention is to have a
|
||||
# human-readable URL to show to users, not to use it as
|
||||
# the final address to redirect to. Needs manual escaping
|
||||
# (see https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
|
||||
# the final address to redirect to.
|
||||
#
|
||||
# * server_name: the homeserver's name.
|
||||
#
|
||||
|
@ -213,9 +209,7 @@ class SSOConfig(Config):
|
|||
# process: 'sso_auth_confirm.html'.
|
||||
#
|
||||
# When rendering, this template is given the following variables:
|
||||
# * redirect_url: the URL the user is about to be redirected to. Needs
|
||||
# manual escaping (see
|
||||
# https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
|
||||
# * redirect_url: the URL the user is about to be redirected to.
|
||||
#
|
||||
# * description: the operation which the user is being asked to confirm
|
||||
#
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
import inspect
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from synapse.rest.media.v1._base import FileInfo
|
||||
from synapse.rest.media.v1.media_storage import ReadableFileWrapper
|
||||
from synapse.spam_checker_api import RegistrationBehaviour
|
||||
from synapse.types import Collection
|
||||
from synapse.util.async_helpers import maybe_awaitable
|
||||
|
@ -214,3 +216,48 @@ class SpamChecker:
|
|||
return behaviour
|
||||
|
||||
return RegistrationBehaviour.ALLOW
|
||||
|
||||
async def check_media_file_for_spam(
|
||||
self, file_wrapper: ReadableFileWrapper, file_info: FileInfo
|
||||
) -> bool:
|
||||
"""Checks if a piece of newly uploaded media should be blocked.
|
||||
|
||||
This will be called for local uploads, downloads of remote media, each
|
||||
thumbnail generated for those, and web pages/images used for URL
|
||||
previews.
|
||||
|
||||
Note that care should be taken to not do blocking IO operations in the
|
||||
main thread. For example, to get the contents of a file a module
|
||||
should do::
|
||||
|
||||
async def check_media_file_for_spam(
|
||||
self, file: ReadableFileWrapper, file_info: FileInfo
|
||||
) -> bool:
|
||||
buffer = BytesIO()
|
||||
await file.write_chunks_to(buffer.write)
|
||||
|
||||
if buffer.getvalue() == b"Hello World":
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
Args:
|
||||
file: An object that allows reading the contents of the media.
|
||||
file_info: Metadata about the file.
|
||||
|
||||
Returns:
|
||||
True if the media should be blocked or False if it should be
|
||||
allowed.
|
||||
"""
|
||||
|
||||
for spam_checker in self.spam_checkers:
|
||||
# For backwards compatibility, only run if the method exists on the
|
||||
# spam checker
|
||||
checker = getattr(spam_checker, "check_media_file_for_spam", None)
|
||||
if checker:
|
||||
spam = await maybe_awaitable(checker(file_wrapper, file_info))
|
||||
if spam:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
|
|
@ -1619,7 +1619,9 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
# We retrieve the room member handler here as to not cause a cyclic dependency
|
||||
member_handler = self.hs.get_room_member_handler()
|
||||
member_handler.ratelimit_invite(event.room_id, event.state_key)
|
||||
# We don't rate limit based on room ID, as that should be done by
|
||||
# sending server.
|
||||
member_handler.ratelimit_invite(None, event.state_key)
|
||||
|
||||
# keep a record of the room version, if we don't yet know it.
|
||||
# (this may get overwritten if we later get a different room version in a
|
||||
|
|
|
@ -156,10 +156,14 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def ratelimit_invite(self, room_id: str, invitee_user_id: str):
|
||||
def ratelimit_invite(self, room_id: Optional[str], invitee_user_id: str):
|
||||
"""Ratelimit invites by room and by target user.
|
||||
|
||||
If room ID is missing then we just rate limit by target user.
|
||||
"""
|
||||
self._invites_per_room_limiter.ratelimit(room_id)
|
||||
if room_id:
|
||||
self._invites_per_room_limiter.ratelimit(room_id)
|
||||
|
||||
self._invites_per_user_limiter.ratelimit(invitee_user_id)
|
||||
|
||||
async def _local_membership_update(
|
||||
|
@ -426,7 +430,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
|||
if effective_membership_state == Membership.INVITE:
|
||||
target_id = target.to_string()
|
||||
if ratelimit:
|
||||
self.ratelimit_invite(room_id, target_id)
|
||||
# Don't ratelimit application services.
|
||||
if not requester.app_service or requester.app_service.is_rate_limited():
|
||||
self.ratelimit_invite(room_id, target_id)
|
||||
|
||||
# block any attempts to invite the server notices mxid
|
||||
if target_id == self._server_notices_mxid:
|
||||
|
|
|
@ -18,6 +18,19 @@
|
|||
font-size: 12px;
|
||||
}
|
||||
|
||||
.username_input.invalid {
|
||||
border-color: #FE2928;
|
||||
}
|
||||
|
||||
.username_input.invalid input, .username_input.invalid label {
|
||||
color: #FE2928;
|
||||
}
|
||||
|
||||
.username_input div, .username_input input {
|
||||
line-height: 18px;
|
||||
font-size: 14px;
|
||||
}
|
||||
|
||||
.username_input label {
|
||||
position: absolute;
|
||||
top: -8px;
|
||||
|
@ -78,6 +91,15 @@
|
|||
display: block;
|
||||
margin-top: 8px;
|
||||
}
|
||||
|
||||
output {
|
||||
padding: 0 14px;
|
||||
display: block;
|
||||
}
|
||||
|
||||
output.error {
|
||||
color: #FE2928;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
|
@ -87,12 +109,13 @@
|
|||
</header>
|
||||
<main>
|
||||
<form method="post" class="form__input" id="form">
|
||||
<div class="username_input">
|
||||
<div class="username_input" id="username_input">
|
||||
<label for="field-username">Username</label>
|
||||
<div class="prefix">@</div>
|
||||
<input type="text" name="username" id="field-username" autofocus required pattern="[a-z0-9\-=_\/\.]+">
|
||||
<input type="text" name="username" id="field-username" autofocus>
|
||||
<div class="postfix">:{{ server_name }}</div>
|
||||
</div>
|
||||
<output for="username_input" id="field-username-output"></output>
|
||||
<input type="submit" value="Continue" class="primary-button">
|
||||
{% if user_attributes %}
|
||||
<section class="idp-pick-details">
|
||||
|
|
|
@ -1,14 +1,24 @@
|
|||
const usernameField = document.getElementById("field-username");
|
||||
const usernameOutput = document.getElementById("field-username-output");
|
||||
const form = document.getElementById("form");
|
||||
|
||||
// needed to validate on change event when no input was changed
|
||||
let needsValidation = true;
|
||||
let isValid = false;
|
||||
|
||||
function throttle(fn, wait) {
|
||||
let timeout;
|
||||
return function() {
|
||||
const throttleFn = function() {
|
||||
const args = Array.from(arguments);
|
||||
if (timeout) {
|
||||
clearTimeout(timeout);
|
||||
}
|
||||
timeout = setTimeout(fn.bind.apply(fn, [null].concat(args)), wait);
|
||||
}
|
||||
};
|
||||
throttleFn.cancelQueued = function() {
|
||||
clearTimeout(timeout);
|
||||
};
|
||||
return throttleFn;
|
||||
}
|
||||
|
||||
function checkUsernameAvailable(username) {
|
||||
|
@ -16,14 +26,14 @@ function checkUsernameAvailable(username) {
|
|||
return fetch(check_uri, {
|
||||
// include the cookie
|
||||
"credentials": "same-origin",
|
||||
}).then((response) => {
|
||||
}).then(function(response) {
|
||||
if(!response.ok) {
|
||||
// for non-200 responses, raise the body of the response as an exception
|
||||
return response.text().then((text) => { throw new Error(text); });
|
||||
} else {
|
||||
return response.json();
|
||||
}
|
||||
}).then((json) => {
|
||||
}).then(function(json) {
|
||||
if(json.error) {
|
||||
return {message: json.error};
|
||||
} else if(json.available) {
|
||||
|
@ -34,33 +44,49 @@ function checkUsernameAvailable(username) {
|
|||
});
|
||||
}
|
||||
|
||||
const allowedUsernameCharacters = new RegExp("^[a-z0-9\\.\\_\\-\\/\\=]+$");
|
||||
const allowedCharactersString = "lowercase letters, digits, ., _, -, /, =";
|
||||
|
||||
function reportError(error) {
|
||||
throttledCheckUsernameAvailable.cancelQueued();
|
||||
usernameOutput.innerText = error;
|
||||
usernameOutput.classList.add("error");
|
||||
usernameField.parentElement.classList.add("invalid");
|
||||
usernameField.focus();
|
||||
}
|
||||
|
||||
function validateUsername(username) {
|
||||
usernameField.setCustomValidity("");
|
||||
if (usernameField.validity.valueMissing) {
|
||||
usernameField.setCustomValidity("Please provide a username");
|
||||
return;
|
||||
isValid = false;
|
||||
needsValidation = false;
|
||||
usernameOutput.innerText = "";
|
||||
usernameField.parentElement.classList.remove("invalid");
|
||||
usernameOutput.classList.remove("error");
|
||||
if (!username) {
|
||||
return reportError("Please provide a username");
|
||||
}
|
||||
if (usernameField.validity.patternMismatch) {
|
||||
usernameField.setCustomValidity("Invalid username, please only use " + allowedCharactersString);
|
||||
return;
|
||||
if (username.length > 255) {
|
||||
return reportError("Too long, please choose something shorter");
|
||||
}
|
||||
usernameField.setCustomValidity("Checking if username is available …");
|
||||
if (!allowedUsernameCharacters.test(username)) {
|
||||
return reportError("Invalid username, please only use " + allowedCharactersString);
|
||||
}
|
||||
usernameOutput.innerText = "Checking if username is available …";
|
||||
throttledCheckUsernameAvailable(username);
|
||||
}
|
||||
|
||||
const throttledCheckUsernameAvailable = throttle(function(username) {
|
||||
const handleError = function(err) {
|
||||
const handleError = function(err) {
|
||||
// don't prevent form submission on error
|
||||
usernameField.setCustomValidity("");
|
||||
console.log(err.message);
|
||||
usernameOutput.innerText = "";
|
||||
isValid = true;
|
||||
};
|
||||
try {
|
||||
checkUsernameAvailable(username).then(function(result) {
|
||||
if (!result.available) {
|
||||
usernameField.setCustomValidity(result.message);
|
||||
usernameField.reportValidity();
|
||||
reportError(result.message);
|
||||
} else {
|
||||
usernameField.setCustomValidity("");
|
||||
isValid = true;
|
||||
usernameOutput.innerText = "";
|
||||
}
|
||||
}, handleError);
|
||||
} catch (err) {
|
||||
|
@ -68,9 +94,23 @@ const throttledCheckUsernameAvailable = throttle(function(username) {
|
|||
}
|
||||
}, 500);
|
||||
|
||||
form.addEventListener("submit", function(evt) {
|
||||
if (needsValidation) {
|
||||
validateUsername(usernameField.value);
|
||||
evt.preventDefault();
|
||||
return;
|
||||
}
|
||||
if (!isValid) {
|
||||
evt.preventDefault();
|
||||
usernameField.focus();
|
||||
return;
|
||||
}
|
||||
});
|
||||
usernameField.addEventListener("input", function(evt) {
|
||||
validateUsername(usernameField.value);
|
||||
});
|
||||
usernameField.addEventListener("change", function(evt) {
|
||||
validateUsername(usernameField.value);
|
||||
if (needsValidation) {
|
||||
validateUsername(usernameField.value);
|
||||
}
|
||||
});
|
||||
|
|
|
@ -16,13 +16,17 @@ import contextlib
|
|||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from typing import IO, TYPE_CHECKING, Any, Optional, Sequence
|
||||
from typing import IO, TYPE_CHECKING, Any, Callable, Optional, Sequence
|
||||
|
||||
import attr
|
||||
|
||||
from twisted.internet.defer import Deferred
|
||||
from twisted.internet.interfaces import IConsumer
|
||||
from twisted.protocols.basic import FileSender
|
||||
|
||||
from synapse.api.errors import NotFoundError
|
||||
from synapse.logging.context import defer_to_thread, make_deferred_yieldable
|
||||
from synapse.util import Clock
|
||||
from synapse.util.file_consumer import BackgroundFileConsumer
|
||||
|
||||
from ._base import FileInfo, Responder
|
||||
|
@ -58,6 +62,8 @@ class MediaStorage:
|
|||
self.local_media_directory = local_media_directory
|
||||
self.filepaths = filepaths
|
||||
self.storage_providers = storage_providers
|
||||
self.spam_checker = hs.get_spam_checker()
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
async def store_file(self, source: IO, file_info: FileInfo) -> str:
|
||||
"""Write `source` to the on disk media store, and also any other
|
||||
|
@ -127,18 +133,29 @@ class MediaStorage:
|
|||
f.flush()
|
||||
f.close()
|
||||
|
||||
spam = await self.spam_checker.check_media_file_for_spam(
|
||||
ReadableFileWrapper(self.clock, fname), file_info
|
||||
)
|
||||
if spam:
|
||||
logger.info("Blocking media due to spam checker")
|
||||
# Note that we'll delete the stored media, due to the
|
||||
# try/except below. The media also won't be stored in
|
||||
# the DB.
|
||||
raise SpamMediaException()
|
||||
|
||||
for provider in self.storage_providers:
|
||||
await provider.store_file(path, file_info)
|
||||
|
||||
finished_called[0] = True
|
||||
|
||||
yield f, fname, finish
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
try:
|
||||
os.remove(fname)
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
|
||||
raise e from None
|
||||
|
||||
if not finished_called:
|
||||
raise Exception("Finished callback not called")
|
||||
|
@ -302,3 +319,39 @@ class FileResponder(Responder):
|
|||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.open_file.close()
|
||||
|
||||
|
||||
class SpamMediaException(NotFoundError):
|
||||
"""The media was blocked by a spam checker, so we simply 404 the request (in
|
||||
the same way as if it was quarantined).
|
||||
"""
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
class ReadableFileWrapper:
|
||||
"""Wrapper that allows reading a file in chunks, yielding to the reactor,
|
||||
and writing to a callback.
|
||||
|
||||
This is simplified `FileSender` that takes an IO object rather than an
|
||||
`IConsumer`.
|
||||
"""
|
||||
|
||||
CHUNK_SIZE = 2 ** 14
|
||||
|
||||
clock = attr.ib(type=Clock)
|
||||
path = attr.ib(type=str)
|
||||
|
||||
async def write_chunks_to(self, callback: Callable[[bytes], None]):
|
||||
"""Reads the file in chunks and calls the callback with each chunk.
|
||||
"""
|
||||
|
||||
with open(self.path, "rb") as file:
|
||||
while True:
|
||||
chunk = file.read(self.CHUNK_SIZE)
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
callback(chunk)
|
||||
|
||||
# We yield to the reactor by sleeping for 0 seconds.
|
||||
await self.clock.sleep(0)
|
||||
|
|
|
@ -22,6 +22,7 @@ from twisted.web.http import Request
|
|||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.http.server import DirectServeJsonResource, respond_with_json
|
||||
from synapse.http.servlet import parse_string
|
||||
from synapse.rest.media.v1.media_storage import SpamMediaException
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
|
@ -86,9 +87,14 @@ class UploadResource(DirectServeJsonResource):
|
|||
# disposition = headers.getRawHeaders(b"Content-Disposition")[0]
|
||||
# TODO(markjh): parse content-dispostion
|
||||
|
||||
content_uri = await self.media_repo.create_content(
|
||||
media_type, upload_name, request.content, content_length, requester.user
|
||||
)
|
||||
try:
|
||||
content_uri = await self.media_repo.create_content(
|
||||
media_type, upload_name, request.content, content_length, requester.user
|
||||
)
|
||||
except SpamMediaException:
|
||||
# For uploading of media we want to respond with a 400, instead of
|
||||
# the default 404, as that would just be confusing.
|
||||
raise SynapseError(400, "Bad content")
|
||||
|
||||
logger.info("Uploaded content with URI %r", content_uri)
|
||||
|
||||
|
|
|
@ -191,53 +191,6 @@ class FederationTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
self.assertEqual(sg, sg2)
|
||||
|
||||
@unittest.override_config(
|
||||
{"rc_invites": {"per_room": {"per_second": 0.5, "burst_count": 3}}}
|
||||
)
|
||||
def test_invite_by_room_ratelimit(self):
|
||||
"""Tests that invites from federation in a room are actually rate-limited.
|
||||
"""
|
||||
other_server = "otherserver"
|
||||
other_user = "@otheruser:" + other_server
|
||||
|
||||
# create the room
|
||||
user_id = self.register_user("kermit", "test")
|
||||
tok = self.login("kermit", "test")
|
||||
room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
|
||||
room_version = self.get_success(self.store.get_room_version(room_id))
|
||||
|
||||
def create_invite_for(local_user):
|
||||
return event_from_pdu_json(
|
||||
{
|
||||
"type": EventTypes.Member,
|
||||
"content": {"membership": "invite"},
|
||||
"room_id": room_id,
|
||||
"sender": other_user,
|
||||
"state_key": local_user,
|
||||
"depth": 32,
|
||||
"prev_events": [],
|
||||
"auth_events": [],
|
||||
"origin_server_ts": self.clock.time_msec(),
|
||||
},
|
||||
room_version,
|
||||
)
|
||||
|
||||
for i in range(3):
|
||||
self.get_success(
|
||||
self.handler.on_invite_request(
|
||||
other_server,
|
||||
create_invite_for("@user-%d:test" % (i,)),
|
||||
room_version,
|
||||
)
|
||||
)
|
||||
|
||||
self.get_failure(
|
||||
self.handler.on_invite_request(
|
||||
other_server, create_invite_for("@user-4:test"), room_version,
|
||||
),
|
||||
exc=LimitExceededError,
|
||||
)
|
||||
|
||||
@unittest.override_config(
|
||||
{"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}}
|
||||
)
|
||||
|
|
|
@ -30,6 +30,8 @@ from twisted.internet import defer
|
|||
from twisted.internet.defer import Deferred
|
||||
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.rest import admin
|
||||
from synapse.rest.client.v1 import login
|
||||
from synapse.rest.media.v1._base import FileInfo
|
||||
from synapse.rest.media.v1.filepath import MediaFilePaths
|
||||
from synapse.rest.media.v1.media_storage import MediaStorage
|
||||
|
@ -37,6 +39,7 @@ from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
|
|||
|
||||
from tests import unittest
|
||||
from tests.server import FakeSite, make_request
|
||||
from tests.utils import default_config
|
||||
|
||||
|
||||
class MediaStorageTests(unittest.HomeserverTestCase):
|
||||
|
@ -398,3 +401,94 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
|||
headers.getRawHeaders(b"X-Robots-Tag"),
|
||||
[b"noindex, nofollow, noarchive, noimageindex"],
|
||||
)
|
||||
|
||||
|
||||
class TestSpamChecker:
|
||||
"""A spam checker module that rejects all media that includes the bytes
|
||||
`evil`.
|
||||
"""
|
||||
|
||||
def __init__(self, config, api):
|
||||
self.config = config
|
||||
self.api = api
|
||||
|
||||
def parse_config(config):
|
||||
return config
|
||||
|
||||
async def check_event_for_spam(self, foo):
|
||||
return False # allow all events
|
||||
|
||||
async def user_may_invite(self, inviter_userid, invitee_userid, room_id):
|
||||
return True # allow all invites
|
||||
|
||||
async def user_may_create_room(self, userid):
|
||||
return True # allow all room creations
|
||||
|
||||
async def user_may_create_room_alias(self, userid, room_alias):
|
||||
return True # allow all room aliases
|
||||
|
||||
async def user_may_publish_room(self, userid, room_id):
|
||||
return True # allow publishing of all rooms
|
||||
|
||||
async def check_media_file_for_spam(self, file_wrapper, file_info) -> bool:
|
||||
buf = BytesIO()
|
||||
await file_wrapper.write_chunks_to(buf.write)
|
||||
|
||||
return b"evil" in buf.getvalue()
|
||||
|
||||
|
||||
class SpamCheckerTestCase(unittest.HomeserverTestCase):
|
||||
servlets = [
|
||||
login.register_servlets,
|
||||
admin.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
self.user = self.register_user("user", "pass")
|
||||
self.tok = self.login("user", "pass")
|
||||
|
||||
# Allow for uploading and downloading to/from the media repo
|
||||
self.media_repo = hs.get_media_repository_resource()
|
||||
self.download_resource = self.media_repo.children[b"download"]
|
||||
self.upload_resource = self.media_repo.children[b"upload"]
|
||||
|
||||
def default_config(self):
|
||||
config = default_config("test")
|
||||
|
||||
config.update(
|
||||
{
|
||||
"spam_checker": [
|
||||
{
|
||||
"module": TestSpamChecker.__module__ + ".TestSpamChecker",
|
||||
"config": {},
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
def test_upload_innocent(self):
|
||||
"""Attempt to upload some innocent data that should be allowed.
|
||||
"""
|
||||
|
||||
image_data = unhexlify(
|
||||
b"89504e470d0a1a0a0000000d4948445200000001000000010806"
|
||||
b"0000001f15c4890000000a49444154789c63000100000500010d"
|
||||
b"0a2db40000000049454e44ae426082"
|
||||
)
|
||||
|
||||
self.helper.upload_media(
|
||||
self.upload_resource, image_data, tok=self.tok, expect_code=200
|
||||
)
|
||||
|
||||
def test_upload_ban(self):
|
||||
"""Attempt to upload some data that includes bytes "evil", which should
|
||||
get rejected by the spam checker.
|
||||
"""
|
||||
|
||||
data = b"Some evil data"
|
||||
|
||||
self.helper.upload_media(
|
||||
self.upload_resource, data, tok=self.tok, expect_code=400
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue