Add type hints to tests/rest (#12146)

* Add type hints to `tests/rest`

* newsfile

* change import from `SigningKey`
This commit is contained in:
Dirk Klimpel 2022-03-03 17:05:44 +01:00 committed by GitHub
parent 1d11b452b7
commit 7e91107be1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 104 additions and 92 deletions

1
changelog.d/12146.misc Normal file
View file

@ -0,0 +1 @@
Add type hints to `tests/rest`.

View file

@ -89,8 +89,6 @@ exclude = (?x)
|tests/push/test_presentable_names.py |tests/push/test_presentable_names.py
|tests/push/test_push_rule_evaluator.py |tests/push/test_push_rule_evaluator.py
|tests/rest/client/test_transactions.py |tests/rest/client/test_transactions.py
|tests/rest/key/v2/test_remote_key_resource.py
|tests/rest/media/v1/test_base.py
|tests/rest/media/v1/test_media_storage.py |tests/rest/media/v1/test_media_storage.py
|tests/rest/media/v1/test_url_preview.py |tests/rest/media/v1/test_url_preview.py
|tests/scripts/test_new_matrix_user.py |tests/scripts/test_new_matrix_user.py
@ -254,10 +252,7 @@ disallow_untyped_defs = True
[mypy-tests.storage.test_user_directory] [mypy-tests.storage.test_user_directory]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-tests.rest.admin.*] [mypy-tests.rest.*]
disallow_untyped_defs = True
[mypy-tests.rest.client.*]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-tests.federation.transport.test_client] [mypy-tests.federation.transport.test_client]

View file

@ -13,19 +13,24 @@
# limitations under the License. # limitations under the License.
import urllib.parse import urllib.parse
from io import BytesIO, StringIO from io import BytesIO, StringIO
from typing import Any, Dict, Optional, Union
from unittest.mock import Mock from unittest.mock import Mock
import signedjson.key import signedjson.key
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from nacl.signing import SigningKey
from signedjson.sign import sign_json from signedjson.sign import sign_json
from signedjson.types import SigningKey
from twisted.web.resource import NoResource from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import NoResource, Resource
from synapse.crypto.keyring import PerspectivesKeyFetcher from synapse.crypto.keyring import PerspectivesKeyFetcher
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.rest.key.v2 import KeyApiV2Resource from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.server import HomeServer
from synapse.storage.keys import FetchKeyResult from synapse.storage.keys import FetchKeyResult
from synapse.types import JsonDict
from synapse.util import Clock
from synapse.util.httpresourcetree import create_resource_tree from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
@ -35,11 +40,11 @@ from tests.utils import default_config
class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase): class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.http_client = Mock() self.http_client = Mock()
return self.setup_test_homeserver(federation_http_client=self.http_client) return self.setup_test_homeserver(federation_http_client=self.http_client)
def create_test_resource(self): def create_test_resource(self) -> Resource:
return create_resource_tree( return create_resource_tree(
{"/_matrix/key/v2": KeyApiV2Resource(self.hs)}, root_resource=NoResource() {"/_matrix/key/v2": KeyApiV2Resource(self.hs)}, root_resource=NoResource()
) )
@ -51,7 +56,12 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase):
Tell the mock http client to expect an outgoing GET request for the given key Tell the mock http client to expect an outgoing GET request for the given key
""" """
async def get_json(destination, path, ignore_backoff=False, **kwargs): async def get_json(
destination: str,
path: str,
ignore_backoff: bool = False,
**kwargs: Any,
) -> Union[JsonDict, list]:
self.assertTrue(ignore_backoff) self.assertTrue(ignore_backoff)
self.assertEqual(destination, server_name) self.assertEqual(destination, server_name)
key_id = "%s:%s" % (signing_key.alg, signing_key.version) key_id = "%s:%s" % (signing_key.alg, signing_key.version)
@ -84,7 +94,8 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase):
Checks that the response is a 200 and returns the decoded json body. Checks that the response is a 200 and returns the decoded json body.
""" """
channel = FakeChannel(self.site, self.reactor) channel = FakeChannel(self.site, self.reactor)
req = SynapseRequest(channel, self.site) # channel is a `FakeChannel` but `HTTPChannel` is expected
req = SynapseRequest(channel, self.site) # type: ignore[arg-type]
req.content = BytesIO(b"") req.content = BytesIO(b"")
req.requestReceived( req.requestReceived(
b"GET", b"GET",
@ -97,7 +108,7 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase):
resp = channel.json_body resp = channel.json_body
return resp return resp
def test_get_key(self): def test_get_key(self) -> None:
"""Fetch a remote key""" """Fetch a remote key"""
SERVER_NAME = "remote.server" SERVER_NAME = "remote.server"
testkey = signedjson.key.generate_signing_key("ver1") testkey = signedjson.key.generate_signing_key("ver1")
@ -114,7 +125,7 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase):
self.assertIn(SERVER_NAME, keys[0]["signatures"]) self.assertIn(SERVER_NAME, keys[0]["signatures"])
self.assertIn(self.hs.hostname, keys[0]["signatures"]) self.assertIn(self.hs.hostname, keys[0]["signatures"])
def test_get_own_key(self): def test_get_own_key(self) -> None:
"""Fetch our own key""" """Fetch our own key"""
testkey = signedjson.key.generate_signing_key("ver1") testkey = signedjson.key.generate_signing_key("ver1")
self.expect_outgoing_key_request(self.hs.hostname, testkey) self.expect_outgoing_key_request(self.hs.hostname, testkey)
@ -141,7 +152,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
endpoint, to check that the two implementations are compatible. endpoint, to check that the two implementations are compatible.
""" """
def default_config(self): def default_config(self) -> Dict[str, Any]:
config = super().default_config() config = super().default_config()
# replace the signing key with our own # replace the signing key with our own
@ -152,7 +163,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
return config return config
def prepare(self, reactor, clock, homeserver): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# make a second homeserver, configured to use the first one as a key notary # make a second homeserver, configured to use the first one as a key notary
self.http_client2 = Mock() self.http_client2 = Mock()
config = default_config(name="keyclient") config = default_config(name="keyclient")
@ -175,7 +186,9 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
# wire up outbound POST /key/v2/query requests from hs2 so that they # wire up outbound POST /key/v2/query requests from hs2 so that they
# will be forwarded to hs1 # will be forwarded to hs1
async def post_json(destination, path, data): async def post_json(
destination: str, path: str, data: Optional[JsonDict] = None
) -> Union[JsonDict, list]:
self.assertEqual(destination, self.hs.hostname) self.assertEqual(destination, self.hs.hostname)
self.assertEqual( self.assertEqual(
path, path,
@ -183,7 +196,8 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
) )
channel = FakeChannel(self.site, self.reactor) channel = FakeChannel(self.site, self.reactor)
req = SynapseRequest(channel, self.site) # channel is a `FakeChannel` but `HTTPChannel` is expected
req = SynapseRequest(channel, self.site) # type: ignore[arg-type]
req.content = BytesIO(encode_canonical_json(data)) req.content = BytesIO(encode_canonical_json(data))
req.requestReceived( req.requestReceived(
@ -198,7 +212,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
self.http_client2.post_json.side_effect = post_json self.http_client2.post_json.side_effect = post_json
def test_get_key(self): def test_get_key(self) -> None:
"""Fetch a key belonging to a random server""" """Fetch a key belonging to a random server"""
# make up a key to be fetched. # make up a key to be fetched.
testkey = signedjson.key.generate_signing_key("abc") testkey = signedjson.key.generate_signing_key("abc")
@ -218,7 +232,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
signedjson.key.encode_verify_key_base64(testkey.verify_key), signedjson.key.encode_verify_key_base64(testkey.verify_key),
) )
def test_get_notary_key(self): def test_get_notary_key(self) -> None:
"""Fetch a key belonging to the notary server""" """Fetch a key belonging to the notary server"""
# make up a key to be fetched. We randomise the keyid to try to get it to # make up a key to be fetched. We randomise the keyid to try to get it to
# appear before the key server signing key sometimes (otherwise we bail out # appear before the key server signing key sometimes (otherwise we bail out
@ -240,7 +254,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
signedjson.key.encode_verify_key_base64(testkey.verify_key), signedjson.key.encode_verify_key_base64(testkey.verify_key),
) )
def test_get_notary_keyserver_key(self): def test_get_notary_keyserver_key(self) -> None:
"""Fetch the notary's keyserver key""" """Fetch the notary's keyserver key"""
# we expect hs1 to make a regular key request to itself # we expect hs1 to make a regular key request to itself
self.expect_outgoing_key_request(self.hs.hostname, self.hs_signing_key) self.expect_outgoing_key_request(self.hs.hostname, self.hs_signing_key)

View file

@ -28,11 +28,11 @@ class GetFileNameFromHeadersTests(unittest.TestCase):
b"inline; filename*=utf-8''foo%C2%A3bar": "foo£bar", b"inline; filename*=utf-8''foo%C2%A3bar": "foo£bar",
} }
def tests(self): def tests(self) -> None:
for hdr, expected in self.TEST_CASES.items(): for hdr, expected in self.TEST_CASES.items():
res = get_filename_from_headers({b"Content-Disposition": [hdr]}) res = get_filename_from_headers({b"Content-Disposition": [hdr]})
self.assertEqual( self.assertEqual(
res, res,
expected, expected,
"expected output for %s to be %s but was %s" % (hdr, expected, res), f"expected output for {hdr!r} to be {expected} but was {res}",
) )

View file

@ -21,12 +21,12 @@ from tests import unittest
class MediaFilePathsTestCase(unittest.TestCase): class MediaFilePathsTestCase(unittest.TestCase):
def setUp(self): def setUp(self) -> None:
super().setUp() super().setUp()
self.filepaths = MediaFilePaths("/media_store") self.filepaths = MediaFilePaths("/media_store")
def test_local_media_filepath(self): def test_local_media_filepath(self) -> None:
"""Test local media paths""" """Test local media paths"""
self.assertEqual( self.assertEqual(
self.filepaths.local_media_filepath_rel("GerZNDnDZVjsOtardLuwfIBg"), self.filepaths.local_media_filepath_rel("GerZNDnDZVjsOtardLuwfIBg"),
@ -37,7 +37,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/local_content/Ge/rZ/NDnDZVjsOtardLuwfIBg", "/media_store/local_content/Ge/rZ/NDnDZVjsOtardLuwfIBg",
) )
def test_local_media_thumbnail(self): def test_local_media_thumbnail(self) -> None:
"""Test local media thumbnail paths""" """Test local media thumbnail paths"""
self.assertEqual( self.assertEqual(
self.filepaths.local_media_thumbnail_rel( self.filepaths.local_media_thumbnail_rel(
@ -52,14 +52,14 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/local_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale", "/media_store/local_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale",
) )
def test_local_media_thumbnail_dir(self): def test_local_media_thumbnail_dir(self) -> None:
"""Test local media thumbnail directory paths""" """Test local media thumbnail directory paths"""
self.assertEqual( self.assertEqual(
self.filepaths.local_media_thumbnail_dir("GerZNDnDZVjsOtardLuwfIBg"), self.filepaths.local_media_thumbnail_dir("GerZNDnDZVjsOtardLuwfIBg"),
"/media_store/local_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg", "/media_store/local_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg",
) )
def test_remote_media_filepath(self): def test_remote_media_filepath(self) -> None:
"""Test remote media paths""" """Test remote media paths"""
self.assertEqual( self.assertEqual(
self.filepaths.remote_media_filepath_rel( self.filepaths.remote_media_filepath_rel(
@ -74,7 +74,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/remote_content/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg", "/media_store/remote_content/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg",
) )
def test_remote_media_thumbnail(self): def test_remote_media_thumbnail(self) -> None:
"""Test remote media thumbnail paths""" """Test remote media thumbnail paths"""
self.assertEqual( self.assertEqual(
self.filepaths.remote_media_thumbnail_rel( self.filepaths.remote_media_thumbnail_rel(
@ -99,7 +99,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale", "/media_store/remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale",
) )
def test_remote_media_thumbnail_legacy(self): def test_remote_media_thumbnail_legacy(self) -> None:
"""Test old-style remote media thumbnail paths""" """Test old-style remote media thumbnail paths"""
self.assertEqual( self.assertEqual(
self.filepaths.remote_media_thumbnail_rel_legacy( self.filepaths.remote_media_thumbnail_rel_legacy(
@ -108,7 +108,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg", "remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg",
) )
def test_remote_media_thumbnail_dir(self): def test_remote_media_thumbnail_dir(self) -> None:
"""Test remote media thumbnail directory paths""" """Test remote media thumbnail directory paths"""
self.assertEqual( self.assertEqual(
self.filepaths.remote_media_thumbnail_dir( self.filepaths.remote_media_thumbnail_dir(
@ -117,7 +117,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg", "/media_store/remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg",
) )
def test_url_cache_filepath(self): def test_url_cache_filepath(self) -> None:
"""Test URL cache paths""" """Test URL cache paths"""
self.assertEqual( self.assertEqual(
self.filepaths.url_cache_filepath_rel("2020-01-02_GerZNDnDZVjsOtar"), self.filepaths.url_cache_filepath_rel("2020-01-02_GerZNDnDZVjsOtar"),
@ -128,7 +128,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/url_cache/2020-01-02/GerZNDnDZVjsOtar", "/media_store/url_cache/2020-01-02/GerZNDnDZVjsOtar",
) )
def test_url_cache_filepath_legacy(self): def test_url_cache_filepath_legacy(self) -> None:
"""Test old-style URL cache paths""" """Test old-style URL cache paths"""
self.assertEqual( self.assertEqual(
self.filepaths.url_cache_filepath_rel("GerZNDnDZVjsOtardLuwfIBg"), self.filepaths.url_cache_filepath_rel("GerZNDnDZVjsOtardLuwfIBg"),
@ -139,7 +139,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/url_cache/Ge/rZ/NDnDZVjsOtardLuwfIBg", "/media_store/url_cache/Ge/rZ/NDnDZVjsOtardLuwfIBg",
) )
def test_url_cache_filepath_dirs_to_delete(self): def test_url_cache_filepath_dirs_to_delete(self) -> None:
"""Test URL cache cleanup paths""" """Test URL cache cleanup paths"""
self.assertEqual( self.assertEqual(
self.filepaths.url_cache_filepath_dirs_to_delete( self.filepaths.url_cache_filepath_dirs_to_delete(
@ -148,7 +148,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
["/media_store/url_cache/2020-01-02"], ["/media_store/url_cache/2020-01-02"],
) )
def test_url_cache_filepath_dirs_to_delete_legacy(self): def test_url_cache_filepath_dirs_to_delete_legacy(self) -> None:
"""Test old-style URL cache cleanup paths""" """Test old-style URL cache cleanup paths"""
self.assertEqual( self.assertEqual(
self.filepaths.url_cache_filepath_dirs_to_delete( self.filepaths.url_cache_filepath_dirs_to_delete(
@ -160,7 +160,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
], ],
) )
def test_url_cache_thumbnail(self): def test_url_cache_thumbnail(self) -> None:
"""Test URL cache thumbnail paths""" """Test URL cache thumbnail paths"""
self.assertEqual( self.assertEqual(
self.filepaths.url_cache_thumbnail_rel( self.filepaths.url_cache_thumbnail_rel(
@ -175,7 +175,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar/800-600-image-jpeg-scale", "/media_store/url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar/800-600-image-jpeg-scale",
) )
def test_url_cache_thumbnail_legacy(self): def test_url_cache_thumbnail_legacy(self) -> None:
"""Test old-style URL cache thumbnail paths""" """Test old-style URL cache thumbnail paths"""
self.assertEqual( self.assertEqual(
self.filepaths.url_cache_thumbnail_rel( self.filepaths.url_cache_thumbnail_rel(
@ -190,7 +190,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale", "/media_store/url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale",
) )
def test_url_cache_thumbnail_directory(self): def test_url_cache_thumbnail_directory(self) -> None:
"""Test URL cache thumbnail directory paths""" """Test URL cache thumbnail directory paths"""
self.assertEqual( self.assertEqual(
self.filepaths.url_cache_thumbnail_directory_rel( self.filepaths.url_cache_thumbnail_directory_rel(
@ -203,7 +203,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar", "/media_store/url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar",
) )
def test_url_cache_thumbnail_directory_legacy(self): def test_url_cache_thumbnail_directory_legacy(self) -> None:
"""Test old-style URL cache thumbnail directory paths""" """Test old-style URL cache thumbnail directory paths"""
self.assertEqual( self.assertEqual(
self.filepaths.url_cache_thumbnail_directory_rel( self.filepaths.url_cache_thumbnail_directory_rel(
@ -216,7 +216,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg", "/media_store/url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg",
) )
def test_url_cache_thumbnail_dirs_to_delete(self): def test_url_cache_thumbnail_dirs_to_delete(self) -> None:
"""Test URL cache thumbnail cleanup paths""" """Test URL cache thumbnail cleanup paths"""
self.assertEqual( self.assertEqual(
self.filepaths.url_cache_thumbnail_dirs_to_delete( self.filepaths.url_cache_thumbnail_dirs_to_delete(
@ -228,7 +228,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
], ],
) )
def test_url_cache_thumbnail_dirs_to_delete_legacy(self): def test_url_cache_thumbnail_dirs_to_delete_legacy(self) -> None:
"""Test old-style URL cache thumbnail cleanup paths""" """Test old-style URL cache thumbnail cleanup paths"""
self.assertEqual( self.assertEqual(
self.filepaths.url_cache_thumbnail_dirs_to_delete( self.filepaths.url_cache_thumbnail_dirs_to_delete(
@ -241,7 +241,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
], ],
) )
def test_server_name_validation(self): def test_server_name_validation(self) -> None:
"""Test validation of server names""" """Test validation of server names"""
self._test_path_validation( self._test_path_validation(
[ [
@ -274,7 +274,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
], ],
) )
def test_file_id_validation(self): def test_file_id_validation(self) -> None:
"""Test validation of local, remote and legacy URL cache file / media IDs""" """Test validation of local, remote and legacy URL cache file / media IDs"""
# File / media IDs get split into three parts to form paths, consisting of the # File / media IDs get split into three parts to form paths, consisting of the
# first two characters, next two characters and rest of the ID. # first two characters, next two characters and rest of the ID.
@ -357,7 +357,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
invalid_values=invalid_file_ids, invalid_values=invalid_file_ids,
) )
def test_url_cache_media_id_validation(self): def test_url_cache_media_id_validation(self) -> None:
"""Test validation of URL cache media IDs""" """Test validation of URL cache media IDs"""
self._test_path_validation( self._test_path_validation(
[ [
@ -387,7 +387,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
], ],
) )
def test_content_type_validation(self): def test_content_type_validation(self) -> None:
"""Test validation of thumbnail content types""" """Test validation of thumbnail content types"""
self._test_path_validation( self._test_path_validation(
[ [
@ -410,7 +410,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
], ],
) )
def test_thumbnail_method_validation(self): def test_thumbnail_method_validation(self) -> None:
"""Test validation of thumbnail methods""" """Test validation of thumbnail methods"""
self._test_path_validation( self._test_path_validation(
[ [
@ -440,7 +440,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
parameter: str, parameter: str,
valid_values: Iterable[str], valid_values: Iterable[str],
invalid_values: Iterable[str], invalid_values: Iterable[str],
): ) -> None:
"""Test that the specified methods validate the named parameter as expected """Test that the specified methods validate the named parameter as expected
Args: Args:

View file

@ -32,7 +32,7 @@ class SummarizeTestCase(unittest.TestCase):
if not lxml: if not lxml:
skip = "url preview feature requires lxml" skip = "url preview feature requires lxml"
def test_long_summarize(self): def test_long_summarize(self) -> None:
example_paras = [ example_paras = [
"""Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami: """Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:
Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in
@ -90,7 +90,7 @@ class SummarizeTestCase(unittest.TestCase):
" Tromsøya had a population of 36,088. Substantial parts of the urban…", " Tromsøya had a population of 36,088. Substantial parts of the urban…",
) )
def test_short_summarize(self): def test_short_summarize(self) -> None:
example_paras = [ example_paras = [
"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
@ -117,7 +117,7 @@ class SummarizeTestCase(unittest.TestCase):
" most of the year.", " most of the year.",
) )
def test_small_then_large_summarize(self): def test_small_then_large_summarize(self) -> None:
example_paras = [ example_paras = [
"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
@ -150,7 +150,7 @@ class CalcOgTestCase(unittest.TestCase):
if not lxml: if not lxml:
skip = "url preview feature requires lxml" skip = "url preview feature requires lxml"
def test_simple(self): def test_simple(self) -> None:
html = b""" html = b"""
<html> <html>
<head><title>Foo</title></head> <head><title>Foo</title></head>
@ -165,7 +165,7 @@ class CalcOgTestCase(unittest.TestCase):
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
def test_comment(self): def test_comment(self) -> None:
html = b""" html = b"""
<html> <html>
<head><title>Foo</title></head> <head><title>Foo</title></head>
@ -181,7 +181,7 @@ class CalcOgTestCase(unittest.TestCase):
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
def test_comment2(self): def test_comment2(self) -> None:
html = b""" html = b"""
<html> <html>
<head><title>Foo</title></head> <head><title>Foo</title></head>
@ -206,7 +206,7 @@ class CalcOgTestCase(unittest.TestCase):
}, },
) )
def test_script(self): def test_script(self) -> None:
html = b""" html = b"""
<html> <html>
<head><title>Foo</title></head> <head><title>Foo</title></head>
@ -222,7 +222,7 @@ class CalcOgTestCase(unittest.TestCase):
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
def test_missing_title(self): def test_missing_title(self) -> None:
html = b""" html = b"""
<html> <html>
<body> <body>
@ -236,7 +236,7 @@ class CalcOgTestCase(unittest.TestCase):
self.assertEqual(og, {"og:title": None, "og:description": "Some text."}) self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
def test_h1_as_title(self): def test_h1_as_title(self) -> None:
html = b""" html = b"""
<html> <html>
<meta property="og:description" content="Some text."/> <meta property="og:description" content="Some text."/>
@ -251,7 +251,7 @@ class CalcOgTestCase(unittest.TestCase):
self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."}) self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."})
def test_missing_title_and_broken_h1(self): def test_missing_title_and_broken_h1(self) -> None:
html = b""" html = b"""
<html> <html>
<body> <body>
@ -266,19 +266,19 @@ class CalcOgTestCase(unittest.TestCase):
self.assertEqual(og, {"og:title": None, "og:description": "Some text."}) self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
def test_empty(self): def test_empty(self) -> None:
"""Test a body with no data in it.""" """Test a body with no data in it."""
html = b"" html = b""
tree = decode_body(html, "http://example.com/test.html") tree = decode_body(html, "http://example.com/test.html")
self.assertIsNone(tree) self.assertIsNone(tree)
def test_no_tree(self): def test_no_tree(self) -> None:
"""A valid body with no tree in it.""" """A valid body with no tree in it."""
html = b"\x00" html = b"\x00"
tree = decode_body(html, "http://example.com/test.html") tree = decode_body(html, "http://example.com/test.html")
self.assertIsNone(tree) self.assertIsNone(tree)
def test_xml(self): def test_xml(self) -> None:
"""Test decoding XML and ensure it works properly.""" """Test decoding XML and ensure it works properly."""
# Note that the strip() call is important to ensure the xml tag starts # Note that the strip() call is important to ensure the xml tag starts
# at the initial byte. # at the initial byte.
@ -293,7 +293,7 @@ class CalcOgTestCase(unittest.TestCase):
og = parse_html_to_open_graph(tree, "http://example.com/test.html") og = parse_html_to_open_graph(tree, "http://example.com/test.html")
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
def test_invalid_encoding(self): def test_invalid_encoding(self) -> None:
"""An invalid character encoding should be ignored and treated as UTF-8, if possible.""" """An invalid character encoding should be ignored and treated as UTF-8, if possible."""
html = b""" html = b"""
<html> <html>
@ -307,7 +307,7 @@ class CalcOgTestCase(unittest.TestCase):
og = parse_html_to_open_graph(tree, "http://example.com/test.html") og = parse_html_to_open_graph(tree, "http://example.com/test.html")
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
def test_invalid_encoding2(self): def test_invalid_encoding2(self) -> None:
"""A body which doesn't match the sent character encoding.""" """A body which doesn't match the sent character encoding."""
# Note that this contains an invalid UTF-8 sequence in the title. # Note that this contains an invalid UTF-8 sequence in the title.
html = b""" html = b"""
@ -322,7 +322,7 @@ class CalcOgTestCase(unittest.TestCase):
og = parse_html_to_open_graph(tree, "http://example.com/test.html") og = parse_html_to_open_graph(tree, "http://example.com/test.html")
self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."}) self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."})
def test_windows_1252(self): def test_windows_1252(self) -> None:
"""A body which uses cp1252, but doesn't declare that.""" """A body which uses cp1252, but doesn't declare that."""
html = b""" html = b"""
<html> <html>
@ -338,7 +338,7 @@ class CalcOgTestCase(unittest.TestCase):
class MediaEncodingTestCase(unittest.TestCase): class MediaEncodingTestCase(unittest.TestCase):
def test_meta_charset(self): def test_meta_charset(self) -> None:
"""A character encoding is found via the meta tag.""" """A character encoding is found via the meta tag."""
encodings = _get_html_media_encodings( encodings = _get_html_media_encodings(
b""" b"""
@ -363,7 +363,7 @@ class MediaEncodingTestCase(unittest.TestCase):
) )
self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"]) self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"])
def test_meta_charset_underscores(self): def test_meta_charset_underscores(self) -> None:
"""A character encoding contains underscore.""" """A character encoding contains underscore."""
encodings = _get_html_media_encodings( encodings = _get_html_media_encodings(
b""" b"""
@ -376,7 +376,7 @@ class MediaEncodingTestCase(unittest.TestCase):
) )
self.assertEqual(list(encodings), ["shift_jis", "utf-8", "cp1252"]) self.assertEqual(list(encodings), ["shift_jis", "utf-8", "cp1252"])
def test_xml_encoding(self): def test_xml_encoding(self) -> None:
"""A character encoding is found via the meta tag.""" """A character encoding is found via the meta tag."""
encodings = _get_html_media_encodings( encodings = _get_html_media_encodings(
b""" b"""
@ -388,7 +388,7 @@ class MediaEncodingTestCase(unittest.TestCase):
) )
self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"]) self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"])
def test_meta_xml_encoding(self): def test_meta_xml_encoding(self) -> None:
"""Meta tags take precedence over XML encoding.""" """Meta tags take precedence over XML encoding."""
encodings = _get_html_media_encodings( encodings = _get_html_media_encodings(
b""" b"""
@ -402,7 +402,7 @@ class MediaEncodingTestCase(unittest.TestCase):
) )
self.assertEqual(list(encodings), ["utf-16", "ascii", "utf-8", "cp1252"]) self.assertEqual(list(encodings), ["utf-16", "ascii", "utf-8", "cp1252"])
def test_content_type(self): def test_content_type(self) -> None:
"""A character encoding is found via the Content-Type header.""" """A character encoding is found via the Content-Type header."""
# Test a few variations of the header. # Test a few variations of the header.
headers = ( headers = (
@ -417,12 +417,12 @@ class MediaEncodingTestCase(unittest.TestCase):
encodings = _get_html_media_encodings(b"", header) encodings = _get_html_media_encodings(b"", header)
self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"]) self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"])
def test_fallback(self): def test_fallback(self) -> None:
"""A character encoding cannot be found in the body or header.""" """A character encoding cannot be found in the body or header."""
encodings = _get_html_media_encodings(b"", "text/html") encodings = _get_html_media_encodings(b"", "text/html")
self.assertEqual(list(encodings), ["utf-8", "cp1252"]) self.assertEqual(list(encodings), ["utf-8", "cp1252"])
def test_duplicates(self): def test_duplicates(self) -> None:
"""Ensure each encoding is only attempted once.""" """Ensure each encoding is only attempted once."""
encodings = _get_html_media_encodings( encodings = _get_html_media_encodings(
b""" b"""
@ -436,7 +436,7 @@ class MediaEncodingTestCase(unittest.TestCase):
) )
self.assertEqual(list(encodings), ["utf-8", "cp1252"]) self.assertEqual(list(encodings), ["utf-8", "cp1252"])
def test_unknown_invalid(self): def test_unknown_invalid(self) -> None:
"""A character encoding should be ignored if it is unknown or invalid.""" """A character encoding should be ignored if it is unknown or invalid."""
encodings = _get_html_media_encodings( encodings = _get_html_media_encodings(
b""" b"""
@ -451,7 +451,7 @@ class MediaEncodingTestCase(unittest.TestCase):
class RebaseUrlTestCase(unittest.TestCase): class RebaseUrlTestCase(unittest.TestCase):
def test_relative(self): def test_relative(self) -> None:
"""Relative URLs should be resolved based on the context of the base URL.""" """Relative URLs should be resolved based on the context of the base URL."""
self.assertEqual( self.assertEqual(
rebase_url("subpage", "https://example.com/foo/"), rebase_url("subpage", "https://example.com/foo/"),
@ -466,14 +466,14 @@ class RebaseUrlTestCase(unittest.TestCase):
"https://example.com/bar", "https://example.com/bar",
) )
def test_absolute(self): def test_absolute(self) -> None:
"""Absolute URLs should not be modified.""" """Absolute URLs should not be modified."""
self.assertEqual( self.assertEqual(
rebase_url("https://alice.com/a/", "https://example.com/foo/"), rebase_url("https://alice.com/a/", "https://example.com/foo/"),
"https://alice.com/a/", "https://alice.com/a/",
) )
def test_data(self): def test_data(self) -> None:
"""Data URLs should not be modified.""" """Data URLs should not be modified."""
self.assertEqual( self.assertEqual(
rebase_url("data:,Hello%2C%20World%21", "https://example.com/foo/"), rebase_url("data:,Hello%2C%20World%21", "https://example.com/foo/"),

View file

@ -16,7 +16,7 @@ import json
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
from synapse.rest.media.v1.oembed import OEmbedProvider from synapse.rest.media.v1.oembed import OEmbedProvider, OEmbedResult
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import Clock from synapse.util import Clock
@ -25,15 +25,15 @@ from tests.unittest import HomeserverTestCase
class OEmbedTests(HomeserverTestCase): class OEmbedTests(HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.oembed = OEmbedProvider(homeserver) self.oembed = OEmbedProvider(hs)
def parse_response(self, response: JsonDict): def parse_response(self, response: JsonDict) -> OEmbedResult:
return self.oembed.parse_oembed_response( return self.oembed.parse_oembed_response(
"https://test", json.dumps(response).encode("utf-8") "https://test", json.dumps(response).encode("utf-8")
) )
def test_version(self): def test_version(self) -> None:
"""Accept versions that are similar to 1.0 as a string or int (or missing).""" """Accept versions that are similar to 1.0 as a string or int (or missing)."""
for version in ("1.0", 1.0, 1): for version in ("1.0", 1.0, 1):
result = self.parse_response({"version": version, "type": "link"}) result = self.parse_response({"version": version, "type": "link"})

View file

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from http import HTTPStatus
from synapse.rest.health import HealthResource from synapse.rest.health import HealthResource
@ -19,12 +19,12 @@ from tests import unittest
class HealthCheckTests(unittest.HomeserverTestCase): class HealthCheckTests(unittest.HomeserverTestCase):
def create_test_resource(self): def create_test_resource(self) -> HealthResource:
# replace the JsonResource with a HealthResource. # replace the JsonResource with a HealthResource.
return HealthResource() return HealthResource()
def test_health(self): def test_health(self) -> None:
channel = self.make_request("GET", "/health", shorthand=False) channel = self.make_request("GET", "/health", shorthand=False)
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, HTTPStatus.OK)
self.assertEqual(channel.result["body"], b"OK") self.assertEqual(channel.result["body"], b"OK")

View file

@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from http import HTTPStatus
from twisted.web.resource import Resource from twisted.web.resource import Resource
from synapse.rest.well_known import well_known_resource from synapse.rest.well_known import well_known_resource
@ -19,7 +21,7 @@ from tests import unittest
class WellKnownTests(unittest.HomeserverTestCase): class WellKnownTests(unittest.HomeserverTestCase):
def create_test_resource(self): def create_test_resource(self) -> Resource:
# replace the JsonResource with a Resource wrapping the WellKnownResource # replace the JsonResource with a Resource wrapping the WellKnownResource
res = Resource() res = Resource()
res.putChild(b".well-known", well_known_resource(self.hs)) res.putChild(b".well-known", well_known_resource(self.hs))
@ -31,12 +33,12 @@ class WellKnownTests(unittest.HomeserverTestCase):
"default_identity_server": "https://testis", "default_identity_server": "https://testis",
} }
) )
def test_client_well_known(self): def test_client_well_known(self) -> None:
channel = self.make_request( channel = self.make_request(
"GET", "/.well-known/matrix/client", shorthand=False "GET", "/.well-known/matrix/client", shorthand=False
) )
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, HTTPStatus.OK)
self.assertEqual( self.assertEqual(
channel.json_body, channel.json_body,
{ {
@ -50,27 +52,27 @@ class WellKnownTests(unittest.HomeserverTestCase):
"public_baseurl": None, "public_baseurl": None,
} }
) )
def test_client_well_known_no_public_baseurl(self): def test_client_well_known_no_public_baseurl(self) -> None:
channel = self.make_request( channel = self.make_request(
"GET", "/.well-known/matrix/client", shorthand=False "GET", "/.well-known/matrix/client", shorthand=False
) )
self.assertEqual(channel.code, 404) self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)
@unittest.override_config({"serve_server_wellknown": True}) @unittest.override_config({"serve_server_wellknown": True})
def test_server_well_known(self): def test_server_well_known(self) -> None:
channel = self.make_request( channel = self.make_request(
"GET", "/.well-known/matrix/server", shorthand=False "GET", "/.well-known/matrix/server", shorthand=False
) )
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, HTTPStatus.OK)
self.assertEqual( self.assertEqual(
channel.json_body, channel.json_body,
{"m.server": "test:443"}, {"m.server": "test:443"},
) )
def test_server_well_known_disabled(self): def test_server_well_known_disabled(self) -> None:
channel = self.make_request( channel = self.make_request(
"GET", "/.well-known/matrix/server", shorthand=False "GET", "/.well-known/matrix/server", shorthand=False
) )
self.assertEqual(channel.code, 404) self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)