synapse/tests/rest/media/v1/test_media_storage.py
Patrick Cloke b312769c0e
Do not error when thumbnailing invalid files (#8236)
If a file cannot be thumbnailed for some reason (e.g. the file is empty), then
catch the exception and convert it to a reasonable error message for the client.
2020-09-09 12:59:41 -04:00

354 lines
12 KiB
Python

# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
import tempfile
from binascii import unhexlify
from io import BytesIO
from typing import Optional
from urllib import parse
from mock import Mock
import attr
from parameterized import parameterized_class
from PIL import Image as Image
from twisted.internet import defer
from twisted.internet.defer import Deferred
from synapse.logging.context import make_deferred_yieldable
from synapse.rest.media.v1._base import FileInfo
from synapse.rest.media.v1.filepath import MediaFilePaths
from synapse.rest.media.v1.media_storage import MediaStorage
from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
from tests import unittest
class MediaStorageTests(unittest.HomeserverTestCase):
needs_threadpool = True
def prepare(self, reactor, clock, hs):
self.test_dir = tempfile.mkdtemp(prefix="synapse-tests-")
self.addCleanup(shutil.rmtree, self.test_dir)
self.primary_base_path = os.path.join(self.test_dir, "primary")
self.secondary_base_path = os.path.join(self.test_dir, "secondary")
hs.config.media_store_path = self.primary_base_path
storage_providers = [FileStorageProviderBackend(hs, self.secondary_base_path)]
self.filepaths = MediaFilePaths(self.primary_base_path)
self.media_storage = MediaStorage(
hs, self.primary_base_path, self.filepaths, storage_providers
)
def test_ensure_media_is_in_local_cache(self):
media_id = "some_media_id"
test_body = "Test\n"
# First we create a file that is in a storage provider but not in the
# local primary media store
rel_path = self.filepaths.local_media_filepath_rel(media_id)
secondary_path = os.path.join(self.secondary_base_path, rel_path)
os.makedirs(os.path.dirname(secondary_path))
with open(secondary_path, "w") as f:
f.write(test_body)
# Now we run ensure_media_is_in_local_cache, which should copy the file
# to the local cache.
file_info = FileInfo(None, media_id)
# This uses a real blocking threadpool so we have to wait for it to be
# actually done :/
x = defer.ensureDeferred(
self.media_storage.ensure_media_is_in_local_cache(file_info)
)
# Hotloop until the threadpool does its job...
self.wait_on_thread(x)
local_path = self.get_success(x)
self.assertTrue(os.path.exists(local_path))
# Asserts the file is under the expected local cache directory
self.assertEquals(
os.path.commonprefix([self.primary_base_path, local_path]),
self.primary_base_path,
)
with open(local_path) as f:
body = f.read()
self.assertEqual(test_body, body)
@attr.s
class _TestImage:
"""An image for testing thumbnailing with the expected results
Attributes:
data: The raw image to thumbnail
content_type: The type of the image as a content type, e.g. "image/png"
extension: The extension associated with the format, e.g. ".png"
expected_cropped: The expected bytes from cropped thumbnailing, or None if
test should just check for success.
expected_scaled: The expected bytes from scaled thumbnailing, or None if
test should just check for a valid image returned.
"""
data = attr.ib(type=bytes)
content_type = attr.ib(type=bytes)
extension = attr.ib(type=bytes)
expected_cropped = attr.ib(type=Optional[bytes])
expected_scaled = attr.ib(type=Optional[bytes])
expected_found = attr.ib(default=True, type=bool)
@parameterized_class(
("test_image",),
[
# smoll png
(
_TestImage(
unhexlify(
b"89504e470d0a1a0a0000000d4948445200000001000000010806"
b"0000001f15c4890000000a49444154789c63000100000500010d"
b"0a2db40000000049454e44ae426082"
),
b"image/png",
b".png",
unhexlify(
b"89504e470d0a1a0a0000000d4948445200000020000000200806"
b"000000737a7af40000001a49444154789cedc101010000008220"
b"ffaf6e484001000000ef0610200001194334ee0000000049454e"
b"44ae426082"
),
unhexlify(
b"89504e470d0a1a0a0000000d4948445200000001000000010806"
b"0000001f15c4890000000d49444154789c636060606000000005"
b"0001a5f645400000000049454e44ae426082"
),
),
),
# small lossless webp
(
_TestImage(
unhexlify(
b"524946461a000000574542505650384c0d0000002f0000001007"
b"1011118888fe0700"
),
b"image/webp",
b".webp",
None,
None,
),
),
# an empty file
(_TestImage(b"", b"image/gif", b".gif", None, None, False,),),
],
)
class MediaRepoTests(unittest.HomeserverTestCase):
hijack_auth = True
user_id = "@test:user"
def make_homeserver(self, reactor, clock):
self.fetches = []
def get_file(destination, path, output_stream, args=None, max_size=None):
"""
Returns tuple[int,dict,str,int] of file length, response headers,
absolute URI, and response code.
"""
def write_to(r):
data, response = r
output_stream.write(data)
return response
d = Deferred()
d.addCallback(write_to)
self.fetches.append((d, destination, path, args))
return make_deferred_yieldable(d)
client = Mock()
client.get_file = get_file
self.storage_path = self.mktemp()
self.media_store_path = self.mktemp()
os.mkdir(self.storage_path)
os.mkdir(self.media_store_path)
config = self.default_config()
config["media_store_path"] = self.media_store_path
config["thumbnail_requirements"] = {}
config["max_image_pixels"] = 2000000
provider_config = {
"module": "synapse.rest.media.v1.storage_provider.FileStorageProviderBackend",
"store_local": True,
"store_synchronous": False,
"store_remote": True,
"config": {"directory": self.storage_path},
}
config["media_storage_providers"] = [provider_config]
hs = self.setup_test_homeserver(config=config, http_client=client)
return hs
def prepare(self, reactor, clock, hs):
self.media_repo = hs.get_media_repository_resource()
self.download_resource = self.media_repo.children[b"download"]
self.thumbnail_resource = self.media_repo.children[b"thumbnail"]
self.media_id = "example.com/12345"
def _req(self, content_disposition):
request, channel = self.make_request("GET", self.media_id, shorthand=False)
request.render(self.download_resource)
self.pump()
# We've made one fetch, to example.com, using the media URL, and asking
# the other server not to do a remote fetch
self.assertEqual(len(self.fetches), 1)
self.assertEqual(self.fetches[0][1], "example.com")
self.assertEqual(
self.fetches[0][2], "/_matrix/media/r0/download/" + self.media_id
)
self.assertEqual(self.fetches[0][3], {"allow_remote": "false"})
headers = {
b"Content-Length": [b"%d" % (len(self.test_image.data))],
b"Content-Type": [self.test_image.content_type],
}
if content_disposition:
headers[b"Content-Disposition"] = [content_disposition]
self.fetches[0][0].callback(
(self.test_image.data, (len(self.test_image.data), headers))
)
self.pump()
self.assertEqual(channel.code, 200)
return channel
def test_disposition_filename_ascii(self):
"""
If the filename is filename=<ascii> then Synapse will decode it as an
ASCII string, and use filename= in the response.
"""
channel = self._req(b"inline; filename=out" + self.test_image.extension)
headers = channel.headers
self.assertEqual(
headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type]
)
self.assertEqual(
headers.getRawHeaders(b"Content-Disposition"),
[b"inline; filename=out" + self.test_image.extension],
)
def test_disposition_filenamestar_utf8escaped(self):
"""
If the filename is filename=*utf8''<utf8 escaped> then Synapse will
correctly decode it as the UTF-8 string, and use filename* in the
response.
"""
filename = parse.quote("\u2603".encode("utf8")).encode("ascii")
channel = self._req(
b"inline; filename*=utf-8''" + filename + self.test_image.extension
)
headers = channel.headers
self.assertEqual(
headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type]
)
self.assertEqual(
headers.getRawHeaders(b"Content-Disposition"),
[b"inline; filename*=utf-8''" + filename + self.test_image.extension],
)
def test_disposition_none(self):
"""
If there is no filename, one isn't passed on in the Content-Disposition
of the request.
"""
channel = self._req(None)
headers = channel.headers
self.assertEqual(
headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type]
)
self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None)
def test_thumbnail_crop(self):
self._test_thumbnail(
"crop", self.test_image.expected_cropped, self.test_image.expected_found
)
def test_thumbnail_scale(self):
self._test_thumbnail(
"scale", self.test_image.expected_scaled, self.test_image.expected_found
)
def _test_thumbnail(self, method, expected_body, expected_found):
params = "?width=32&height=32&method=" + method
request, channel = self.make_request(
"GET", self.media_id + params, shorthand=False
)
request.render(self.thumbnail_resource)
self.pump()
headers = {
b"Content-Length": [b"%d" % (len(self.test_image.data))],
b"Content-Type": [self.test_image.content_type],
}
self.fetches[0][0].callback(
(self.test_image.data, (len(self.test_image.data), headers))
)
self.pump()
if expected_found:
self.assertEqual(channel.code, 200)
if expected_body is not None:
self.assertEqual(
channel.result["body"], expected_body, channel.result["body"]
)
else:
# ensure that the result is at least some valid image
Image.open(BytesIO(channel.result["body"]))
else:
# A 404 with a JSON body.
self.assertEqual(channel.code, 404)
self.assertEqual(
channel.json_body,
{
"errcode": "M_NOT_FOUND",
"error": "Not found [b'example.com', b'12345?width=32&height=32&method=%s']"
% method,
},
)