0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2025-01-05 00:03:53 +01:00

Add type hints to tests/rest. ()

Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>
This commit is contained in:
Dirk Klimpel 2022-03-11 13:42:22 +01:00 committed by GitHub
parent e10a2fe0c2
commit 32c828d0f7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 129 additions and 85 deletions

1
changelog.d/12208.misc Normal file
View file

@ -0,0 +1 @@
Add type hints to tests files.

View file

@ -90,7 +90,6 @@ exclude = (?x)
|tests/push/test_push_rule_evaluator.py |tests/push/test_push_rule_evaluator.py
|tests/rest/client/test_transactions.py |tests/rest/client/test_transactions.py
|tests/rest/media/v1/test_media_storage.py |tests/rest/media/v1/test_media_storage.py
|tests/rest/media/v1/test_url_preview.py
|tests/scripts/test_new_matrix_user.py |tests/scripts/test_new_matrix_user.py
|tests/server.py |tests/server.py
|tests/server_notices/test_resource_limits_server_notices.py |tests/server_notices/test_resource_limits_server_notices.py

View file

@ -1,3 +1,18 @@
# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from http import HTTPStatus
from unittest.mock import Mock, call from unittest.mock import Mock, call
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
@ -11,14 +26,14 @@ from tests.utils import MockClock
class HttpTransactionCacheTestCase(unittest.TestCase): class HttpTransactionCacheTestCase(unittest.TestCase):
def setUp(self): def setUp(self) -> None:
self.clock = MockClock() self.clock = MockClock()
self.hs = Mock() self.hs = Mock()
self.hs.get_clock = Mock(return_value=self.clock) self.hs.get_clock = Mock(return_value=self.clock)
self.hs.get_auth = Mock() self.hs.get_auth = Mock()
self.cache = HttpTransactionCache(self.hs) self.cache = HttpTransactionCache(self.hs)
self.mock_http_response = (200, "GOOD JOB!") self.mock_http_response = (HTTPStatus.OK, "GOOD JOB!")
self.mock_key = "foo" self.mock_key = "foo"
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -16,7 +16,7 @@ import shutil
import tempfile import tempfile
from binascii import unhexlify from binascii import unhexlify
from io import BytesIO from io import BytesIO
from typing import Optional from typing import Any, BinaryIO, Dict, List, Optional, Union
from unittest.mock import Mock from unittest.mock import Mock
from urllib import parse from urllib import parse
@ -26,18 +26,24 @@ from PIL import Image as Image
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.defer import Deferred from twisted.internet.defer import Deferred
from twisted.test.proto_helpers import MemoryReactor
from synapse.events import EventBase
from synapse.events.spamcheck import load_legacy_spam_checkers from synapse.events.spamcheck import load_legacy_spam_checkers
from synapse.logging.context import make_deferred_yieldable from synapse.logging.context import make_deferred_yieldable
from synapse.module_api import ModuleApi
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import login from synapse.rest.client import login
from synapse.rest.media.v1._base import FileInfo from synapse.rest.media.v1._base import FileInfo
from synapse.rest.media.v1.filepath import MediaFilePaths from synapse.rest.media.v1.filepath import MediaFilePaths
from synapse.rest.media.v1.media_storage import MediaStorage from synapse.rest.media.v1.media_storage import MediaStorage, ReadableFileWrapper
from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
from synapse.server import HomeServer
from synapse.types import RoomAlias
from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.server import FakeSite, make_request from tests.server import FakeChannel, FakeSite, make_request
from tests.test_utils import SMALL_PNG from tests.test_utils import SMALL_PNG
from tests.utils import default_config from tests.utils import default_config
@ -46,7 +52,7 @@ class MediaStorageTests(unittest.HomeserverTestCase):
needs_threadpool = True needs_threadpool = True
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.test_dir = tempfile.mkdtemp(prefix="synapse-tests-") self.test_dir = tempfile.mkdtemp(prefix="synapse-tests-")
self.addCleanup(shutil.rmtree, self.test_dir) self.addCleanup(shutil.rmtree, self.test_dir)
@ -62,7 +68,7 @@ class MediaStorageTests(unittest.HomeserverTestCase):
hs, self.primary_base_path, self.filepaths, storage_providers hs, self.primary_base_path, self.filepaths, storage_providers
) )
def test_ensure_media_is_in_local_cache(self): def test_ensure_media_is_in_local_cache(self) -> None:
media_id = "some_media_id" media_id = "some_media_id"
test_body = "Test\n" test_body = "Test\n"
@ -105,7 +111,7 @@ class MediaStorageTests(unittest.HomeserverTestCase):
self.assertEqual(test_body, body) self.assertEqual(test_body, body)
@attr.s(slots=True, frozen=True) @attr.s(auto_attribs=True, slots=True, frozen=True)
class _TestImage: class _TestImage:
"""An image for testing thumbnailing with the expected results """An image for testing thumbnailing with the expected results
@ -121,18 +127,18 @@ class _TestImage:
a 404 is expected. a 404 is expected.
""" """
data = attr.ib(type=bytes) data: bytes
content_type = attr.ib(type=bytes) content_type: bytes
extension = attr.ib(type=bytes) extension: bytes
expected_cropped = attr.ib(type=Optional[bytes], default=None) expected_cropped: Optional[bytes] = None
expected_scaled = attr.ib(type=Optional[bytes], default=None) expected_scaled: Optional[bytes] = None
expected_found = attr.ib(default=True, type=bool) expected_found: bool = True
@parameterized_class( @parameterized_class(
("test_image",), ("test_image",),
[ [
# smoll png # small png
( (
_TestImage( _TestImage(
SMALL_PNG, SMALL_PNG,
@ -193,11 +199,17 @@ class MediaRepoTests(unittest.HomeserverTestCase):
hijack_auth = True hijack_auth = True
user_id = "@test:user" user_id = "@test:user"
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.fetches = [] self.fetches = []
def get_file(destination, path, output_stream, args=None, max_size=None): def get_file(
destination: str,
path: str,
output_stream: BinaryIO,
args: Optional[Dict[str, Union[str, List[str]]]] = None,
max_size: Optional[int] = None,
) -> Deferred:
""" """
Returns tuple[int,dict,str,int] of file length, response headers, Returns tuple[int,dict,str,int] of file length, response headers,
absolute URI, and response code. absolute URI, and response code.
@ -238,7 +250,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
return hs return hs
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
media_resource = hs.get_media_repository_resource() media_resource = hs.get_media_repository_resource()
self.download_resource = media_resource.children[b"download"] self.download_resource = media_resource.children[b"download"]
@ -248,8 +260,9 @@ class MediaRepoTests(unittest.HomeserverTestCase):
self.media_id = "example.com/12345" self.media_id = "example.com/12345"
def _req(self, content_disposition, include_content_type=True): def _req(
self, content_disposition: Optional[bytes], include_content_type: bool = True
) -> FakeChannel:
channel = make_request( channel = make_request(
self.reactor, self.reactor,
FakeSite(self.download_resource, self.reactor), FakeSite(self.download_resource, self.reactor),
@ -288,7 +301,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
return channel return channel
def test_handle_missing_content_type(self): def test_handle_missing_content_type(self) -> None:
channel = self._req( channel = self._req(
b"inline; filename=out" + self.test_image.extension, b"inline; filename=out" + self.test_image.extension,
include_content_type=False, include_content_type=False,
@ -299,7 +312,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
headers.getRawHeaders(b"Content-Type"), [b"application/octet-stream"] headers.getRawHeaders(b"Content-Type"), [b"application/octet-stream"]
) )
def test_disposition_filename_ascii(self): def test_disposition_filename_ascii(self) -> None:
""" """
If the filename is filename=<ascii> then Synapse will decode it as an If the filename is filename=<ascii> then Synapse will decode it as an
ASCII string, and use filename= in the response. ASCII string, and use filename= in the response.
@ -315,7 +328,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
[b"inline; filename=out" + self.test_image.extension], [b"inline; filename=out" + self.test_image.extension],
) )
def test_disposition_filenamestar_utf8escaped(self): def test_disposition_filenamestar_utf8escaped(self) -> None:
""" """
If the filename is filename=*utf8''<utf8 escaped> then Synapse will If the filename is filename=*utf8''<utf8 escaped> then Synapse will
correctly decode it as the UTF-8 string, and use filename* in the correctly decode it as the UTF-8 string, and use filename* in the
@ -335,7 +348,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
[b"inline; filename*=utf-8''" + filename + self.test_image.extension], [b"inline; filename*=utf-8''" + filename + self.test_image.extension],
) )
def test_disposition_none(self): def test_disposition_none(self) -> None:
""" """
If there is no filename, one isn't passed on in the Content-Disposition If there is no filename, one isn't passed on in the Content-Disposition
of the request. of the request.
@ -348,26 +361,26 @@ class MediaRepoTests(unittest.HomeserverTestCase):
) )
self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None) self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None)
def test_thumbnail_crop(self): def test_thumbnail_crop(self) -> None:
"""Test that a cropped remote thumbnail is available.""" """Test that a cropped remote thumbnail is available."""
self._test_thumbnail( self._test_thumbnail(
"crop", self.test_image.expected_cropped, self.test_image.expected_found "crop", self.test_image.expected_cropped, self.test_image.expected_found
) )
def test_thumbnail_scale(self): def test_thumbnail_scale(self) -> None:
"""Test that a scaled remote thumbnail is available.""" """Test that a scaled remote thumbnail is available."""
self._test_thumbnail( self._test_thumbnail(
"scale", self.test_image.expected_scaled, self.test_image.expected_found "scale", self.test_image.expected_scaled, self.test_image.expected_found
) )
def test_invalid_type(self): def test_invalid_type(self) -> None:
"""An invalid thumbnail type is never available.""" """An invalid thumbnail type is never available."""
self._test_thumbnail("invalid", None, False) self._test_thumbnail("invalid", None, False)
@unittest.override_config( @unittest.override_config(
{"thumbnail_sizes": [{"width": 32, "height": 32, "method": "scale"}]} {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "scale"}]}
) )
def test_no_thumbnail_crop(self): def test_no_thumbnail_crop(self) -> None:
""" """
Override the config to generate only scaled thumbnails, but request a cropped one. Override the config to generate only scaled thumbnails, but request a cropped one.
""" """
@ -376,13 +389,13 @@ class MediaRepoTests(unittest.HomeserverTestCase):
@unittest.override_config( @unittest.override_config(
{"thumbnail_sizes": [{"width": 32, "height": 32, "method": "crop"}]} {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "crop"}]}
) )
def test_no_thumbnail_scale(self): def test_no_thumbnail_scale(self) -> None:
""" """
Override the config to generate only cropped thumbnails, but request a scaled one. Override the config to generate only cropped thumbnails, but request a scaled one.
""" """
self._test_thumbnail("scale", None, False) self._test_thumbnail("scale", None, False)
def test_thumbnail_repeated_thumbnail(self): def test_thumbnail_repeated_thumbnail(self) -> None:
"""Test that fetching the same thumbnail works, and deleting the on disk """Test that fetching the same thumbnail works, and deleting the on disk
thumbnail regenerates it. thumbnail regenerates it.
""" """
@ -443,7 +456,9 @@ class MediaRepoTests(unittest.HomeserverTestCase):
channel.result["body"], channel.result["body"],
) )
def _test_thumbnail(self, method, expected_body, expected_found): def _test_thumbnail(
self, method: str, expected_body: Optional[bytes], expected_found: bool
) -> None:
params = "?width=32&height=32&method=" + method params = "?width=32&height=32&method=" + method
channel = make_request( channel = make_request(
self.reactor, self.reactor,
@ -485,7 +500,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
) )
@parameterized.expand([("crop", 16), ("crop", 64), ("scale", 16), ("scale", 64)]) @parameterized.expand([("crop", 16), ("crop", 64), ("scale", 16), ("scale", 64)])
def test_same_quality(self, method, desired_size): def test_same_quality(self, method: str, desired_size: int) -> None:
"""Test that choosing between thumbnails with the same quality rating succeeds. """Test that choosing between thumbnails with the same quality rating succeeds.
We are not particular about which thumbnail is chosen.""" We are not particular about which thumbnail is chosen."""
@ -521,7 +536,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
) )
) )
def test_x_robots_tag_header(self): def test_x_robots_tag_header(self) -> None:
""" """
Tests that the `X-Robots-Tag` header is present, which informs web crawlers Tests that the `X-Robots-Tag` header is present, which informs web crawlers
to not index, archive, or follow links in media. to not index, archive, or follow links in media.
@ -540,29 +555,38 @@ class TestSpamChecker:
`evil`. `evil`.
""" """
def __init__(self, config, api): def __init__(self, config: Dict[str, Any], api: ModuleApi) -> None:
self.config = config self.config = config
self.api = api self.api = api
def parse_config(config): def parse_config(config: Dict[str, Any]) -> Dict[str, Any]:
return config return config
async def check_event_for_spam(self, foo): async def check_event_for_spam(self, event: EventBase) -> Union[bool, str]:
return False # allow all events return False # allow all events
async def user_may_invite(self, inviter_userid, invitee_userid, room_id): async def user_may_invite(
self,
inviter_userid: str,
invitee_userid: str,
room_id: str,
) -> bool:
return True # allow all invites return True # allow all invites
async def user_may_create_room(self, userid): async def user_may_create_room(self, userid: str) -> bool:
return True # allow all room creations return True # allow all room creations
async def user_may_create_room_alias(self, userid, room_alias): async def user_may_create_room_alias(
self, userid: str, room_alias: RoomAlias
) -> bool:
return True # allow all room aliases return True # allow all room aliases
async def user_may_publish_room(self, userid, room_id): async def user_may_publish_room(self, userid: str, room_id: str) -> bool:
return True # allow publishing of all rooms return True # allow publishing of all rooms
async def check_media_file_for_spam(self, file_wrapper, file_info) -> bool: async def check_media_file_for_spam(
self, file_wrapper: ReadableFileWrapper, file_info: FileInfo
) -> bool:
buf = BytesIO() buf = BytesIO()
await file_wrapper.write_chunks_to(buf.write) await file_wrapper.write_chunks_to(buf.write)
@ -575,7 +599,7 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
admin.register_servlets, admin.register_servlets,
] ]
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user = self.register_user("user", "pass") self.user = self.register_user("user", "pass")
self.tok = self.login("user", "pass") self.tok = self.login("user", "pass")
@ -586,7 +610,7 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
load_legacy_spam_checkers(hs) load_legacy_spam_checkers(hs)
def default_config(self): def default_config(self) -> Dict[str, Any]:
config = default_config("test") config = default_config("test")
config.update( config.update(
@ -602,13 +626,13 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
return config return config
def test_upload_innocent(self): def test_upload_innocent(self) -> None:
"""Attempt to upload some innocent data that should be allowed.""" """Attempt to upload some innocent data that should be allowed."""
self.helper.upload_media( self.helper.upload_media(
self.upload_resource, SMALL_PNG, tok=self.tok, expect_code=200 self.upload_resource, SMALL_PNG, tok=self.tok, expect_code=200
) )
def test_upload_ban(self): def test_upload_ban(self) -> None:
"""Attempt to upload some data that includes bytes "evil", which should """Attempt to upload some data that includes bytes "evil", which should
get rejected by the spam checker. get rejected by the spam checker.
""" """

View file

@ -16,16 +16,21 @@ import base64
import json import json
import os import os
import re import re
from typing import Any, Dict, Optional, Sequence, Tuple, Type
from urllib.parse import urlencode from urllib.parse import urlencode
from twisted.internet._resolver import HostResolution from twisted.internet._resolver import HostResolution
from twisted.internet.address import IPv4Address, IPv6Address from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.error import DNSLookupError from twisted.internet.error import DNSLookupError
from twisted.test.proto_helpers import AccumulatingProtocol from twisted.internet.interfaces import IAddress, IResolutionReceiver
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactor
from synapse.config.oembed import OEmbedEndpointConfig from synapse.config.oembed import OEmbedEndpointConfig
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
from synapse.rest.media.v1.preview_url_resource import IMAGE_CACHE_EXPIRY_MS from synapse.rest.media.v1.preview_url_resource import IMAGE_CACHE_EXPIRY_MS
from synapse.server import HomeServer
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import Clock
from synapse.util.stringutils import parse_and_validate_mxc_uri from synapse.util.stringutils import parse_and_validate_mxc_uri
from tests import unittest from tests import unittest
@ -52,7 +57,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
b"</head></html>" b"</head></html>"
) )
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config() config = self.default_config()
config["url_preview_enabled"] = True config["url_preview_enabled"] = True
@ -113,22 +118,22 @@ class URLPreviewTests(unittest.HomeserverTestCase):
return hs return hs
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.media_repo = hs.get_media_repository_resource() self.media_repo = hs.get_media_repository_resource()
self.preview_url = self.media_repo.children[b"preview_url"] self.preview_url = self.media_repo.children[b"preview_url"]
self.lookups = {} self.lookups: Dict[str, Any] = {}
class Resolver: class Resolver:
def resolveHostName( def resolveHostName(
_self, _self,
resolutionReceiver, resolutionReceiver: IResolutionReceiver,
hostName, hostName: str,
portNumber=0, portNumber: int = 0,
addressTypes=None, addressTypes: Optional[Sequence[Type[IAddress]]] = None,
transportSemantics="TCP", transportSemantics: str = "TCP",
): ) -> IResolutionReceiver:
resolution = HostResolution(hostName) resolution = HostResolution(hostName)
resolutionReceiver.resolutionBegan(resolution) resolutionReceiver.resolutionBegan(resolution)
@ -140,9 +145,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
resolutionReceiver.resolutionComplete() resolutionReceiver.resolutionComplete()
return resolutionReceiver return resolutionReceiver
self.reactor.nameResolver = Resolver() self.reactor.nameResolver = Resolver() # type: ignore[assignment]
def create_test_resource(self): def create_test_resource(self) -> MediaRepositoryResource:
return self.hs.get_media_repository_resource() return self.hs.get_media_repository_resource()
def _assert_small_png(self, json_body: JsonDict) -> None: def _assert_small_png(self, json_body: JsonDict) -> None:
@ -153,7 +158,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(json_body["og:image:type"], "image/png") self.assertEqual(json_body["og:image:type"], "image/png")
self.assertEqual(json_body["matrix:image:size"], 67) self.assertEqual(json_body["matrix:image:size"], 67)
def test_cache_returns_correct_type(self): def test_cache_returns_correct_type(self) -> None:
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
channel = self.make_request( channel = self.make_request(
@ -207,7 +212,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel.json_body, {"og:title": "~matrix~", "og:description": "hi"} channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
) )
def test_non_ascii_preview_httpequiv(self): def test_non_ascii_preview_httpequiv(self) -> None:
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
end_content = ( end_content = (
@ -243,7 +248,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430") self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430")
def test_video_rejected(self): def test_video_rejected(self) -> None:
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
end_content = b"anything" end_content = b"anything"
@ -279,7 +284,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
}, },
) )
def test_audio_rejected(self): def test_audio_rejected(self) -> None:
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
end_content = b"anything" end_content = b"anything"
@ -315,7 +320,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
}, },
) )
def test_non_ascii_preview_content_type(self): def test_non_ascii_preview_content_type(self) -> None:
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
end_content = ( end_content = (
@ -350,7 +355,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430") self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430")
def test_overlong_title(self): def test_overlong_title(self) -> None:
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
end_content = ( end_content = (
@ -387,7 +392,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
# We should only see the `og:description` field, as `title` is too long and should be stripped out # We should only see the `og:description` field, as `title` is too long and should be stripped out
self.assertCountEqual(["og:description"], res.keys()) self.assertCountEqual(["og:description"], res.keys())
def test_ipaddr(self): def test_ipaddr(self) -> None:
""" """
IP addresses can be previewed directly. IP addresses can be previewed directly.
""" """
@ -417,7 +422,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel.json_body, {"og:title": "~matrix~", "og:description": "hi"} channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
) )
def test_blacklisted_ip_specific(self): def test_blacklisted_ip_specific(self) -> None:
""" """
Blacklisted IP addresses, found via DNS, are not spidered. Blacklisted IP addresses, found via DNS, are not spidered.
""" """
@ -438,7 +443,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
}, },
) )
def test_blacklisted_ip_range(self): def test_blacklisted_ip_range(self) -> None:
""" """
Blacklisted IP ranges, IPs found over DNS, are not spidered. Blacklisted IP ranges, IPs found over DNS, are not spidered.
""" """
@ -457,7 +462,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
}, },
) )
def test_blacklisted_ip_specific_direct(self): def test_blacklisted_ip_specific_direct(self) -> None:
""" """
Blacklisted IP addresses, accessed directly, are not spidered. Blacklisted IP addresses, accessed directly, are not spidered.
""" """
@ -476,7 +481,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
) )
self.assertEqual(channel.code, 403) self.assertEqual(channel.code, 403)
def test_blacklisted_ip_range_direct(self): def test_blacklisted_ip_range_direct(self) -> None:
""" """
Blacklisted IP ranges, accessed directly, are not spidered. Blacklisted IP ranges, accessed directly, are not spidered.
""" """
@ -493,7 +498,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
}, },
) )
def test_blacklisted_ip_range_whitelisted_ip(self): def test_blacklisted_ip_range_whitelisted_ip(self) -> None:
""" """
Blacklisted but then subsequently whitelisted IP addresses can be Blacklisted but then subsequently whitelisted IP addresses can be
spidered. spidered.
@ -526,7 +531,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel.json_body, {"og:title": "~matrix~", "og:description": "hi"} channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
) )
def test_blacklisted_ip_with_external_ip(self): def test_blacklisted_ip_with_external_ip(self) -> None:
""" """
If a hostname resolves a blacklisted IP, even if there's a If a hostname resolves a blacklisted IP, even if there's a
non-blacklisted one, it will be rejected. non-blacklisted one, it will be rejected.
@ -549,7 +554,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
}, },
) )
def test_blacklisted_ipv6_specific(self): def test_blacklisted_ipv6_specific(self) -> None:
""" """
Blacklisted IP addresses, found via DNS, are not spidered. Blacklisted IP addresses, found via DNS, are not spidered.
""" """
@ -572,7 +577,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
}, },
) )
def test_blacklisted_ipv6_range(self): def test_blacklisted_ipv6_range(self) -> None:
""" """
Blacklisted IP ranges, IPs found over DNS, are not spidered. Blacklisted IP ranges, IPs found over DNS, are not spidered.
""" """
@ -591,7 +596,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
}, },
) )
def test_OPTIONS(self): def test_OPTIONS(self) -> None:
""" """
OPTIONS returns the OPTIONS. OPTIONS returns the OPTIONS.
""" """
@ -601,7 +606,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body, {}) self.assertEqual(channel.json_body, {})
def test_accept_language_config_option(self): def test_accept_language_config_option(self) -> None:
""" """
Accept-Language header is sent to the remote server Accept-Language header is sent to the remote server
""" """
@ -652,7 +657,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
server.data, server.data,
) )
def test_data_url(self): def test_data_url(self) -> None:
""" """
Requesting to preview a data URL is not supported. Requesting to preview a data URL is not supported.
""" """
@ -675,7 +680,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 500) self.assertEqual(channel.code, 500)
def test_inline_data_url(self): def test_inline_data_url(self) -> None:
""" """
An inline image (as a data URL) should be parsed properly. An inline image (as a data URL) should be parsed properly.
""" """
@ -712,7 +717,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)
self._assert_small_png(channel.json_body) self._assert_small_png(channel.json_body)
def test_oembed_photo(self): def test_oembed_photo(self) -> None:
"""Test an oEmbed endpoint which returns a 'photo' type which redirects the preview to a new URL.""" """Test an oEmbed endpoint which returns a 'photo' type which redirects the preview to a new URL."""
self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")] self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
self.lookups["cdn.twitter.com"] = [(IPv4Address, "10.1.2.3")] self.lookups["cdn.twitter.com"] = [(IPv4Address, "10.1.2.3")]
@ -771,7 +776,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(body["og:url"], "http://twitter.com/matrixdotorg/status/12345") self.assertEqual(body["og:url"], "http://twitter.com/matrixdotorg/status/12345")
self._assert_small_png(body) self._assert_small_png(body)
def test_oembed_rich(self): def test_oembed_rich(self) -> None:
"""Test an oEmbed endpoint which returns HTML content via the 'rich' type.""" """Test an oEmbed endpoint which returns HTML content via the 'rich' type."""
self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")] self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
@ -817,7 +822,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
}, },
) )
def test_oembed_format(self): def test_oembed_format(self) -> None:
"""Test an oEmbed endpoint which requires the format in the URL.""" """Test an oEmbed endpoint which requires the format in the URL."""
self.lookups["www.hulu.com"] = [(IPv4Address, "10.1.2.3")] self.lookups["www.hulu.com"] = [(IPv4Address, "10.1.2.3")]
@ -866,7 +871,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
}, },
) )
def test_oembed_autodiscovery(self): def test_oembed_autodiscovery(self) -> None:
""" """
Autodiscovery works by finding the link in the HTML response and then requesting an oEmbed URL. Autodiscovery works by finding the link in the HTML response and then requesting an oEmbed URL.
1. Request a preview of a URL which is not known to the oEmbed code. 1. Request a preview of a URL which is not known to the oEmbed code.
@ -962,7 +967,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
) )
self._assert_small_png(body) self._assert_small_png(body)
def _download_image(self): def _download_image(self) -> Tuple[str, str]:
"""Downloads an image into the URL cache. """Downloads an image into the URL cache.
Returns: Returns:
A (host, media_id) tuple representing the MXC URI of the image. A (host, media_id) tuple representing the MXC URI of the image.
@ -995,7 +1000,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertIsNone(_port) self.assertIsNone(_port)
return host, media_id return host, media_id
def test_storage_providers_exclude_files(self): def test_storage_providers_exclude_files(self) -> None:
"""Test that files are not stored in or fetched from storage providers.""" """Test that files are not stored in or fetched from storage providers."""
host, media_id = self._download_image() host, media_id = self._download_image()
@ -1037,7 +1042,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
"URL cache file was unexpectedly retrieved from a storage provider", "URL cache file was unexpectedly retrieved from a storage provider",
) )
def test_storage_providers_exclude_thumbnails(self): def test_storage_providers_exclude_thumbnails(self) -> None:
"""Test that thumbnails are not stored in or fetched from storage providers.""" """Test that thumbnails are not stored in or fetched from storage providers."""
host, media_id = self._download_image() host, media_id = self._download_image()
@ -1090,7 +1095,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
"URL cache thumbnail was unexpectedly retrieved from a storage provider", "URL cache thumbnail was unexpectedly retrieved from a storage provider",
) )
def test_cache_expiry(self): def test_cache_expiry(self) -> None:
"""Test that URL cache files and thumbnails are cleaned up properly on expiry.""" """Test that URL cache files and thumbnails are cleaned up properly on expiry."""
self.preview_url.clock = MockClock() self.preview_url.clock = MockClock()