0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-09-26 19:49:01 +02:00

Add type hints to tests/rest/admin (#11851)

This commit is contained in:
Dirk Klimpel 2022-01-31 20:20:05 +01:00 committed by GitHub
parent 0da2301b21
commit 901b264c0c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 184 additions and 232 deletions

1
changelog.d/11851.misc Normal file
View file

@ -0,0 +1 @@
Add type hints to `tests/rest/admin`.

View file

@ -77,9 +77,6 @@ exclude = (?x)
|tests/push/test_http.py |tests/push/test_http.py
|tests/push/test_presentable_names.py |tests/push/test_presentable_names.py
|tests/push/test_push_rule_evaluator.py |tests/push/test_push_rule_evaluator.py
|tests/rest/admin/test_admin.py
|tests/rest/admin/test_user.py
|tests/rest/admin/test_username_available.py
|tests/rest/client/test_account.py |tests/rest/client/test_account.py
|tests/rest/client/test_events.py |tests/rest/client/test_events.py
|tests/rest/client/test_filter.py |tests/rest/client/test_filter.py

View file

@ -12,18 +12,20 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import urllib.parse import urllib.parse
from http import HTTPStatus from http import HTTPStatus
from unittest.mock import Mock from typing import List
from twisted.internet.defer import Deferred from parameterized import parameterized
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin import synapse.rest.admin
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
from synapse.logging.context import make_deferred_yieldable
from synapse.rest.admin import VersionServlet from synapse.rest.admin import VersionServlet
from synapse.rest.client import groups, login, room from synapse.rest.client import groups, login, room
from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.server import FakeSite, make_request from tests.server import FakeSite, make_request
@ -33,12 +35,12 @@ from tests.test_utils import SMALL_PNG
class VersionTestCase(unittest.HomeserverTestCase): class VersionTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/server_version" url = "/_synapse/admin/v1/server_version"
def create_test_resource(self): def create_test_resource(self) -> JsonResource:
resource = JsonResource(self.hs) resource = JsonResource(self.hs)
VersionServlet(self.hs).register(resource) VersionServlet(self.hs).register(resource)
return resource return resource
def test_version_string(self): def test_version_string(self) -> None:
channel = self.make_request("GET", self.url, shorthand=False) channel = self.make_request("GET", self.url, shorthand=False)
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
@ -54,14 +56,14 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
groups.register_servlets, groups.register_servlets,
] ]
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass") self.admin_user_tok = self.login("admin", "pass")
self.other_user = self.register_user("user", "pass") self.other_user = self.register_user("user", "pass")
self.other_user_token = self.login("user", "pass") self.other_user_token = self.login("user", "pass")
def test_delete_group(self): def test_delete_group(self) -> None:
# Create a new group # Create a new group
channel = self.make_request( channel = self.make_request(
"POST", "POST",
@ -112,7 +114,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
self.assertNotIn(group_id, self._get_groups_user_is_in(self.admin_user_tok)) self.assertNotIn(group_id, self._get_groups_user_is_in(self.admin_user_tok))
self.assertNotIn(group_id, self._get_groups_user_is_in(self.other_user_token)) self.assertNotIn(group_id, self._get_groups_user_is_in(self.other_user_token))
def _check_group(self, group_id, expect_code): def _check_group(self, group_id: str, expect_code: int) -> None:
"""Assert that trying to fetch the given group results in the given """Assert that trying to fetch the given group results in the given
HTTP status code HTTP status code
""" """
@ -124,7 +126,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
self.assertEqual(expect_code, channel.code, msg=channel.json_body) self.assertEqual(expect_code, channel.code, msg=channel.json_body)
def _get_groups_user_is_in(self, access_token): def _get_groups_user_is_in(self, access_token: str) -> List[str]:
"""Returns the list of groups the user is in (given their access token)""" """Returns the list of groups the user is in (given their access token)"""
channel = self.make_request("GET", b"/joined_groups", access_token=access_token) channel = self.make_request("GET", b"/joined_groups", access_token=access_token)
@ -143,59 +145,15 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
room.register_servlets, room.register_servlets,
] ]
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# Allow for uploading and downloading to/from the media repo # Allow for uploading and downloading to/from the media repo
self.media_repo = hs.get_media_repository_resource() self.media_repo = hs.get_media_repository_resource()
self.download_resource = self.media_repo.children[b"download"] self.download_resource = self.media_repo.children[b"download"]
self.upload_resource = self.media_repo.children[b"upload"] self.upload_resource = self.media_repo.children[b"upload"]
def make_homeserver(self, reactor, clock): def _ensure_quarantined(
self, admin_user_tok: str, server_and_media_id: str
self.fetches = [] ) -> None:
async def get_file(destination, path, output_stream, args=None, max_size=None):
"""
Returns tuple[int,dict,str,int] of file length, response headers,
absolute URI, and response code.
"""
def write_to(r):
data, response = r
output_stream.write(data)
return response
d = Deferred()
d.addCallback(write_to)
self.fetches.append((d, destination, path, args))
return await make_deferred_yieldable(d)
client = Mock()
client.get_file = get_file
self.storage_path = self.mktemp()
self.media_store_path = self.mktemp()
os.mkdir(self.storage_path)
os.mkdir(self.media_store_path)
config = self.default_config()
config["media_store_path"] = self.media_store_path
config["thumbnail_requirements"] = {}
config["max_image_pixels"] = 2000000
provider_config = {
"module": "synapse.rest.media.v1.storage_provider.FileStorageProviderBackend",
"store_local": True,
"store_synchronous": False,
"store_remote": True,
"config": {"directory": self.storage_path},
}
config["media_storage_providers"] = [provider_config]
hs = self.setup_test_homeserver(config=config, federation_http_client=client)
return hs
def _ensure_quarantined(self, admin_user_tok, server_and_media_id):
"""Ensure a piece of media is quarantined when trying to access it.""" """Ensure a piece of media is quarantined when trying to access it."""
channel = make_request( channel = make_request(
self.reactor, self.reactor,
@ -216,12 +174,18 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
), ),
) )
def test_quarantine_media_requires_admin(self): @parameterized.expand(
[
# Attempt quarantine media APIs as non-admin
"/_synapse/admin/v1/media/quarantine/example.org/abcde12345",
# And the roomID/userID endpoint
"/_synapse/admin/v1/room/!room%3Aexample.com/media/quarantine",
]
)
def test_quarantine_media_requires_admin(self, url: str) -> None:
self.register_user("nonadmin", "pass", admin=False) self.register_user("nonadmin", "pass", admin=False)
non_admin_user_tok = self.login("nonadmin", "pass") non_admin_user_tok = self.login("nonadmin", "pass")
# Attempt quarantine media APIs as non-admin
url = "/_synapse/admin/v1/media/quarantine/example.org/abcde12345"
channel = self.make_request( channel = self.make_request(
"POST", "POST",
url.encode("ascii"), url.encode("ascii"),
@ -235,22 +199,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
msg="Expected forbidden on quarantining media as a non-admin", msg="Expected forbidden on quarantining media as a non-admin",
) )
# And the roomID/userID endpoint def test_quarantine_media_by_id(self) -> None:
url = "/_synapse/admin/v1/room/!room%3Aexample.com/media/quarantine"
channel = self.make_request(
"POST",
url.encode("ascii"),
access_token=non_admin_user_tok,
)
# Expect a forbidden error
self.assertEqual(
HTTPStatus.FORBIDDEN,
channel.code,
msg="Expected forbidden on quarantining media as a non-admin",
)
def test_quarantine_media_by_id(self):
self.register_user("id_admin", "pass", admin=True) self.register_user("id_admin", "pass", admin=True)
admin_user_tok = self.login("id_admin", "pass") admin_user_tok = self.login("id_admin", "pass")
@ -295,7 +244,15 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Attempt to access the media # Attempt to access the media
self._ensure_quarantined(admin_user_tok, server_name_and_media_id) self._ensure_quarantined(admin_user_tok, server_name_and_media_id)
def test_quarantine_all_media_in_room(self, override_url_template=None): @parameterized.expand(
[
# regular API path
"/_synapse/admin/v1/room/%s/media/quarantine",
# deprecated API path
"/_synapse/admin/v1/quarantine_media/%s",
]
)
def test_quarantine_all_media_in_room(self, url: str) -> None:
self.register_user("room_admin", "pass", admin=True) self.register_user("room_admin", "pass", admin=True)
admin_user_tok = self.login("room_admin", "pass") admin_user_tok = self.login("room_admin", "pass")
@ -333,16 +290,9 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
tok=non_admin_user_tok, tok=non_admin_user_tok,
) )
# Quarantine all media in the room
if override_url_template:
url = override_url_template % urllib.parse.quote(room_id)
else:
url = "/_synapse/admin/v1/room/%s/media/quarantine" % urllib.parse.quote(
room_id
)
channel = self.make_request( channel = self.make_request(
"POST", "POST",
url, url % urllib.parse.quote(room_id),
access_token=admin_user_tok, access_token=admin_user_tok,
) )
self.pump(1.0) self.pump(1.0)
@ -359,11 +309,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
self._ensure_quarantined(admin_user_tok, server_and_media_id_1) self._ensure_quarantined(admin_user_tok, server_and_media_id_1)
self._ensure_quarantined(admin_user_tok, server_and_media_id_2) self._ensure_quarantined(admin_user_tok, server_and_media_id_2)
def test_quarantine_all_media_in_room_deprecated_api_path(self): def test_quarantine_all_media_by_user(self) -> None:
# Perform the above test with the deprecated API path
self.test_quarantine_all_media_in_room("/_synapse/admin/v1/quarantine_media/%s")
def test_quarantine_all_media_by_user(self):
self.register_user("user_admin", "pass", admin=True) self.register_user("user_admin", "pass", admin=True)
admin_user_tok = self.login("user_admin", "pass") admin_user_tok = self.login("user_admin", "pass")
@ -401,7 +347,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
self._ensure_quarantined(admin_user_tok, server_and_media_id_1) self._ensure_quarantined(admin_user_tok, server_and_media_id_1)
self._ensure_quarantined(admin_user_tok, server_and_media_id_2) self._ensure_quarantined(admin_user_tok, server_and_media_id_2)
def test_cannot_quarantine_safe_media(self): def test_cannot_quarantine_safe_media(self) -> None:
self.register_user("user_admin", "pass", admin=True) self.register_user("user_admin", "pass", admin=True)
admin_user_tok = self.login("user_admin", "pass") admin_user_tok = self.login("user_admin", "pass")
@ -475,7 +421,7 @@ class PurgeHistoryTestCase(unittest.HomeserverTestCase):
room.register_servlets, room.register_servlets,
] ]
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass") self.admin_user_tok = self.login("admin", "pass")
@ -488,7 +434,7 @@ class PurgeHistoryTestCase(unittest.HomeserverTestCase):
self.url = f"/_synapse/admin/v1/purge_history/{self.room_id}" self.url = f"/_synapse/admin/v1/purge_history/{self.room_id}"
self.url_status = "/_synapse/admin/v1/purge_history_status/" self.url_status = "/_synapse/admin/v1/purge_history_status/"
def test_purge_history(self): def test_purge_history(self) -> None:
""" """
Simple test of purge history API. Simple test of purge history API.
Test only that is is possible to call, get status HTTPStatus.OK and purge_id. Test only that is is possible to call, get status HTTPStatus.OK and purge_id.

File diff suppressed because it is too large Load diff

View file

@ -14,9 +14,13 @@
from http import HTTPStatus from http import HTTPStatus
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin import synapse.rest.admin
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.rest.client import login from synapse.rest.client import login
from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest from tests import unittest
@ -28,11 +32,11 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase):
] ]
url = "/_synapse/admin/v1/username_available" url = "/_synapse/admin/v1/username_available"
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.register_user("admin", "pass", admin=True) self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass") self.admin_user_tok = self.login("admin", "pass")
async def check_username(username): async def check_username(username: str) -> bool:
if username == "allowed": if username == "allowed":
return True return True
raise SynapseError( raise SynapseError(
@ -44,24 +48,24 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase):
handler = self.hs.get_registration_handler() handler = self.hs.get_registration_handler()
handler.check_username = check_username handler.check_username = check_username
def test_username_available(self): def test_username_available(self) -> None:
""" """
The endpoint should return a HTTPStatus.OK response if the username does not exist The endpoint should return a HTTPStatus.OK response if the username does not exist
""" """
url = "%s?username=%s" % (self.url, "allowed") url = "%s?username=%s" % (self.url, "allowed")
channel = self.make_request("GET", url, None, self.admin_user_tok) channel = self.make_request("GET", url, access_token=self.admin_user_tok)
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertTrue(channel.json_body["available"]) self.assertTrue(channel.json_body["available"])
def test_username_unavailable(self): def test_username_unavailable(self) -> None:
""" """
The endpoint should return a HTTPStatus.OK response if the username does not exist The endpoint should return a HTTPStatus.OK response if the username does not exist
""" """
url = "%s?username=%s" % (self.url, "disallowed") url = "%s?username=%s" % (self.url, "disallowed")
channel = self.make_request("GET", url, None, self.admin_user_tok) channel = self.make_request("GET", url, access_token=self.admin_user_tok)
self.assertEqual( self.assertEqual(
HTTPStatus.BAD_REQUEST, HTTPStatus.BAD_REQUEST,