Clean up synapse.rest.admin (#11535)

This commit is contained in:
Dirk Klimpel 2021-12-08 17:59:40 +01:00 committed by GitHub
parent 83a74d9350
commit 7ecaa3b976
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 96 additions and 165 deletions

1
changelog.d/11535.misc Normal file
View file

@ -0,0 +1 @@
Clean up `synapse.rest.admin`.

View file

@ -108,7 +108,7 @@ class VersionServlet(RestServlet):
class PurgeHistoryRestServlet(RestServlet): class PurgeHistoryRestServlet(RestServlet):
PATTERNS = admin_patterns( PATTERNS = admin_patterns(
"/purge_history/(?P<room_id>[^/]*)(/(?P<event_id>[^/]+))?" "/purge_history/(?P<room_id>[^/]*)(/(?P<event_id>[^/]*))?$"
) )
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
@ -195,7 +195,7 @@ class PurgeHistoryRestServlet(RestServlet):
class PurgeHistoryStatusRestServlet(RestServlet): class PurgeHistoryStatusRestServlet(RestServlet):
PATTERNS = admin_patterns("/purge_history_status/(?P<purge_id>[^/]+)") PATTERNS = admin_patterns("/purge_history_status/(?P<purge_id>[^/]*)$")
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.pagination_handler = hs.get_pagination_handler() self.pagination_handler = hs.get_pagination_handler()

View file

@ -22,7 +22,7 @@ from synapse.http.servlet import (
parse_json_object_from_request, parse_json_object_from_request,
) )
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import admin_patterns, assert_user_is_admin from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
from synapse.types import JsonDict from synapse.types import JsonDict
if TYPE_CHECKING: if TYPE_CHECKING:
@ -41,8 +41,7 @@ class BackgroundUpdateEnabledRestServlet(RestServlet):
self._data_stores = hs.get_datastores() self._data_stores = hs.get_datastores()
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self._auth.get_user_by_req(request) await assert_requester_is_admin(self._auth, request)
await assert_user_is_admin(self._auth, requester.user)
# We need to check that all configured databases have updates enabled. # We need to check that all configured databases have updates enabled.
# (They *should* all be in sync.) # (They *should* all be in sync.)
@ -51,8 +50,7 @@ class BackgroundUpdateEnabledRestServlet(RestServlet):
return HTTPStatus.OK, {"enabled": enabled} return HTTPStatus.OK, {"enabled": enabled}
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self._auth.get_user_by_req(request) await assert_requester_is_admin(self._auth, request)
await assert_user_is_admin(self._auth, requester.user)
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
@ -84,8 +82,7 @@ class BackgroundUpdateRestServlet(RestServlet):
self._data_stores = hs.get_datastores() self._data_stores = hs.get_datastores()
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self._auth.get_user_by_req(request) await assert_requester_is_admin(self._auth, request)
await assert_user_is_admin(self._auth, requester.user)
# We need to check that all configured databases have updates enabled. # We need to check that all configured databases have updates enabled.
# (They *should* all be in sync.) # (They *should* all be in sync.)
@ -111,15 +108,14 @@ class BackgroundUpdateRestServlet(RestServlet):
class BackgroundUpdateStartJobRestServlet(RestServlet): class BackgroundUpdateStartJobRestServlet(RestServlet):
"""Allows to start specific background updates""" """Allows to start specific background updates"""
PATTERNS = admin_patterns("/background_updates/start_job") PATTERNS = admin_patterns("/background_updates/start_job$")
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self._auth = hs.get_auth() self._auth = hs.get_auth()
self._store = hs.get_datastore() self._store = hs.get_datastore()
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self._auth.get_user_by_req(request) await assert_requester_is_admin(self._auth, request)
await assert_user_is_admin(self._auth, requester.user)
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
assert_params_in_dict(body, ["job_name"]) assert_params_in_dict(body, ["job_name"])

View file

@ -42,10 +42,10 @@ class DeviceRestServlet(RestServlet):
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.is_mine = hs.is_mine
async def on_GET( async def on_GET(
self, request: SynapseRequest, user_id: str, device_id: str self, request: SynapseRequest, user_id: str, device_id: str
@ -53,7 +53,7 @@ class DeviceRestServlet(RestServlet):
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
if not self.hs.is_mine(target_user): if not self.is_mine(target_user):
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
u = await self.store.get_user_by_id(target_user.to_string()) u = await self.store.get_user_by_id(target_user.to_string())
@ -71,7 +71,7 @@ class DeviceRestServlet(RestServlet):
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
if not self.hs.is_mine(target_user): if not self.is_mine(target_user):
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
u = await self.store.get_user_by_id(target_user.to_string()) u = await self.store.get_user_by_id(target_user.to_string())
@ -87,7 +87,7 @@ class DeviceRestServlet(RestServlet):
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
if not self.hs.is_mine(target_user): if not self.is_mine(target_user):
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
u = await self.store.get_user_by_id(target_user.to_string()) u = await self.store.get_user_by_id(target_user.to_string())
@ -109,14 +109,10 @@ class DevicesRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/devices$", "v2") PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/devices$", "v2")
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
"""
Args:
hs: server
"""
self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.is_mine = hs.is_mine
async def on_GET( async def on_GET(
self, request: SynapseRequest, user_id: str self, request: SynapseRequest, user_id: str
@ -124,7 +120,7 @@ class DevicesRestServlet(RestServlet):
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
if not self.hs.is_mine(target_user): if not self.is_mine(target_user):
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
u = await self.store.get_user_by_id(target_user.to_string()) u = await self.store.get_user_by_id(target_user.to_string())
@ -144,10 +140,10 @@ class DeleteDevicesRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/delete_devices$", "v2") PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/delete_devices$", "v2")
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.is_mine = hs.is_mine
async def on_POST( async def on_POST(
self, request: SynapseRequest, user_id: str self, request: SynapseRequest, user_id: str
@ -155,7 +151,7 @@ class DeleteDevicesRestServlet(RestServlet):
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
if not self.hs.is_mine(target_user): if not self.is_mine(target_user):
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
u = await self.store.get_user_by_id(target_user.to_string()) u = await self.store.get_user_by_id(target_user.to_string())

View file

@ -52,7 +52,6 @@ class EventReportsRestServlet(RestServlet):
PATTERNS = admin_patterns("/event_reports$") PATTERNS = admin_patterns("/event_reports$")
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -115,7 +114,6 @@ class EventReportDetailRestServlet(RestServlet):
PATTERNS = admin_patterns("/event_reports/(?P<report_id>[^/]*)$") PATTERNS = admin_patterns("/event_reports/(?P<report_id>[^/]*)$")
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()

View file

@ -100,7 +100,7 @@ class DestinationsRestServlet(RestServlet):
200 OK with details of a destination if success otherwise an error. 200 OK with details of a destination if success otherwise an error.
""" """
PATTERNS = admin_patterns("/federation/destinations/(?P<destination>[^/]+)$") PATTERNS = admin_patterns("/federation/destinations/(?P<destination>[^/]*)$")
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self._auth = hs.get_auth() self._auth = hs.get_auth()

View file

@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
class DeleteGroupAdminRestServlet(RestServlet): class DeleteGroupAdminRestServlet(RestServlet):
"""Allows deleting of local groups""" """Allows deleting of local groups"""
PATTERNS = admin_patterns("/delete_group/(?P<group_id>[^/]*)") PATTERNS = admin_patterns("/delete_group/(?P<group_id>[^/]*)$")
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.group_server = hs.get_groups_server_handler() self.group_server = hs.get_groups_server_handler()

View file

@ -17,7 +17,7 @@ import logging
from http import HTTPStatus from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.server import HttpServer from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
@ -41,9 +41,9 @@ class QuarantineMediaInRoom(RestServlet):
""" """
PATTERNS = [ PATTERNS = [
*admin_patterns("/room/(?P<room_id>[^/]+)/media/quarantine$"), *admin_patterns("/room/(?P<room_id>[^/]*)/media/quarantine$"),
# This path kept around for legacy reasons # This path kept around for legacy reasons
*admin_patterns("/quarantine_media/(?P<room_id>[^/]+)"), *admin_patterns("/quarantine_media/(?P<room_id>[^/]*)$"),
] ]
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
@ -71,7 +71,7 @@ class QuarantineMediaByUser(RestServlet):
this server. this server.
""" """
PATTERNS = admin_patterns("/user/(?P<user_id>[^/]+)/media/quarantine$") PATTERNS = admin_patterns("/user/(?P<user_id>[^/]*)/media/quarantine$")
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -99,7 +99,7 @@ class QuarantineMediaByID(RestServlet):
""" """
PATTERNS = admin_patterns( PATTERNS = admin_patterns(
"/media/quarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)" "/media/quarantine/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$"
) )
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
@ -128,7 +128,7 @@ class UnquarantineMediaByID(RestServlet):
""" """
PATTERNS = admin_patterns( PATTERNS = admin_patterns(
"/media/unquarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)" "/media/unquarantine/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$"
) )
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
@ -138,8 +138,7 @@ class UnquarantineMediaByID(RestServlet):
async def on_POST( async def on_POST(
self, request: SynapseRequest, server_name: str, media_id: str self, request: SynapseRequest, server_name: str, media_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) await assert_requester_is_admin(self.auth, request)
await assert_user_is_admin(self.auth, requester.user)
logging.info( logging.info(
"Remove from quarantine local media by ID: %s/%s", server_name, media_id "Remove from quarantine local media by ID: %s/%s", server_name, media_id
@ -154,7 +153,7 @@ class UnquarantineMediaByID(RestServlet):
class ProtectMediaByID(RestServlet): class ProtectMediaByID(RestServlet):
"""Protect local media from being quarantined.""" """Protect local media from being quarantined."""
PATTERNS = admin_patterns("/media/protect/(?P<media_id>[^/]+)") PATTERNS = admin_patterns("/media/protect/(?P<media_id>[^/]*)$")
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -163,8 +162,7 @@ class ProtectMediaByID(RestServlet):
async def on_POST( async def on_POST(
self, request: SynapseRequest, media_id: str self, request: SynapseRequest, media_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) await assert_requester_is_admin(self.auth, request)
await assert_user_is_admin(self.auth, requester.user)
logging.info("Protecting local media by ID: %s", media_id) logging.info("Protecting local media by ID: %s", media_id)
@ -177,7 +175,7 @@ class ProtectMediaByID(RestServlet):
class UnprotectMediaByID(RestServlet): class UnprotectMediaByID(RestServlet):
"""Unprotect local media from being quarantined.""" """Unprotect local media from being quarantined."""
PATTERNS = admin_patterns("/media/unprotect/(?P<media_id>[^/]+)") PATTERNS = admin_patterns("/media/unprotect/(?P<media_id>[^/]*)$")
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -186,8 +184,7 @@ class UnprotectMediaByID(RestServlet):
async def on_POST( async def on_POST(
self, request: SynapseRequest, media_id: str self, request: SynapseRequest, media_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) await assert_requester_is_admin(self.auth, request)
await assert_user_is_admin(self.auth, requester.user)
logging.info("Unprotecting local media by ID: %s", media_id) logging.info("Unprotecting local media by ID: %s", media_id)
@ -200,7 +197,7 @@ class UnprotectMediaByID(RestServlet):
class ListMediaInRoom(RestServlet): class ListMediaInRoom(RestServlet):
"""Lists all of the media in a given room.""" """Lists all of the media in a given room."""
PATTERNS = admin_patterns("/room/(?P<room_id>[^/]+)/media$") PATTERNS = admin_patterns("/room/(?P<room_id>[^/]*)/media$")
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -209,10 +206,7 @@ class ListMediaInRoom(RestServlet):
async def on_GET( async def on_GET(
self, request: SynapseRequest, room_id: str self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) await assert_requester_is_admin(self.auth, request)
is_admin = await self.auth.is_server_admin(requester.user)
if not is_admin:
raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin")
local_mxcs, remote_mxcs = await self.store.get_media_mxcs_in_room(room_id) local_mxcs, remote_mxcs = await self.store.get_media_mxcs_in_room(room_id)
@ -254,7 +248,7 @@ class PurgeMediaCacheRestServlet(RestServlet):
class DeleteMediaByID(RestServlet): class DeleteMediaByID(RestServlet):
"""Delete local media by a given ID. Removes it from this server.""" """Delete local media by a given ID. Removes it from this server."""
PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)") PATTERNS = admin_patterns("/media/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$")
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -286,7 +280,7 @@ class DeleteMediaByDateSize(RestServlet):
timestamp and size. timestamp and size.
""" """
PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/delete$") PATTERNS = admin_patterns("/media/(?P<server_name>[^/]*)/delete$")
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -353,7 +347,7 @@ class UserMediaRestServlet(RestServlet):
media that exist given for this user media that exist given for this user
""" """
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/media$") PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/media$")
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.is_mine = hs.is_mine self.is_mine = hs.is_mine
@ -403,16 +397,7 @@ class UserMediaRestServlet(RestServlet):
request, request,
"order_by", "order_by",
default=MediaSortOrder.CREATED_TS.value, default=MediaSortOrder.CREATED_TS.value,
allowed_values=( allowed_values=[sort_order.value for sort_order in MediaSortOrder],
MediaSortOrder.MEDIA_ID.value,
MediaSortOrder.UPLOAD_NAME.value,
MediaSortOrder.CREATED_TS.value,
MediaSortOrder.LAST_ACCESS_TS.value,
MediaSortOrder.MEDIA_LENGTH.value,
MediaSortOrder.MEDIA_TYPE.value,
MediaSortOrder.QUARANTINED_BY.value,
MediaSortOrder.SAFE_FROM_QUARANTINE.value,
),
) )
direction = parse_string( direction = parse_string(
request, "dir", default="f", allowed_values=("f", "b") request, "dir", default="f", allowed_values=("f", "b")
@ -470,16 +455,7 @@ class UserMediaRestServlet(RestServlet):
request, request,
"order_by", "order_by",
default=MediaSortOrder.CREATED_TS.value, default=MediaSortOrder.CREATED_TS.value,
allowed_values=( allowed_values=[sort_order.value for sort_order in MediaSortOrder],
MediaSortOrder.MEDIA_ID.value,
MediaSortOrder.UPLOAD_NAME.value,
MediaSortOrder.CREATED_TS.value,
MediaSortOrder.LAST_ACCESS_TS.value,
MediaSortOrder.MEDIA_LENGTH.value,
MediaSortOrder.MEDIA_TYPE.value,
MediaSortOrder.QUARANTINED_BY.value,
MediaSortOrder.SAFE_FROM_QUARANTINE.value,
),
) )
direction = parse_string( direction = parse_string(
request, "dir", default="f", allowed_values=("f", "b") request, "dir", default="f", allowed_values=("f", "b")

View file

@ -70,7 +70,6 @@ class ListRegistrationTokensRestServlet(RestServlet):
PATTERNS = admin_patterns("/registration_tokens$") PATTERNS = admin_patterns("/registration_tokens$")
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -109,7 +108,6 @@ class NewRegistrationTokenRestServlet(RestServlet):
PATTERNS = admin_patterns("/registration_tokens/new$") PATTERNS = admin_patterns("/registration_tokens/new$")
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.clock = hs.get_clock() self.clock = hs.get_clock()
@ -260,7 +258,6 @@ class RegistrationTokenRestServlet(RestServlet):
PATTERNS = admin_patterns("/registration_tokens/(?P<token>[^/]*)$") PATTERNS = admin_patterns("/registration_tokens/(?P<token>[^/]*)$")
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.hs = hs
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()

View file

@ -61,7 +61,7 @@ class RoomRestV2Servlet(RestServlet):
If 'purge' is true, it will remove all traces of a room from the database. If 'purge' is true, it will remove all traces of a room from the database.
""" """
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)$", "v2") PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)$", "v2")
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self._auth = hs.get_auth() self._auth = hs.get_auth()
@ -123,7 +123,7 @@ class RoomRestV2Servlet(RestServlet):
class DeleteRoomStatusByRoomIdRestServlet(RestServlet): class DeleteRoomStatusByRoomIdRestServlet(RestServlet):
"""Get the status of the delete room background task.""" """Get the status of the delete room background task."""
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/delete_status$", "v2") PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/delete_status$", "v2")
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self._auth = hs.get_auth() self._auth = hs.get_auth()
@ -160,7 +160,7 @@ class DeleteRoomStatusByRoomIdRestServlet(RestServlet):
class DeleteRoomStatusByDeleteIdRestServlet(RestServlet): class DeleteRoomStatusByDeleteIdRestServlet(RestServlet):
"""Get the status of the delete room background task.""" """Get the status of the delete room background task."""
PATTERNS = admin_patterns("/rooms/delete_status/(?P<delete_id>[^/]+)$", "v2") PATTERNS = admin_patterns("/rooms/delete_status/(?P<delete_id>[^/]*)$", "v2")
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self._auth = hs.get_auth() self._auth = hs.get_auth()
@ -193,34 +193,16 @@ class ListRoomRestServlet(RestServlet):
self.admin_handler = hs.get_admin_handler() self.admin_handler = hs.get_admin_handler()
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) await assert_requester_is_admin(self.auth, request)
await assert_user_is_admin(self.auth, requester.user)
# Extract query parameters # Extract query parameters
start = parse_integer(request, "from", default=0) start = parse_integer(request, "from", default=0)
limit = parse_integer(request, "limit", default=100) limit = parse_integer(request, "limit", default=100)
order_by = parse_string(request, "order_by", default=RoomSortOrder.NAME.value) order_by = parse_string(
if order_by not in ( request,
RoomSortOrder.ALPHABETICAL.value, "order_by",
RoomSortOrder.SIZE.value, default=RoomSortOrder.NAME.value,
RoomSortOrder.NAME.value, allowed_values=[sort_order.value for sort_order in RoomSortOrder],
RoomSortOrder.CANONICAL_ALIAS.value,
RoomSortOrder.JOINED_MEMBERS.value,
RoomSortOrder.JOINED_LOCAL_MEMBERS.value,
RoomSortOrder.VERSION.value,
RoomSortOrder.CREATOR.value,
RoomSortOrder.ENCRYPTION.value,
RoomSortOrder.FEDERATABLE.value,
RoomSortOrder.PUBLIC.value,
RoomSortOrder.JOIN_RULES.value,
RoomSortOrder.GUEST_ACCESS.value,
RoomSortOrder.HISTORY_VISIBILITY.value,
RoomSortOrder.STATE_EVENTS.value,
):
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Unknown value for order_by: %s" % (order_by,),
errcode=Codes.INVALID_PARAM,
) )
search_term = parse_string(request, "search_term", encoding="utf-8") search_term = parse_string(request, "search_term", encoding="utf-8")
@ -292,10 +274,9 @@ class RoomRestServlet(RestServlet):
TODO: Add on_POST to allow room creation without joining the room TODO: Add on_POST to allow room creation without joining the room
""" """
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)$") PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)$")
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.room_shutdown_handler = hs.get_room_shutdown_handler() self.room_shutdown_handler = hs.get_room_shutdown_handler()
@ -397,10 +378,9 @@ class RoomMembersRestServlet(RestServlet):
Get members list of a room. Get members list of a room.
""" """
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/members") PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/members$")
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -424,10 +404,9 @@ class RoomStateRestServlet(RestServlet):
Get full state within a room. Get full state within a room.
""" """
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/state") PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/state$")
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.clock = hs.get_clock() self.clock = hs.get_clock()
@ -436,8 +415,7 @@ class RoomStateRestServlet(RestServlet):
async def on_GET( async def on_GET(
self, request: SynapseRequest, room_id: str self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) await assert_requester_is_admin(self.auth, request)
await assert_user_is_admin(self.auth, requester.user)
ret = await self.store.get_room(room_id) ret = await self.store.get_room(room_id)
if not ret: if not ret:
@ -454,14 +432,14 @@ class RoomStateRestServlet(RestServlet):
class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet): class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)") PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)$")
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler() self.admin_handler = hs.get_admin_handler()
self.state_handler = hs.get_state_handler() self.state_handler = hs.get_state_handler()
self.is_mine = hs.is_mine
async def on_POST( async def on_POST(
self, request: SynapseRequest, room_identifier: str self, request: SynapseRequest, room_identifier: str
@ -477,7 +455,7 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
assert_params_in_dict(content, ["user_id"]) assert_params_in_dict(content, ["user_id"])
target_user = UserID.from_string(content["user_id"]) target_user = UserID.from_string(content["user_id"])
if not self.hs.is_mine(target_user): if not self.is_mine(target_user):
raise SynapseError( raise SynapseError(
HTTPStatus.BAD_REQUEST, HTTPStatus.BAD_REQUEST,
"This endpoint can only be used with local users", "This endpoint can only be used with local users",
@ -542,11 +520,10 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
} }
""" """
PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/make_room_admin") PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/make_room_admin$")
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.event_creation_handler = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
@ -688,19 +665,17 @@ class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet):
GET /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities GET /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities
""" """
PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/forward_extremities") PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/forward_extremities$")
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
async def on_DELETE( async def on_DELETE(
self, request: SynapseRequest, room_identifier: str self, request: SynapseRequest, room_identifier: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) await assert_requester_is_admin(self.auth, request)
await assert_user_is_admin(self.auth, requester.user)
room_id, _ = await self.resolve_room_id(room_identifier) room_id, _ = await self.resolve_room_id(room_identifier)
@ -710,8 +685,7 @@ class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet):
async def on_GET( async def on_GET(
self, request: SynapseRequest, room_identifier: str self, request: SynapseRequest, room_identifier: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) await assert_requester_is_admin(self.auth, request)
await assert_user_is_admin(self.auth, requester.user)
room_id, _ = await self.resolve_room_id(room_identifier) room_id, _ = await self.resolve_room_id(room_identifier)
@ -793,7 +767,7 @@ class BlockRoomRestServlet(RestServlet):
On GET: Get blocking status of room and user who has blocked this room. On GET: Get blocking status of room and user who has blocked this room.
""" """
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/block$") PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/block$")
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self._auth = hs.get_auth() self._auth = hs.get_auth()

View file

@ -52,11 +52,11 @@ class SendServerNoticeServlet(RestServlet):
""" """
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.server_notices_manager = hs.get_server_notices_manager() self.server_notices_manager = hs.get_server_notices_manager()
self.admin_handler = hs.get_admin_handler() self.admin_handler = hs.get_admin_handler()
self.txns = HttpTransactionCache(hs) self.txns = HttpTransactionCache(hs)
self.is_mine = hs.is_mine
def register(self, json_resource: HttpServer) -> None: def register(self, json_resource: HttpServer) -> None:
PATTERN = "/send_server_notice" PATTERN = "/send_server_notice"
@ -88,7 +88,7 @@ class SendServerNoticeServlet(RestServlet):
) )
target_user = UserID.from_string(body["user_id"]) target_user = UserID.from_string(body["user_id"])
if not self.hs.is_mine(target_user): if not self.is_mine(target_user):
raise SynapseError( raise SynapseError(
HTTPStatus.BAD_REQUEST, "Server notices can only be sent to local users" HTTPStatus.BAD_REQUEST, "Server notices can only be sent to local users"
) )

View file

@ -37,7 +37,6 @@ class UserMediaStatisticsRestServlet(RestServlet):
PATTERNS = admin_patterns("/statistics/users/media$") PATTERNS = admin_patterns("/statistics/users/media$")
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -45,18 +44,15 @@ class UserMediaStatisticsRestServlet(RestServlet):
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
order_by = parse_string( order_by = parse_string(
request, "order_by", default=UserSortOrder.USER_ID.value request,
) "order_by",
if order_by not in ( default=UserSortOrder.USER_ID.value,
allowed_values=(
UserSortOrder.MEDIA_LENGTH.value, UserSortOrder.MEDIA_LENGTH.value,
UserSortOrder.MEDIA_COUNT.value, UserSortOrder.MEDIA_COUNT.value,
UserSortOrder.USER_ID.value, UserSortOrder.USER_ID.value,
UserSortOrder.DISPLAYNAME.value, UserSortOrder.DISPLAYNAME.value,
): ),
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Unknown value for order_by: %s" % (order_by,),
errcode=Codes.INVALID_PARAM,
) )
start = parse_integer(request, "from", default=0) start = parse_integer(request, "from", default=0)

View file

@ -37,7 +37,7 @@ class UsernameAvailableRestServlet(RestServlet):
} }
""" """
PATTERNS = admin_patterns("/username_available") PATTERNS = admin_patterns("/username_available$")
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.auth = hs.get_auth() self.auth = hs.get_auth()

View file

@ -66,7 +66,6 @@ class UsersRestServletV2(RestServlet):
""" """
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler() self.admin_handler = hs.get_admin_handler()
@ -126,7 +125,7 @@ class UsersRestServletV2(RestServlet):
class UserRestServletV2(RestServlet): class UserRestServletV2(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)$", "v2") PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)$", "v2")
"""Get request to list user details. """Get request to list user details.
This needs user to have administrator access in Synapse. This needs user to have administrator access in Synapse.
@ -414,7 +413,7 @@ class UserRegisterServlet(RestServlet):
nonce to the time it was generated, in int seconds. nonce to the time it was generated, in int seconds.
""" """
PATTERNS = admin_patterns("/register") PATTERNS = admin_patterns("/register$")
NONCE_TIMEOUT = 60 NONCE_TIMEOUT = 60
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
@ -561,9 +560,9 @@ class WhoisRestServlet(RestServlet):
] ]
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler() self.admin_handler = hs.get_admin_handler()
self.is_mine = hs.is_mine
async def on_GET( async def on_GET(
self, request: SynapseRequest, user_id: str self, request: SynapseRequest, user_id: str
@ -575,7 +574,7 @@ class WhoisRestServlet(RestServlet):
if target_user != auth_user: if target_user != auth_user:
await assert_user_is_admin(self.auth, auth_user) await assert_user_is_admin(self.auth, auth_user)
if not self.hs.is_mine(target_user): if not self.is_mine(target_user):
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only whois a local user") raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only whois a local user")
ret = await self.admin_handler.get_whois(target_user) ret = await self.admin_handler.get_whois(target_user)
@ -584,7 +583,7 @@ class WhoisRestServlet(RestServlet):
class DeactivateAccountRestServlet(RestServlet): class DeactivateAccountRestServlet(RestServlet):
PATTERNS = admin_patterns("/deactivate/(?P<target_user_id>[^/]*)") PATTERNS = admin_patterns("/deactivate/(?P<target_user_id>[^/]*)$")
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self._deactivate_account_handler = hs.get_deactivate_account_handler() self._deactivate_account_handler = hs.get_deactivate_account_handler()
@ -630,7 +629,6 @@ class AccountValidityRenewServlet(RestServlet):
PATTERNS = admin_patterns("/account_validity/validity$") PATTERNS = admin_patterns("/account_validity/validity$")
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.hs = hs
self.account_activity_handler = hs.get_account_validity_handler() self.account_activity_handler = hs.get_account_validity_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
@ -674,11 +672,10 @@ class ResetPasswordRestServlet(RestServlet):
200 OK with empty object if success otherwise an error. 200 OK with empty object if success otherwise an error.
""" """
PATTERNS = admin_patterns("/reset_password/(?P<target_user_id>[^/]*)") PATTERNS = admin_patterns("/reset_password/(?P<target_user_id>[^/]*)$")
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
self._set_password_handler = hs.get_set_password_handler() self._set_password_handler = hs.get_set_password_handler()
@ -718,12 +715,12 @@ class SearchUsersRestServlet(RestServlet):
200 OK with json object {list[dict[str, Any]], count} or empty object. 200 OK with json object {list[dict[str, Any]], count} or empty object.
""" """
PATTERNS = admin_patterns("/search_users/(?P<target_user_id>[^/]*)") PATTERNS = admin_patterns("/search_users/(?P<target_user_id>[^/]*)$")
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.is_mine = hs.is_mine
async def on_GET( async def on_GET(
self, request: SynapseRequest, target_user_id: str self, request: SynapseRequest, target_user_id: str
@ -740,7 +737,7 @@ class SearchUsersRestServlet(RestServlet):
# if not is_admin and target_user != auth_user: # if not is_admin and target_user != auth_user:
# raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin") # raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin")
if not self.hs.is_mine(target_user): if not self.is_mine(target_user):
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only users a local user") raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only users a local user")
term = parse_string(request, "term", required=True) term = parse_string(request, "term", required=True)
@ -779,9 +776,9 @@ class UserAdminServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/admin$") PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/admin$")
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.is_mine = hs.is_mine
async def on_GET( async def on_GET(
self, request: SynapseRequest, user_id: str self, request: SynapseRequest, user_id: str
@ -790,7 +787,7 @@ class UserAdminServlet(RestServlet):
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
if not self.hs.is_mine(target_user): if not self.is_mine(target_user):
raise SynapseError( raise SynapseError(
HTTPStatus.BAD_REQUEST, HTTPStatus.BAD_REQUEST,
"Only local users can be admins of this homeserver", "Only local users can be admins of this homeserver",
@ -813,7 +810,7 @@ class UserAdminServlet(RestServlet):
assert_params_in_dict(body, ["admin"]) assert_params_in_dict(body, ["admin"])
if not self.hs.is_mine(target_user): if not self.is_mine(target_user):
raise SynapseError( raise SynapseError(
HTTPStatus.BAD_REQUEST, HTTPStatus.BAD_REQUEST,
"Only local users can be admins of this homeserver", "Only local users can be admins of this homeserver",
@ -834,7 +831,7 @@ class UserMembershipRestServlet(RestServlet):
Get room list of an user. Get room list of an user.
""" """
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/joined_rooms$") PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/joined_rooms$")
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.is_mine = hs.is_mine self.is_mine = hs.is_mine
@ -909,10 +906,10 @@ class UserTokenRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/login$") PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/login$")
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
self.is_mine_id = hs.is_mine_id
async def on_POST( async def on_POST(
self, request: SynapseRequest, user_id: str self, request: SynapseRequest, user_id: str
@ -921,7 +918,7 @@ class UserTokenRestServlet(RestServlet):
await assert_user_is_admin(self.auth, requester.user) await assert_user_is_admin(self.auth, requester.user)
auth_user = requester.user auth_user = requester.user
if not self.hs.is_mine_id(user_id): if not self.is_mine_id(user_id):
raise SynapseError( raise SynapseError(
HTTPStatus.BAD_REQUEST, "Only local users can be logged in as" HTTPStatus.BAD_REQUEST, "Only local users can be logged in as"
) )
@ -975,19 +972,19 @@ class ShadowBanRestServlet(RestServlet):
{} {}
""" """
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/shadow_ban") PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/shadow_ban$")
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.is_mine_id = hs.is_mine_id
async def on_POST( async def on_POST(
self, request: SynapseRequest, user_id: str self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
if not self.hs.is_mine_id(user_id): if not self.is_mine_id(user_id):
raise SynapseError( raise SynapseError(
HTTPStatus.BAD_REQUEST, "Only local users can be shadow-banned" HTTPStatus.BAD_REQUEST, "Only local users can be shadow-banned"
) )
@ -1001,7 +998,7 @@ class ShadowBanRestServlet(RestServlet):
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
if not self.hs.is_mine_id(user_id): if not self.is_mine_id(user_id):
raise SynapseError( raise SynapseError(
HTTPStatus.BAD_REQUEST, "Only local users can be shadow-banned" HTTPStatus.BAD_REQUEST, "Only local users can be shadow-banned"
) )
@ -1027,19 +1024,19 @@ class RateLimitRestServlet(RestServlet):
} }
""" """
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/override_ratelimit") PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/override_ratelimit$")
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.is_mine_id = hs.is_mine_id
async def on_GET( async def on_GET(
self, request: SynapseRequest, user_id: str self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
if not self.hs.is_mine_id(user_id): if not self.is_mine_id(user_id):
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users") raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users")
if not await self.store.get_user_by_id(user_id): if not await self.store.get_user_by_id(user_id):
@ -1068,7 +1065,7 @@ class RateLimitRestServlet(RestServlet):
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
if not self.hs.is_mine_id(user_id): if not self.is_mine_id(user_id):
raise SynapseError( raise SynapseError(
HTTPStatus.BAD_REQUEST, "Only local users can be ratelimited" HTTPStatus.BAD_REQUEST, "Only local users can be ratelimited"
) )
@ -1113,7 +1110,7 @@ class RateLimitRestServlet(RestServlet):
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
if not self.hs.is_mine_id(user_id): if not self.is_mine_id(user_id):
raise SynapseError( raise SynapseError(
HTTPStatus.BAD_REQUEST, "Only local users can be ratelimited" HTTPStatus.BAD_REQUEST, "Only local users can be ratelimited"
) )

View file

@ -92,7 +92,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
channel.code, channel.code,
msg=channel.json_body, msg=channel.json_body,
) )
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
# negative from # negative from
channel = self.make_request( channel = self.make_request(