0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-15 08:13:48 +01:00

Prevent the media store from writing outside of the configured directory

Also tighten validation of server names by forbidding invalid characters
in IPv6 addresses and empty domain labels.
This commit is contained in:
Sean Quah 2021-11-19 13:39:15 +00:00
parent 9f9d82aa84
commit 91f2bd0907
5 changed files with 483 additions and 50 deletions

View file

@ -29,7 +29,7 @@ from synapse.api.errors import Codes, SynapseError, cs_error
from synapse.http.server import finish_request, respond_with_json from synapse.http.server import finish_request, respond_with_json
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable from synapse.logging.context import make_deferred_yieldable
from synapse.util.stringutils import is_ascii from synapse.util.stringutils import is_ascii, parse_and_validate_server_name
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -51,6 +51,19 @@ TEXT_CONTENT_TYPES = [
def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]: def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
"""Parses the server name, media ID and optional file name from the request URI
Also performs some rough validation on the server name.
Args:
request: The `Request`.
Returns:
A tuple containing the parsed server name, media ID and optional file name.
Raises:
SynapseError(404): if parsing or validation fail for any reason
"""
try: try:
# The type on postpath seems incorrect in Twisted 21.2.0. # The type on postpath seems incorrect in Twisted 21.2.0.
postpath: List[bytes] = request.postpath # type: ignore postpath: List[bytes] = request.postpath # type: ignore
@ -62,6 +75,9 @@ def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
server_name = server_name_bytes.decode("utf-8") server_name = server_name_bytes.decode("utf-8")
media_id = media_id_bytes.decode("utf8") media_id = media_id_bytes.decode("utf8")
# Validate the server name, raising if invalid
parse_and_validate_server_name(server_name)
file_name = None file_name = None
if len(postpath) > 2: if len(postpath) > 2:
try: try:

View file

@ -16,7 +16,8 @@
import functools import functools
import os import os
import re import re
from typing import Any, Callable, List, TypeVar, cast import string
from typing import Any, Callable, List, TypeVar, Union, cast
NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d") NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d")
@ -37,6 +38,85 @@ def _wrap_in_base_path(func: F) -> F:
return cast(F, _wrapped) return cast(F, _wrapped)
GetPathMethod = TypeVar(
"GetPathMethod", bound=Union[Callable[..., str], Callable[..., List[str]]]
)
def _wrap_with_jail_check(func: GetPathMethod) -> GetPathMethod:
"""Wraps a path-returning method to check that the returned path(s) do not escape
the media store directory.
The check is not expected to ever fail, unless `func` is missing a call to
`_validate_path_component`, or `_validate_path_component` is buggy.
Args:
func: The `MediaFilePaths` method to wrap. The method may return either a single
path, or a list of paths. Returned paths may be either absolute or relative.
Returns:
The method, wrapped with a check to ensure that the returned path(s) lie within
the media store directory. Raises a `ValueError` if the check fails.
"""
@functools.wraps(func)
def _wrapped(
self: "MediaFilePaths", *args: Any, **kwargs: Any
) -> Union[str, List[str]]:
path_or_paths = func(self, *args, **kwargs)
if isinstance(path_or_paths, list):
paths_to_check = path_or_paths
else:
paths_to_check = [path_or_paths]
for path in paths_to_check:
# path may be an absolute or relative path, depending on the method being
# wrapped. When "appending" an absolute path, `os.path.join` discards the
# previous path, which is desired here.
normalized_path = os.path.normpath(os.path.join(self.real_base_path, path))
if (
os.path.commonpath([normalized_path, self.real_base_path])
!= self.real_base_path
):
raise ValueError(f"Invalid media store path: {path!r}")
return path_or_paths
return cast(GetPathMethod, _wrapped)
ALLOWED_CHARACTERS = set(
string.ascii_letters
+ string.digits
+ "_-"
+ ".[]:" # Domain names, IPv6 addresses and ports in server names
)
FORBIDDEN_NAMES = {
"",
os.path.curdir, # "." for the current platform
os.path.pardir, # ".." for the current platform
}
def _validate_path_component(name: str) -> str:
"""Checks that the given string can be safely used as a path component
Args:
name: The path component to check.
Returns:
The path component if valid.
Raises:
ValueError: If `name` cannot be safely used as a path component.
"""
if not ALLOWED_CHARACTERS.issuperset(name) or name in FORBIDDEN_NAMES:
raise ValueError(f"Invalid path component: {name!r}")
return name
class MediaFilePaths: class MediaFilePaths:
"""Describes where files are stored on disk. """Describes where files are stored on disk.
@ -48,22 +128,46 @@ class MediaFilePaths:
def __init__(self, primary_base_path: str): def __init__(self, primary_base_path: str):
self.base_path = primary_base_path self.base_path = primary_base_path
# The media store directory, with all symlinks resolved.
self.real_base_path = os.path.realpath(primary_base_path)
# Refuse to initialize if paths cannot be validated correctly for the current
# platform.
assert os.path.sep not in ALLOWED_CHARACTERS
assert os.path.altsep not in ALLOWED_CHARACTERS
# On Windows, paths have all sorts of weirdness which `_validate_path_component`
# does not consider. In any case, the remote media store can't work correctly
# for certain homeservers there, since ":"s aren't allowed in paths.
assert os.name == "posix"
@_wrap_with_jail_check
def local_media_filepath_rel(self, media_id: str) -> str: def local_media_filepath_rel(self, media_id: str) -> str:
return os.path.join("local_content", media_id[0:2], media_id[2:4], media_id[4:]) return os.path.join(
"local_content",
_validate_path_component(media_id[0:2]),
_validate_path_component(media_id[2:4]),
_validate_path_component(media_id[4:]),
)
local_media_filepath = _wrap_in_base_path(local_media_filepath_rel) local_media_filepath = _wrap_in_base_path(local_media_filepath_rel)
@_wrap_with_jail_check
def local_media_thumbnail_rel( def local_media_thumbnail_rel(
self, media_id: str, width: int, height: int, content_type: str, method: str self, media_id: str, width: int, height: int, content_type: str, method: str
) -> str: ) -> str:
top_level_type, sub_type = content_type.split("/") top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
return os.path.join( return os.path.join(
"local_thumbnails", media_id[0:2], media_id[2:4], media_id[4:], file_name "local_thumbnails",
_validate_path_component(media_id[0:2]),
_validate_path_component(media_id[2:4]),
_validate_path_component(media_id[4:]),
_validate_path_component(file_name),
) )
local_media_thumbnail = _wrap_in_base_path(local_media_thumbnail_rel) local_media_thumbnail = _wrap_in_base_path(local_media_thumbnail_rel)
@_wrap_with_jail_check
def local_media_thumbnail_dir(self, media_id: str) -> str: def local_media_thumbnail_dir(self, media_id: str) -> str:
""" """
Retrieve the local store path of thumbnails of a given media_id Retrieve the local store path of thumbnails of a given media_id
@ -76,18 +180,24 @@ class MediaFilePaths:
return os.path.join( return os.path.join(
self.base_path, self.base_path,
"local_thumbnails", "local_thumbnails",
media_id[0:2], _validate_path_component(media_id[0:2]),
media_id[2:4], _validate_path_component(media_id[2:4]),
media_id[4:], _validate_path_component(media_id[4:]),
) )
@_wrap_with_jail_check
def remote_media_filepath_rel(self, server_name: str, file_id: str) -> str: def remote_media_filepath_rel(self, server_name: str, file_id: str) -> str:
return os.path.join( return os.path.join(
"remote_content", server_name, file_id[0:2], file_id[2:4], file_id[4:] "remote_content",
_validate_path_component(server_name),
_validate_path_component(file_id[0:2]),
_validate_path_component(file_id[2:4]),
_validate_path_component(file_id[4:]),
) )
remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel) remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel)
@_wrap_with_jail_check
def remote_media_thumbnail_rel( def remote_media_thumbnail_rel(
self, self,
server_name: str, server_name: str,
@ -101,11 +211,11 @@ class MediaFilePaths:
file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
return os.path.join( return os.path.join(
"remote_thumbnail", "remote_thumbnail",
server_name, _validate_path_component(server_name),
file_id[0:2], _validate_path_component(file_id[0:2]),
file_id[2:4], _validate_path_component(file_id[2:4]),
file_id[4:], _validate_path_component(file_id[4:]),
file_name, _validate_path_component(file_name),
) )
remote_media_thumbnail = _wrap_in_base_path(remote_media_thumbnail_rel) remote_media_thumbnail = _wrap_in_base_path(remote_media_thumbnail_rel)
@ -113,6 +223,7 @@ class MediaFilePaths:
# Legacy path that was used to store thumbnails previously. # Legacy path that was used to store thumbnails previously.
# Should be removed after some time, when most of the thumbnails are stored # Should be removed after some time, when most of the thumbnails are stored
# using the new path. # using the new path.
@_wrap_with_jail_check
def remote_media_thumbnail_rel_legacy( def remote_media_thumbnail_rel_legacy(
self, server_name: str, file_id: str, width: int, height: int, content_type: str self, server_name: str, file_id: str, width: int, height: int, content_type: str
) -> str: ) -> str:
@ -120,43 +231,66 @@ class MediaFilePaths:
file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type) file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
return os.path.join( return os.path.join(
"remote_thumbnail", "remote_thumbnail",
server_name, _validate_path_component(server_name),
file_id[0:2], _validate_path_component(file_id[0:2]),
file_id[2:4], _validate_path_component(file_id[2:4]),
file_id[4:], _validate_path_component(file_id[4:]),
file_name, _validate_path_component(file_name),
) )
def remote_media_thumbnail_dir(self, server_name: str, file_id: str) -> str: def remote_media_thumbnail_dir(self, server_name: str, file_id: str) -> str:
return os.path.join( return os.path.join(
self.base_path, self.base_path,
"remote_thumbnail", "remote_thumbnail",
server_name, _validate_path_component(server_name),
file_id[0:2], _validate_path_component(file_id[0:2]),
file_id[2:4], _validate_path_component(file_id[2:4]),
file_id[4:], _validate_path_component(file_id[4:]),
) )
@_wrap_with_jail_check
def url_cache_filepath_rel(self, media_id: str) -> str: def url_cache_filepath_rel(self, media_id: str) -> str:
if NEW_FORMAT_ID_RE.match(media_id): if NEW_FORMAT_ID_RE.match(media_id):
# Media id is of the form <DATE><RANDOM_STRING> # Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf # E.g.: 2017-09-28-fsdRDt24DS234dsf
return os.path.join("url_cache", media_id[:10], media_id[11:]) return os.path.join(
"url_cache",
_validate_path_component(media_id[:10]),
_validate_path_component(media_id[11:]),
)
else: else:
return os.path.join("url_cache", media_id[0:2], media_id[2:4], media_id[4:]) return os.path.join(
"url_cache",
_validate_path_component(media_id[0:2]),
_validate_path_component(media_id[2:4]),
_validate_path_component(media_id[4:]),
)
url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel) url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel)
@_wrap_with_jail_check
def url_cache_filepath_dirs_to_delete(self, media_id: str) -> List[str]: def url_cache_filepath_dirs_to_delete(self, media_id: str) -> List[str]:
"The dirs to try and remove if we delete the media_id file" "The dirs to try and remove if we delete the media_id file"
if NEW_FORMAT_ID_RE.match(media_id): if NEW_FORMAT_ID_RE.match(media_id):
return [os.path.join(self.base_path, "url_cache", media_id[:10])] return [
os.path.join(
self.base_path, "url_cache", _validate_path_component(media_id[:10])
)
]
else: else:
return [ return [
os.path.join(self.base_path, "url_cache", media_id[0:2], media_id[2:4]), os.path.join(
os.path.join(self.base_path, "url_cache", media_id[0:2]), self.base_path,
"url_cache",
_validate_path_component(media_id[0:2]),
_validate_path_component(media_id[2:4]),
),
os.path.join(
self.base_path, "url_cache", _validate_path_component(media_id[0:2])
),
] ]
@_wrap_with_jail_check
def url_cache_thumbnail_rel( def url_cache_thumbnail_rel(
self, media_id: str, width: int, height: int, content_type: str, method: str self, media_id: str, width: int, height: int, content_type: str, method: str
) -> str: ) -> str:
@ -168,37 +302,46 @@ class MediaFilePaths:
if NEW_FORMAT_ID_RE.match(media_id): if NEW_FORMAT_ID_RE.match(media_id):
return os.path.join( return os.path.join(
"url_cache_thumbnails", media_id[:10], media_id[11:], file_name "url_cache_thumbnails",
_validate_path_component(media_id[:10]),
_validate_path_component(media_id[11:]),
_validate_path_component(file_name),
) )
else: else:
return os.path.join( return os.path.join(
"url_cache_thumbnails", "url_cache_thumbnails",
media_id[0:2], _validate_path_component(media_id[0:2]),
media_id[2:4], _validate_path_component(media_id[2:4]),
media_id[4:], _validate_path_component(media_id[4:]),
file_name, _validate_path_component(file_name),
) )
url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel) url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel)
@_wrap_with_jail_check
def url_cache_thumbnail_directory_rel(self, media_id: str) -> str: def url_cache_thumbnail_directory_rel(self, media_id: str) -> str:
# Media id is of the form <DATE><RANDOM_STRING> # Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf # E.g.: 2017-09-28-fsdRDt24DS234dsf
if NEW_FORMAT_ID_RE.match(media_id): if NEW_FORMAT_ID_RE.match(media_id):
return os.path.join("url_cache_thumbnails", media_id[:10], media_id[11:]) return os.path.join(
"url_cache_thumbnails",
_validate_path_component(media_id[:10]),
_validate_path_component(media_id[11:]),
)
else: else:
return os.path.join( return os.path.join(
"url_cache_thumbnails", "url_cache_thumbnails",
media_id[0:2], _validate_path_component(media_id[0:2]),
media_id[2:4], _validate_path_component(media_id[2:4]),
media_id[4:], _validate_path_component(media_id[4:]),
) )
url_cache_thumbnail_directory = _wrap_in_base_path( url_cache_thumbnail_directory = _wrap_in_base_path(
url_cache_thumbnail_directory_rel url_cache_thumbnail_directory_rel
) )
@_wrap_with_jail_check
def url_cache_thumbnail_dirs_to_delete(self, media_id: str) -> List[str]: def url_cache_thumbnail_dirs_to_delete(self, media_id: str) -> List[str]:
"The dirs to try and remove if we delete the media_id thumbnails" "The dirs to try and remove if we delete the media_id thumbnails"
# Media id is of the form <DATE><RANDOM_STRING> # Media id is of the form <DATE><RANDOM_STRING>
@ -206,21 +349,35 @@ class MediaFilePaths:
if NEW_FORMAT_ID_RE.match(media_id): if NEW_FORMAT_ID_RE.match(media_id):
return [ return [
os.path.join( os.path.join(
self.base_path, "url_cache_thumbnails", media_id[:10], media_id[11:] self.base_path,
"url_cache_thumbnails",
_validate_path_component(media_id[:10]),
_validate_path_component(media_id[11:]),
),
os.path.join(
self.base_path,
"url_cache_thumbnails",
_validate_path_component(media_id[:10]),
), ),
os.path.join(self.base_path, "url_cache_thumbnails", media_id[:10]),
] ]
else: else:
return [ return [
os.path.join( os.path.join(
self.base_path, self.base_path,
"url_cache_thumbnails", "url_cache_thumbnails",
media_id[0:2], _validate_path_component(media_id[0:2]),
media_id[2:4], _validate_path_component(media_id[2:4]),
media_id[4:], _validate_path_component(media_id[4:]),
), ),
os.path.join( os.path.join(
self.base_path, "url_cache_thumbnails", media_id[0:2], media_id[2:4] self.base_path,
"url_cache_thumbnails",
_validate_path_component(media_id[0:2]),
_validate_path_component(media_id[2:4]),
),
os.path.join(
self.base_path,
"url_cache_thumbnails",
_validate_path_component(media_id[0:2]),
), ),
os.path.join(self.base_path, "url_cache_thumbnails", media_id[0:2]),
] ]

View file

@ -19,6 +19,8 @@ import string
from collections.abc import Iterable from collections.abc import Iterable
from typing import Optional, Tuple from typing import Optional, Tuple
from netaddr import valid_ipv6
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
_string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@" _string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
@ -97,7 +99,10 @@ def parse_server_name(server_name: str) -> Tuple[str, Optional[int]]:
raise ValueError("Invalid server name '%s'" % server_name) raise ValueError("Invalid server name '%s'" % server_name)
VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z.-]+\\Z") # An approximation of the domain name syntax in RFC 1035, section 2.3.1.
# NB: "\Z" is not equivalent to "$".
# The latter will match the position before a "\n" at the end of a string.
VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z-]+(?:\\.[0-9a-zA-Z-]+)*\\Z")
def parse_and_validate_server_name(server_name: str) -> Tuple[str, Optional[int]]: def parse_and_validate_server_name(server_name: str) -> Tuple[str, Optional[int]]:
@ -122,13 +127,15 @@ def parse_and_validate_server_name(server_name: str) -> Tuple[str, Optional[int]
if host[0] == "[": if host[0] == "[":
if host[-1] != "]": if host[-1] != "]":
raise ValueError("Mismatched [...] in server name '%s'" % (server_name,)) raise ValueError("Mismatched [...] in server name '%s'" % (server_name,))
return host, port
# otherwise it should only be alphanumerics. # valid_ipv6 raises when given an empty string
if not VALID_HOST_REGEX.match(host): ipv6_address = host[1:-1]
if not ipv6_address or not valid_ipv6(ipv6_address):
raise ValueError( raise ValueError(
"Server name '%s' contains invalid characters" % (server_name,) "Server name '%s' is not a valid IPv6 address" % (server_name,)
) )
elif not VALID_HOST_REGEX.match(host):
raise ValueError("Server name '%s' has an invalid format" % (server_name,))
return host, port return host, port

View file

@ -36,8 +36,11 @@ class ServerNameTestCase(unittest.TestCase):
"localhost:http", # non-numeric port "localhost:http", # non-numeric port
"1234]", # smells like ipv6 literal but isn't "1234]", # smells like ipv6 literal but isn't
"[1234", "[1234",
"[1.2.3.4]",
"underscore_.com", "underscore_.com",
"percent%65.com", "percent%65.com",
"newline.com\n",
".empty-label.com",
"1234:5678:80", # too many colons "1234:5678:80", # too many colons
] ]
for i in test_data: for i in test_data:

View file

@ -11,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect
from typing import Iterable
from synapse.rest.media.v1.filepath import MediaFilePaths from synapse.rest.media.v1.filepath import MediaFilePaths
from tests import unittest from tests import unittest
@ -236,3 +239,250 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/url_cache_thumbnails/Ge", "/media_store/url_cache_thumbnails/Ge",
], ],
) )
def test_server_name_validation(self):
"""Test validation of server names"""
self._test_path_validation(
[
"remote_media_filepath_rel",
"remote_media_filepath",
"remote_media_thumbnail_rel",
"remote_media_thumbnail",
"remote_media_thumbnail_rel_legacy",
"remote_media_thumbnail_dir",
],
parameter="server_name",
valid_values=[
"matrix.org",
"matrix.org:8448",
"matrix-federation.matrix.org",
"matrix-federation.matrix.org:8448",
"10.1.12.123",
"10.1.12.123:8448",
"[fd00:abcd::ffff]",
"[fd00:abcd::ffff]:8448",
],
invalid_values=[
"/matrix.org",
"matrix.org/..",
"matrix.org\x00",
"",
".",
"..",
"/",
],
)
def test_file_id_validation(self):
"""Test validation of local, remote and legacy URL cache file / media IDs"""
# File / media IDs get split into three parts to form paths, consisting of the
# first two characters, next two characters and rest of the ID.
valid_file_ids = [
"GerZNDnDZVjsOtardLuwfIBg",
# Unexpected, but produces an acceptable path:
"GerZN", # "N" becomes the last directory
]
invalid_file_ids = [
"/erZNDnDZVjsOtardLuwfIBg",
"Ge/ZNDnDZVjsOtardLuwfIBg",
"GerZ/DnDZVjsOtardLuwfIBg",
"GerZ/..",
"G\x00rZNDnDZVjsOtardLuwfIBg",
"Ger\x00NDnDZVjsOtardLuwfIBg",
"GerZNDnDZVjsOtardLuwfIBg\x00",
"",
"Ge",
"GerZ",
"GerZ.",
"..rZNDnDZVjsOtardLuwfIBg",
"Ge..NDnDZVjsOtardLuwfIBg",
"GerZ..",
"GerZ/",
]
self._test_path_validation(
[
"local_media_filepath_rel",
"local_media_filepath",
"local_media_thumbnail_rel",
"local_media_thumbnail",
"local_media_thumbnail_dir",
# Legacy URL cache media IDs
"url_cache_filepath_rel",
"url_cache_filepath",
# `url_cache_filepath_dirs_to_delete` is tested below.
"url_cache_thumbnail_rel",
"url_cache_thumbnail",
"url_cache_thumbnail_directory_rel",
"url_cache_thumbnail_directory",
"url_cache_thumbnail_dirs_to_delete",
],
parameter="media_id",
valid_values=valid_file_ids,
invalid_values=invalid_file_ids,
)
# `url_cache_filepath_dirs_to_delete` ignores what would be the last path
# component, so only the first 4 characters matter.
self._test_path_validation(
[
"url_cache_filepath_dirs_to_delete",
],
parameter="media_id",
valid_values=valid_file_ids,
invalid_values=[
"/erZNDnDZVjsOtardLuwfIBg",
"Ge/ZNDnDZVjsOtardLuwfIBg",
"G\x00rZNDnDZVjsOtardLuwfIBg",
"Ger\x00NDnDZVjsOtardLuwfIBg",
"",
"Ge",
"..rZNDnDZVjsOtardLuwfIBg",
"Ge..NDnDZVjsOtardLuwfIBg",
],
)
self._test_path_validation(
[
"remote_media_filepath_rel",
"remote_media_filepath",
"remote_media_thumbnail_rel",
"remote_media_thumbnail",
"remote_media_thumbnail_rel_legacy",
"remote_media_thumbnail_dir",
],
parameter="file_id",
valid_values=valid_file_ids,
invalid_values=invalid_file_ids,
)
def test_url_cache_media_id_validation(self):
"""Test validation of URL cache media IDs"""
self._test_path_validation(
[
"url_cache_filepath_rel",
"url_cache_filepath",
# `url_cache_filepath_dirs_to_delete` only cares about the date prefix
"url_cache_thumbnail_rel",
"url_cache_thumbnail",
"url_cache_thumbnail_directory_rel",
"url_cache_thumbnail_directory",
"url_cache_thumbnail_dirs_to_delete",
],
parameter="media_id",
valid_values=[
"2020-01-02_GerZNDnDZVjsOtar",
"2020-01-02_G", # Unexpected, but produces an acceptable path
],
invalid_values=[
"2020-01-02",
"2020-01-02-",
"2020-01-02-.",
"2020-01-02-..",
"2020-01-02-/",
"2020-01-02-/GerZNDnDZVjsOtar",
"2020-01-02-GerZNDnDZVjsOtar/..",
"2020-01-02-GerZNDnDZVjsOtar\x00",
],
)
def test_content_type_validation(self):
"""Test validation of thumbnail content types"""
self._test_path_validation(
[
"local_media_thumbnail_rel",
"local_media_thumbnail",
"remote_media_thumbnail_rel",
"remote_media_thumbnail",
"remote_media_thumbnail_rel_legacy",
"url_cache_thumbnail_rel",
"url_cache_thumbnail",
],
parameter="content_type",
valid_values=[
"image/jpeg",
],
invalid_values=[
"", # ValueError: not enough values to unpack
"image/jpeg/abc", # ValueError: too many values to unpack
"image/jpeg\x00",
],
)
def test_thumbnail_method_validation(self):
"""Test validation of thumbnail methods"""
self._test_path_validation(
[
"local_media_thumbnail_rel",
"local_media_thumbnail",
"remote_media_thumbnail_rel",
"remote_media_thumbnail",
"url_cache_thumbnail_rel",
"url_cache_thumbnail",
],
parameter="method",
valid_values=[
"crop",
"scale",
],
invalid_values=[
"/scale",
"scale/..",
"scale\x00",
"/",
],
)
def _test_path_validation(
self,
methods: Iterable[str],
parameter: str,
valid_values: Iterable[str],
invalid_values: Iterable[str],
):
"""Test that the specified methods validate the named parameter as expected
Args:
methods: The names of `MediaFilePaths` methods to test
parameter: The name of the parameter to test
valid_values: A list of parameter values that are expected to be accepted
invalid_values: A list of parameter values that are expected to be rejected
Raises:
AssertionError: If a value was accepted when it should have failed
validation.
ValueError: If a value failed validation when it should have been accepted.
"""
for method in methods:
get_path = getattr(self.filepaths, method)
parameters = inspect.signature(get_path).parameters
kwargs = {
"server_name": "matrix.org",
"media_id": "GerZNDnDZVjsOtardLuwfIBg",
"file_id": "GerZNDnDZVjsOtardLuwfIBg",
"width": 800,
"height": 600,
"content_type": "image/jpeg",
"method": "scale",
}
if get_path.__name__.startswith("url_"):
kwargs["media_id"] = "2020-01-02_GerZNDnDZVjsOtar"
kwargs = {k: v for k, v in kwargs.items() if k in parameters}
kwargs.pop(parameter)
for value in valid_values:
kwargs[parameter] = value
get_path(**kwargs)
# No exception should be raised
for value in invalid_values:
with self.assertRaises(ValueError):
kwargs[parameter] = value
path_or_list = get_path(**kwargs)
self.fail(
f"{value!r} unexpectedly passed validation: "
f"{method} returned {path_or_list!r}"
)