mirror of
https://mau.dev/maunium/synapse.git
synced 2025-01-05 00:03:53 +01:00
Add type hints to tests/rest
. (#12208)
Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>
This commit is contained in:
parent
e10a2fe0c2
commit
32c828d0f7
5 changed files with 129 additions and 85 deletions
1
changelog.d/12208.misc
Normal file
1
changelog.d/12208.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Add type hints to tests files.
|
1
mypy.ini
1
mypy.ini
|
@ -90,7 +90,6 @@ exclude = (?x)
|
||||||
|tests/push/test_push_rule_evaluator.py
|
|tests/push/test_push_rule_evaluator.py
|
||||||
|tests/rest/client/test_transactions.py
|
|tests/rest/client/test_transactions.py
|
||||||
|tests/rest/media/v1/test_media_storage.py
|
|tests/rest/media/v1/test_media_storage.py
|
||||||
|tests/rest/media/v1/test_url_preview.py
|
|
||||||
|tests/scripts/test_new_matrix_user.py
|
|tests/scripts/test_new_matrix_user.py
|
||||||
|tests/server.py
|
|tests/server.py
|
||||||
|tests/server_notices/test_resource_limits_server_notices.py
|
|tests/server_notices/test_resource_limits_server_notices.py
|
||||||
|
|
|
@ -1,3 +1,18 @@
|
||||||
|
# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from http import HTTPStatus
|
||||||
from unittest.mock import Mock, call
|
from unittest.mock import Mock, call
|
||||||
|
|
||||||
from twisted.internet import defer, reactor
|
from twisted.internet import defer, reactor
|
||||||
|
@ -11,14 +26,14 @@ from tests.utils import MockClock
|
||||||
|
|
||||||
|
|
||||||
class HttpTransactionCacheTestCase(unittest.TestCase):
|
class HttpTransactionCacheTestCase(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self) -> None:
|
||||||
self.clock = MockClock()
|
self.clock = MockClock()
|
||||||
self.hs = Mock()
|
self.hs = Mock()
|
||||||
self.hs.get_clock = Mock(return_value=self.clock)
|
self.hs.get_clock = Mock(return_value=self.clock)
|
||||||
self.hs.get_auth = Mock()
|
self.hs.get_auth = Mock()
|
||||||
self.cache = HttpTransactionCache(self.hs)
|
self.cache = HttpTransactionCache(self.hs)
|
||||||
|
|
||||||
self.mock_http_response = (200, "GOOD JOB!")
|
self.mock_http_response = (HTTPStatus.OK, "GOOD JOB!")
|
||||||
self.mock_key = "foo"
|
self.mock_key = "foo"
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
|
|
@ -16,7 +16,7 @@ import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
from binascii import unhexlify
|
from binascii import unhexlify
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Optional
|
from typing import Any, BinaryIO, Dict, List, Optional, Union
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
from urllib import parse
|
from urllib import parse
|
||||||
|
|
||||||
|
@ -26,18 +26,24 @@ from PIL import Image as Image
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from twisted.internet.defer import Deferred
|
from twisted.internet.defer import Deferred
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
|
from synapse.events import EventBase
|
||||||
from synapse.events.spamcheck import load_legacy_spam_checkers
|
from synapse.events.spamcheck import load_legacy_spam_checkers
|
||||||
from synapse.logging.context import make_deferred_yieldable
|
from synapse.logging.context import make_deferred_yieldable
|
||||||
|
from synapse.module_api import ModuleApi
|
||||||
from synapse.rest import admin
|
from synapse.rest import admin
|
||||||
from synapse.rest.client import login
|
from synapse.rest.client import login
|
||||||
from synapse.rest.media.v1._base import FileInfo
|
from synapse.rest.media.v1._base import FileInfo
|
||||||
from synapse.rest.media.v1.filepath import MediaFilePaths
|
from synapse.rest.media.v1.filepath import MediaFilePaths
|
||||||
from synapse.rest.media.v1.media_storage import MediaStorage
|
from synapse.rest.media.v1.media_storage import MediaStorage, ReadableFileWrapper
|
||||||
from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
|
from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.types import RoomAlias
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.server import FakeSite, make_request
|
from tests.server import FakeChannel, FakeSite, make_request
|
||||||
from tests.test_utils import SMALL_PNG
|
from tests.test_utils import SMALL_PNG
|
||||||
from tests.utils import default_config
|
from tests.utils import default_config
|
||||||
|
|
||||||
|
@ -46,7 +52,7 @@ class MediaStorageTests(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
needs_threadpool = True
|
needs_threadpool = True
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.test_dir = tempfile.mkdtemp(prefix="synapse-tests-")
|
self.test_dir = tempfile.mkdtemp(prefix="synapse-tests-")
|
||||||
self.addCleanup(shutil.rmtree, self.test_dir)
|
self.addCleanup(shutil.rmtree, self.test_dir)
|
||||||
|
|
||||||
|
@ -62,7 +68,7 @@ class MediaStorageTests(unittest.HomeserverTestCase):
|
||||||
hs, self.primary_base_path, self.filepaths, storage_providers
|
hs, self.primary_base_path, self.filepaths, storage_providers
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_ensure_media_is_in_local_cache(self):
|
def test_ensure_media_is_in_local_cache(self) -> None:
|
||||||
media_id = "some_media_id"
|
media_id = "some_media_id"
|
||||||
test_body = "Test\n"
|
test_body = "Test\n"
|
||||||
|
|
||||||
|
@ -105,7 +111,7 @@ class MediaStorageTests(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(test_body, body)
|
self.assertEqual(test_body, body)
|
||||||
|
|
||||||
|
|
||||||
@attr.s(slots=True, frozen=True)
|
@attr.s(auto_attribs=True, slots=True, frozen=True)
|
||||||
class _TestImage:
|
class _TestImage:
|
||||||
"""An image for testing thumbnailing with the expected results
|
"""An image for testing thumbnailing with the expected results
|
||||||
|
|
||||||
|
@ -121,18 +127,18 @@ class _TestImage:
|
||||||
a 404 is expected.
|
a 404 is expected.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
data = attr.ib(type=bytes)
|
data: bytes
|
||||||
content_type = attr.ib(type=bytes)
|
content_type: bytes
|
||||||
extension = attr.ib(type=bytes)
|
extension: bytes
|
||||||
expected_cropped = attr.ib(type=Optional[bytes], default=None)
|
expected_cropped: Optional[bytes] = None
|
||||||
expected_scaled = attr.ib(type=Optional[bytes], default=None)
|
expected_scaled: Optional[bytes] = None
|
||||||
expected_found = attr.ib(default=True, type=bool)
|
expected_found: bool = True
|
||||||
|
|
||||||
|
|
||||||
@parameterized_class(
|
@parameterized_class(
|
||||||
("test_image",),
|
("test_image",),
|
||||||
[
|
[
|
||||||
# smoll png
|
# small png
|
||||||
(
|
(
|
||||||
_TestImage(
|
_TestImage(
|
||||||
SMALL_PNG,
|
SMALL_PNG,
|
||||||
|
@ -193,11 +199,17 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
||||||
hijack_auth = True
|
hijack_auth = True
|
||||||
user_id = "@test:user"
|
user_id = "@test:user"
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
|
|
||||||
self.fetches = []
|
self.fetches = []
|
||||||
|
|
||||||
def get_file(destination, path, output_stream, args=None, max_size=None):
|
def get_file(
|
||||||
|
destination: str,
|
||||||
|
path: str,
|
||||||
|
output_stream: BinaryIO,
|
||||||
|
args: Optional[Dict[str, Union[str, List[str]]]] = None,
|
||||||
|
max_size: Optional[int] = None,
|
||||||
|
) -> Deferred:
|
||||||
"""
|
"""
|
||||||
Returns tuple[int,dict,str,int] of file length, response headers,
|
Returns tuple[int,dict,str,int] of file length, response headers,
|
||||||
absolute URI, and response code.
|
absolute URI, and response code.
|
||||||
|
@ -238,7 +250,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
return hs
|
return hs
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
|
|
||||||
media_resource = hs.get_media_repository_resource()
|
media_resource = hs.get_media_repository_resource()
|
||||||
self.download_resource = media_resource.children[b"download"]
|
self.download_resource = media_resource.children[b"download"]
|
||||||
|
@ -248,8 +260,9 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.media_id = "example.com/12345"
|
self.media_id = "example.com/12345"
|
||||||
|
|
||||||
def _req(self, content_disposition, include_content_type=True):
|
def _req(
|
||||||
|
self, content_disposition: Optional[bytes], include_content_type: bool = True
|
||||||
|
) -> FakeChannel:
|
||||||
channel = make_request(
|
channel = make_request(
|
||||||
self.reactor,
|
self.reactor,
|
||||||
FakeSite(self.download_resource, self.reactor),
|
FakeSite(self.download_resource, self.reactor),
|
||||||
|
@ -288,7 +301,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
return channel
|
return channel
|
||||||
|
|
||||||
def test_handle_missing_content_type(self):
|
def test_handle_missing_content_type(self) -> None:
|
||||||
channel = self._req(
|
channel = self._req(
|
||||||
b"inline; filename=out" + self.test_image.extension,
|
b"inline; filename=out" + self.test_image.extension,
|
||||||
include_content_type=False,
|
include_content_type=False,
|
||||||
|
@ -299,7 +312,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
||||||
headers.getRawHeaders(b"Content-Type"), [b"application/octet-stream"]
|
headers.getRawHeaders(b"Content-Type"), [b"application/octet-stream"]
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_disposition_filename_ascii(self):
|
def test_disposition_filename_ascii(self) -> None:
|
||||||
"""
|
"""
|
||||||
If the filename is filename=<ascii> then Synapse will decode it as an
|
If the filename is filename=<ascii> then Synapse will decode it as an
|
||||||
ASCII string, and use filename= in the response.
|
ASCII string, and use filename= in the response.
|
||||||
|
@ -315,7 +328,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
||||||
[b"inline; filename=out" + self.test_image.extension],
|
[b"inline; filename=out" + self.test_image.extension],
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_disposition_filenamestar_utf8escaped(self):
|
def test_disposition_filenamestar_utf8escaped(self) -> None:
|
||||||
"""
|
"""
|
||||||
If the filename is filename=*utf8''<utf8 escaped> then Synapse will
|
If the filename is filename=*utf8''<utf8 escaped> then Synapse will
|
||||||
correctly decode it as the UTF-8 string, and use filename* in the
|
correctly decode it as the UTF-8 string, and use filename* in the
|
||||||
|
@ -335,7 +348,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
||||||
[b"inline; filename*=utf-8''" + filename + self.test_image.extension],
|
[b"inline; filename*=utf-8''" + filename + self.test_image.extension],
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_disposition_none(self):
|
def test_disposition_none(self) -> None:
|
||||||
"""
|
"""
|
||||||
If there is no filename, one isn't passed on in the Content-Disposition
|
If there is no filename, one isn't passed on in the Content-Disposition
|
||||||
of the request.
|
of the request.
|
||||||
|
@ -348,26 +361,26 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None)
|
self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None)
|
||||||
|
|
||||||
def test_thumbnail_crop(self):
|
def test_thumbnail_crop(self) -> None:
|
||||||
"""Test that a cropped remote thumbnail is available."""
|
"""Test that a cropped remote thumbnail is available."""
|
||||||
self._test_thumbnail(
|
self._test_thumbnail(
|
||||||
"crop", self.test_image.expected_cropped, self.test_image.expected_found
|
"crop", self.test_image.expected_cropped, self.test_image.expected_found
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_thumbnail_scale(self):
|
def test_thumbnail_scale(self) -> None:
|
||||||
"""Test that a scaled remote thumbnail is available."""
|
"""Test that a scaled remote thumbnail is available."""
|
||||||
self._test_thumbnail(
|
self._test_thumbnail(
|
||||||
"scale", self.test_image.expected_scaled, self.test_image.expected_found
|
"scale", self.test_image.expected_scaled, self.test_image.expected_found
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_invalid_type(self):
|
def test_invalid_type(self) -> None:
|
||||||
"""An invalid thumbnail type is never available."""
|
"""An invalid thumbnail type is never available."""
|
||||||
self._test_thumbnail("invalid", None, False)
|
self._test_thumbnail("invalid", None, False)
|
||||||
|
|
||||||
@unittest.override_config(
|
@unittest.override_config(
|
||||||
{"thumbnail_sizes": [{"width": 32, "height": 32, "method": "scale"}]}
|
{"thumbnail_sizes": [{"width": 32, "height": 32, "method": "scale"}]}
|
||||||
)
|
)
|
||||||
def test_no_thumbnail_crop(self):
|
def test_no_thumbnail_crop(self) -> None:
|
||||||
"""
|
"""
|
||||||
Override the config to generate only scaled thumbnails, but request a cropped one.
|
Override the config to generate only scaled thumbnails, but request a cropped one.
|
||||||
"""
|
"""
|
||||||
|
@ -376,13 +389,13 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
||||||
@unittest.override_config(
|
@unittest.override_config(
|
||||||
{"thumbnail_sizes": [{"width": 32, "height": 32, "method": "crop"}]}
|
{"thumbnail_sizes": [{"width": 32, "height": 32, "method": "crop"}]}
|
||||||
)
|
)
|
||||||
def test_no_thumbnail_scale(self):
|
def test_no_thumbnail_scale(self) -> None:
|
||||||
"""
|
"""
|
||||||
Override the config to generate only cropped thumbnails, but request a scaled one.
|
Override the config to generate only cropped thumbnails, but request a scaled one.
|
||||||
"""
|
"""
|
||||||
self._test_thumbnail("scale", None, False)
|
self._test_thumbnail("scale", None, False)
|
||||||
|
|
||||||
def test_thumbnail_repeated_thumbnail(self):
|
def test_thumbnail_repeated_thumbnail(self) -> None:
|
||||||
"""Test that fetching the same thumbnail works, and deleting the on disk
|
"""Test that fetching the same thumbnail works, and deleting the on disk
|
||||||
thumbnail regenerates it.
|
thumbnail regenerates it.
|
||||||
"""
|
"""
|
||||||
|
@ -443,7 +456,9 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
||||||
channel.result["body"],
|
channel.result["body"],
|
||||||
)
|
)
|
||||||
|
|
||||||
def _test_thumbnail(self, method, expected_body, expected_found):
|
def _test_thumbnail(
|
||||||
|
self, method: str, expected_body: Optional[bytes], expected_found: bool
|
||||||
|
) -> None:
|
||||||
params = "?width=32&height=32&method=" + method
|
params = "?width=32&height=32&method=" + method
|
||||||
channel = make_request(
|
channel = make_request(
|
||||||
self.reactor,
|
self.reactor,
|
||||||
|
@ -485,7 +500,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
@parameterized.expand([("crop", 16), ("crop", 64), ("scale", 16), ("scale", 64)])
|
@parameterized.expand([("crop", 16), ("crop", 64), ("scale", 16), ("scale", 64)])
|
||||||
def test_same_quality(self, method, desired_size):
|
def test_same_quality(self, method: str, desired_size: int) -> None:
|
||||||
"""Test that choosing between thumbnails with the same quality rating succeeds.
|
"""Test that choosing between thumbnails with the same quality rating succeeds.
|
||||||
|
|
||||||
We are not particular about which thumbnail is chosen."""
|
We are not particular about which thumbnail is chosen."""
|
||||||
|
@ -521,7 +536,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_x_robots_tag_header(self):
|
def test_x_robots_tag_header(self) -> None:
|
||||||
"""
|
"""
|
||||||
Tests that the `X-Robots-Tag` header is present, which informs web crawlers
|
Tests that the `X-Robots-Tag` header is present, which informs web crawlers
|
||||||
to not index, archive, or follow links in media.
|
to not index, archive, or follow links in media.
|
||||||
|
@ -540,29 +555,38 @@ class TestSpamChecker:
|
||||||
`evil`.
|
`evil`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config, api):
|
def __init__(self, config: Dict[str, Any], api: ModuleApi) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.api = api
|
self.api = api
|
||||||
|
|
||||||
def parse_config(config):
|
def parse_config(config: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
return config
|
return config
|
||||||
|
|
||||||
async def check_event_for_spam(self, foo):
|
async def check_event_for_spam(self, event: EventBase) -> Union[bool, str]:
|
||||||
return False # allow all events
|
return False # allow all events
|
||||||
|
|
||||||
async def user_may_invite(self, inviter_userid, invitee_userid, room_id):
|
async def user_may_invite(
|
||||||
|
self,
|
||||||
|
inviter_userid: str,
|
||||||
|
invitee_userid: str,
|
||||||
|
room_id: str,
|
||||||
|
) -> bool:
|
||||||
return True # allow all invites
|
return True # allow all invites
|
||||||
|
|
||||||
async def user_may_create_room(self, userid):
|
async def user_may_create_room(self, userid: str) -> bool:
|
||||||
return True # allow all room creations
|
return True # allow all room creations
|
||||||
|
|
||||||
async def user_may_create_room_alias(self, userid, room_alias):
|
async def user_may_create_room_alias(
|
||||||
|
self, userid: str, room_alias: RoomAlias
|
||||||
|
) -> bool:
|
||||||
return True # allow all room aliases
|
return True # allow all room aliases
|
||||||
|
|
||||||
async def user_may_publish_room(self, userid, room_id):
|
async def user_may_publish_room(self, userid: str, room_id: str) -> bool:
|
||||||
return True # allow publishing of all rooms
|
return True # allow publishing of all rooms
|
||||||
|
|
||||||
async def check_media_file_for_spam(self, file_wrapper, file_info) -> bool:
|
async def check_media_file_for_spam(
|
||||||
|
self, file_wrapper: ReadableFileWrapper, file_info: FileInfo
|
||||||
|
) -> bool:
|
||||||
buf = BytesIO()
|
buf = BytesIO()
|
||||||
await file_wrapper.write_chunks_to(buf.write)
|
await file_wrapper.write_chunks_to(buf.write)
|
||||||
|
|
||||||
|
@ -575,7 +599,7 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
|
||||||
admin.register_servlets,
|
admin.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.user = self.register_user("user", "pass")
|
self.user = self.register_user("user", "pass")
|
||||||
self.tok = self.login("user", "pass")
|
self.tok = self.login("user", "pass")
|
||||||
|
|
||||||
|
@ -586,7 +610,7 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
load_legacy_spam_checkers(hs)
|
load_legacy_spam_checkers(hs)
|
||||||
|
|
||||||
def default_config(self):
|
def default_config(self) -> Dict[str, Any]:
|
||||||
config = default_config("test")
|
config = default_config("test")
|
||||||
|
|
||||||
config.update(
|
config.update(
|
||||||
|
@ -602,13 +626,13 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def test_upload_innocent(self):
|
def test_upload_innocent(self) -> None:
|
||||||
"""Attempt to upload some innocent data that should be allowed."""
|
"""Attempt to upload some innocent data that should be allowed."""
|
||||||
self.helper.upload_media(
|
self.helper.upload_media(
|
||||||
self.upload_resource, SMALL_PNG, tok=self.tok, expect_code=200
|
self.upload_resource, SMALL_PNG, tok=self.tok, expect_code=200
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_upload_ban(self):
|
def test_upload_ban(self) -> None:
|
||||||
"""Attempt to upload some data that includes bytes "evil", which should
|
"""Attempt to upload some data that includes bytes "evil", which should
|
||||||
get rejected by the spam checker.
|
get rejected by the spam checker.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -16,16 +16,21 @@ import base64
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
from typing import Any, Dict, Optional, Sequence, Tuple, Type
|
||||||
from urllib.parse import urlencode
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
from twisted.internet._resolver import HostResolution
|
from twisted.internet._resolver import HostResolution
|
||||||
from twisted.internet.address import IPv4Address, IPv6Address
|
from twisted.internet.address import IPv4Address, IPv6Address
|
||||||
from twisted.internet.error import DNSLookupError
|
from twisted.internet.error import DNSLookupError
|
||||||
from twisted.test.proto_helpers import AccumulatingProtocol
|
from twisted.internet.interfaces import IAddress, IResolutionReceiver
|
||||||
|
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactor
|
||||||
|
|
||||||
from synapse.config.oembed import OEmbedEndpointConfig
|
from synapse.config.oembed import OEmbedEndpointConfig
|
||||||
|
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
|
||||||
from synapse.rest.media.v1.preview_url_resource import IMAGE_CACHE_EXPIRY_MS
|
from synapse.rest.media.v1.preview_url_resource import IMAGE_CACHE_EXPIRY_MS
|
||||||
|
from synapse.server import HomeServer
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
|
from synapse.util import Clock
|
||||||
from synapse.util.stringutils import parse_and_validate_mxc_uri
|
from synapse.util.stringutils import parse_and_validate_mxc_uri
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
@ -52,7 +57,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||||
b"</head></html>"
|
b"</head></html>"
|
||||||
)
|
)
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
|
|
||||||
config = self.default_config()
|
config = self.default_config()
|
||||||
config["url_preview_enabled"] = True
|
config["url_preview_enabled"] = True
|
||||||
|
@ -113,22 +118,22 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
return hs
|
return hs
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
|
|
||||||
self.media_repo = hs.get_media_repository_resource()
|
self.media_repo = hs.get_media_repository_resource()
|
||||||
self.preview_url = self.media_repo.children[b"preview_url"]
|
self.preview_url = self.media_repo.children[b"preview_url"]
|
||||||
|
|
||||||
self.lookups = {}
|
self.lookups: Dict[str, Any] = {}
|
||||||
|
|
||||||
class Resolver:
|
class Resolver:
|
||||||
def resolveHostName(
|
def resolveHostName(
|
||||||
_self,
|
_self,
|
||||||
resolutionReceiver,
|
resolutionReceiver: IResolutionReceiver,
|
||||||
hostName,
|
hostName: str,
|
||||||
portNumber=0,
|
portNumber: int = 0,
|
||||||
addressTypes=None,
|
addressTypes: Optional[Sequence[Type[IAddress]]] = None,
|
||||||
transportSemantics="TCP",
|
transportSemantics: str = "TCP",
|
||||||
):
|
) -> IResolutionReceiver:
|
||||||
|
|
||||||
resolution = HostResolution(hostName)
|
resolution = HostResolution(hostName)
|
||||||
resolutionReceiver.resolutionBegan(resolution)
|
resolutionReceiver.resolutionBegan(resolution)
|
||||||
|
@ -140,9 +145,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||||
resolutionReceiver.resolutionComplete()
|
resolutionReceiver.resolutionComplete()
|
||||||
return resolutionReceiver
|
return resolutionReceiver
|
||||||
|
|
||||||
self.reactor.nameResolver = Resolver()
|
self.reactor.nameResolver = Resolver() # type: ignore[assignment]
|
||||||
|
|
||||||
def create_test_resource(self):
|
def create_test_resource(self) -> MediaRepositoryResource:
|
||||||
return self.hs.get_media_repository_resource()
|
return self.hs.get_media_repository_resource()
|
||||||
|
|
||||||
def _assert_small_png(self, json_body: JsonDict) -> None:
|
def _assert_small_png(self, json_body: JsonDict) -> None:
|
||||||
|
@ -153,7 +158,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(json_body["og:image:type"], "image/png")
|
self.assertEqual(json_body["og:image:type"], "image/png")
|
||||||
self.assertEqual(json_body["matrix:image:size"], 67)
|
self.assertEqual(json_body["matrix:image:size"], 67)
|
||||||
|
|
||||||
def test_cache_returns_correct_type(self):
|
def test_cache_returns_correct_type(self) -> None:
|
||||||
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
|
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
|
||||||
|
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
|
@ -207,7 +212,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||||
channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
|
channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_non_ascii_preview_httpequiv(self):
|
def test_non_ascii_preview_httpequiv(self) -> None:
|
||||||
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
|
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
|
||||||
|
|
||||||
end_content = (
|
end_content = (
|
||||||
|
@ -243,7 +248,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(channel.code, 200)
|
self.assertEqual(channel.code, 200)
|
||||||
self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430")
|
self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430")
|
||||||
|
|
||||||
def test_video_rejected(self):
|
def test_video_rejected(self) -> None:
|
||||||
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
|
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
|
||||||
|
|
||||||
end_content = b"anything"
|
end_content = b"anything"
|
||||||
|
@ -279,7 +284,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_audio_rejected(self):
|
def test_audio_rejected(self) -> None:
|
||||||
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
|
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
|
||||||
|
|
||||||
end_content = b"anything"
|
end_content = b"anything"
|
||||||
|
@ -315,7 +320,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_non_ascii_preview_content_type(self):
|
def test_non_ascii_preview_content_type(self) -> None:
|
||||||
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
|
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
|
||||||
|
|
||||||
end_content = (
|
end_content = (
|
||||||
|
@ -350,7 +355,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(channel.code, 200)
|
self.assertEqual(channel.code, 200)
|
||||||
self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430")
|
self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430")
|
||||||
|
|
||||||
def test_overlong_title(self):
|
def test_overlong_title(self) -> None:
|
||||||
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
|
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
|
||||||
|
|
||||||
end_content = (
|
end_content = (
|
||||||
|
@ -387,7 +392,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||||
# We should only see the `og:description` field, as `title` is too long and should be stripped out
|
# We should only see the `og:description` field, as `title` is too long and should be stripped out
|
||||||
self.assertCountEqual(["og:description"], res.keys())
|
self.assertCountEqual(["og:description"], res.keys())
|
||||||
|
|
||||||
def test_ipaddr(self):
|
def test_ipaddr(self) -> None:
|
||||||
"""
|
"""
|
||||||
IP addresses can be previewed directly.
|
IP addresses can be previewed directly.
|
||||||
"""
|
"""
|
||||||
|
@ -417,7 +422,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||||
channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
|
channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_blacklisted_ip_specific(self):
|
def test_blacklisted_ip_specific(self) -> None:
|
||||||
"""
|
"""
|
||||||
Blacklisted IP addresses, found via DNS, are not spidered.
|
Blacklisted IP addresses, found via DNS, are not spidered.
|
||||||
"""
|
"""
|
||||||
|
@ -438,7 +443,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_blacklisted_ip_range(self):
|
def test_blacklisted_ip_range(self) -> None:
|
||||||
"""
|
"""
|
||||||
Blacklisted IP ranges, IPs found over DNS, are not spidered.
|
Blacklisted IP ranges, IPs found over DNS, are not spidered.
|
||||||
"""
|
"""
|
||||||
|
@ -457,7 +462,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_blacklisted_ip_specific_direct(self):
|
def test_blacklisted_ip_specific_direct(self) -> None:
|
||||||
"""
|
"""
|
||||||
Blacklisted IP addresses, accessed directly, are not spidered.
|
Blacklisted IP addresses, accessed directly, are not spidered.
|
||||||
"""
|
"""
|
||||||
|
@ -476,7 +481,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.assertEqual(channel.code, 403)
|
self.assertEqual(channel.code, 403)
|
||||||
|
|
||||||
def test_blacklisted_ip_range_direct(self):
|
def test_blacklisted_ip_range_direct(self) -> None:
|
||||||
"""
|
"""
|
||||||
Blacklisted IP ranges, accessed directly, are not spidered.
|
Blacklisted IP ranges, accessed directly, are not spidered.
|
||||||
"""
|
"""
|
||||||
|
@ -493,7 +498,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_blacklisted_ip_range_whitelisted_ip(self):
|
def test_blacklisted_ip_range_whitelisted_ip(self) -> None:
|
||||||
"""
|
"""
|
||||||
Blacklisted but then subsequently whitelisted IP addresses can be
|
Blacklisted but then subsequently whitelisted IP addresses can be
|
||||||
spidered.
|
spidered.
|
||||||
|
@ -526,7 +531,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||||
channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
|
channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_blacklisted_ip_with_external_ip(self):
|
def test_blacklisted_ip_with_external_ip(self) -> None:
|
||||||
"""
|
"""
|
||||||
If a hostname resolves a blacklisted IP, even if there's a
|
If a hostname resolves a blacklisted IP, even if there's a
|
||||||
non-blacklisted one, it will be rejected.
|
non-blacklisted one, it will be rejected.
|
||||||
|
@ -549,7 +554,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_blacklisted_ipv6_specific(self):
|
def test_blacklisted_ipv6_specific(self) -> None:
|
||||||
"""
|
"""
|
||||||
Blacklisted IP addresses, found via DNS, are not spidered.
|
Blacklisted IP addresses, found via DNS, are not spidered.
|
||||||
"""
|
"""
|
||||||
|
@ -572,7 +577,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_blacklisted_ipv6_range(self):
|
def test_blacklisted_ipv6_range(self) -> None:
|
||||||
"""
|
"""
|
||||||
Blacklisted IP ranges, IPs found over DNS, are not spidered.
|
Blacklisted IP ranges, IPs found over DNS, are not spidered.
|
||||||
"""
|
"""
|
||||||
|
@ -591,7 +596,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_OPTIONS(self):
|
def test_OPTIONS(self) -> None:
|
||||||
"""
|
"""
|
||||||
OPTIONS returns the OPTIONS.
|
OPTIONS returns the OPTIONS.
|
||||||
"""
|
"""
|
||||||
|
@ -601,7 +606,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(channel.code, 200)
|
self.assertEqual(channel.code, 200)
|
||||||
self.assertEqual(channel.json_body, {})
|
self.assertEqual(channel.json_body, {})
|
||||||
|
|
||||||
def test_accept_language_config_option(self):
|
def test_accept_language_config_option(self) -> None:
|
||||||
"""
|
"""
|
||||||
Accept-Language header is sent to the remote server
|
Accept-Language header is sent to the remote server
|
||||||
"""
|
"""
|
||||||
|
@ -652,7 +657,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||||
server.data,
|
server.data,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_data_url(self):
|
def test_data_url(self) -> None:
|
||||||
"""
|
"""
|
||||||
Requesting to preview a data URL is not supported.
|
Requesting to preview a data URL is not supported.
|
||||||
"""
|
"""
|
||||||
|
@ -675,7 +680,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.assertEqual(channel.code, 500)
|
self.assertEqual(channel.code, 500)
|
||||||
|
|
||||||
def test_inline_data_url(self):
|
def test_inline_data_url(self) -> None:
|
||||||
"""
|
"""
|
||||||
An inline image (as a data URL) should be parsed properly.
|
An inline image (as a data URL) should be parsed properly.
|
||||||
"""
|
"""
|
||||||
|
@ -712,7 +717,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(channel.code, 200)
|
self.assertEqual(channel.code, 200)
|
||||||
self._assert_small_png(channel.json_body)
|
self._assert_small_png(channel.json_body)
|
||||||
|
|
||||||
def test_oembed_photo(self):
|
def test_oembed_photo(self) -> None:
|
||||||
"""Test an oEmbed endpoint which returns a 'photo' type which redirects the preview to a new URL."""
|
"""Test an oEmbed endpoint which returns a 'photo' type which redirects the preview to a new URL."""
|
||||||
self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
|
self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
|
||||||
self.lookups["cdn.twitter.com"] = [(IPv4Address, "10.1.2.3")]
|
self.lookups["cdn.twitter.com"] = [(IPv4Address, "10.1.2.3")]
|
||||||
|
@ -771,7 +776,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(body["og:url"], "http://twitter.com/matrixdotorg/status/12345")
|
self.assertEqual(body["og:url"], "http://twitter.com/matrixdotorg/status/12345")
|
||||||
self._assert_small_png(body)
|
self._assert_small_png(body)
|
||||||
|
|
||||||
def test_oembed_rich(self):
|
def test_oembed_rich(self) -> None:
|
||||||
"""Test an oEmbed endpoint which returns HTML content via the 'rich' type."""
|
"""Test an oEmbed endpoint which returns HTML content via the 'rich' type."""
|
||||||
self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
|
self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
|
||||||
|
|
||||||
|
@ -817,7 +822,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_oembed_format(self):
|
def test_oembed_format(self) -> None:
|
||||||
"""Test an oEmbed endpoint which requires the format in the URL."""
|
"""Test an oEmbed endpoint which requires the format in the URL."""
|
||||||
self.lookups["www.hulu.com"] = [(IPv4Address, "10.1.2.3")]
|
self.lookups["www.hulu.com"] = [(IPv4Address, "10.1.2.3")]
|
||||||
|
|
||||||
|
@ -866,7 +871,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_oembed_autodiscovery(self):
|
def test_oembed_autodiscovery(self) -> None:
|
||||||
"""
|
"""
|
||||||
Autodiscovery works by finding the link in the HTML response and then requesting an oEmbed URL.
|
Autodiscovery works by finding the link in the HTML response and then requesting an oEmbed URL.
|
||||||
1. Request a preview of a URL which is not known to the oEmbed code.
|
1. Request a preview of a URL which is not known to the oEmbed code.
|
||||||
|
@ -962,7 +967,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self._assert_small_png(body)
|
self._assert_small_png(body)
|
||||||
|
|
||||||
def _download_image(self):
|
def _download_image(self) -> Tuple[str, str]:
|
||||||
"""Downloads an image into the URL cache.
|
"""Downloads an image into the URL cache.
|
||||||
Returns:
|
Returns:
|
||||||
A (host, media_id) tuple representing the MXC URI of the image.
|
A (host, media_id) tuple representing the MXC URI of the image.
|
||||||
|
@ -995,7 +1000,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||||
self.assertIsNone(_port)
|
self.assertIsNone(_port)
|
||||||
return host, media_id
|
return host, media_id
|
||||||
|
|
||||||
def test_storage_providers_exclude_files(self):
|
def test_storage_providers_exclude_files(self) -> None:
|
||||||
"""Test that files are not stored in or fetched from storage providers."""
|
"""Test that files are not stored in or fetched from storage providers."""
|
||||||
host, media_id = self._download_image()
|
host, media_id = self._download_image()
|
||||||
|
|
||||||
|
@ -1037,7 +1042,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||||
"URL cache file was unexpectedly retrieved from a storage provider",
|
"URL cache file was unexpectedly retrieved from a storage provider",
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_storage_providers_exclude_thumbnails(self):
|
def test_storage_providers_exclude_thumbnails(self) -> None:
|
||||||
"""Test that thumbnails are not stored in or fetched from storage providers."""
|
"""Test that thumbnails are not stored in or fetched from storage providers."""
|
||||||
host, media_id = self._download_image()
|
host, media_id = self._download_image()
|
||||||
|
|
||||||
|
@ -1090,7 +1095,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||||
"URL cache thumbnail was unexpectedly retrieved from a storage provider",
|
"URL cache thumbnail was unexpectedly retrieved from a storage provider",
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_cache_expiry(self):
|
def test_cache_expiry(self) -> None:
|
||||||
"""Test that URL cache files and thumbnails are cleaned up properly on expiry."""
|
"""Test that URL cache files and thumbnails are cleaned up properly on expiry."""
|
||||||
self.preview_url.clock = MockClock()
|
self.preview_url.clock = MockClock()
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue