Add missing type hints to the admin API servlets (#10105)

This commit is contained in:
Dirk Klimpel 2021-06-07 16:12:34 +02:00 committed by GitHub
parent fa1db8f156
commit d558292548
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 48 additions and 40 deletions

1
changelog.d/10105.misc Normal file
View file

@ -0,0 +1 @@
Add missing type hints to the admin API servlets.

View file

@ -17,11 +17,13 @@
import logging import logging
import platform import platform
from typing import TYPE_CHECKING, Optional, Tuple
import synapse import synapse
from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.server import JsonResource from synapse.http.server import HttpServer, JsonResource
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
from synapse.rest.admin.devices import ( from synapse.rest.admin.devices import (
DeleteDevicesRestServlet, DeleteDevicesRestServlet,
@ -66,22 +68,25 @@ from synapse.rest.admin.users import (
UserTokenRestServlet, UserTokenRestServlet,
WhoisRestServlet, WhoisRestServlet,
) )
from synapse.types import RoomStreamToken from synapse.types import JsonDict, RoomStreamToken
from synapse.util.versionstring import get_version_string from synapse.util.versionstring import get_version_string
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class VersionServlet(RestServlet): class VersionServlet(RestServlet):
PATTERNS = admin_patterns("/server_version$") PATTERNS = admin_patterns("/server_version$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.res = { self.res = {
"server_version": get_version_string(synapse), "server_version": get_version_string(synapse),
"python_version": platform.python_version(), "python_version": platform.python_version(),
} }
def on_GET(self, request): def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
return 200, self.res return 200, self.res
@ -90,17 +95,14 @@ class PurgeHistoryRestServlet(RestServlet):
"/purge_history/(?P<room_id>[^/]*)(/(?P<event_id>[^/]+))?" "/purge_history/(?P<room_id>[^/]*)(/(?P<event_id>[^/]+))?"
) )
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
"""
Args:
hs (synapse.server.HomeServer)
"""
self.pagination_handler = hs.get_pagination_handler() self.pagination_handler = hs.get_pagination_handler()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
async def on_POST(self, request, room_id, event_id): async def on_POST(
self, request: SynapseRequest, room_id: str, event_id: Optional[str]
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
body = parse_json_object_from_request(request, allow_empty_body=True) body = parse_json_object_from_request(request, allow_empty_body=True)
@ -119,6 +121,8 @@ class PurgeHistoryRestServlet(RestServlet):
if event.room_id != room_id: if event.room_id != room_id:
raise SynapseError(400, "Event is for wrong room.") raise SynapseError(400, "Event is for wrong room.")
# RoomStreamToken expects [int] not Optional[int]
assert event.internal_metadata.stream_ordering is not None
room_token = RoomStreamToken( room_token = RoomStreamToken(
event.depth, event.internal_metadata.stream_ordering event.depth, event.internal_metadata.stream_ordering
) )
@ -173,16 +177,13 @@ class PurgeHistoryRestServlet(RestServlet):
class PurgeHistoryStatusRestServlet(RestServlet): class PurgeHistoryStatusRestServlet(RestServlet):
PATTERNS = admin_patterns("/purge_history_status/(?P<purge_id>[^/]+)") PATTERNS = admin_patterns("/purge_history_status/(?P<purge_id>[^/]+)")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
"""
Args:
hs (synapse.server.HomeServer)
"""
self.pagination_handler = hs.get_pagination_handler() self.pagination_handler = hs.get_pagination_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
async def on_GET(self, request, purge_id): async def on_GET(
self, request: SynapseRequest, purge_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
purge_status = self.pagination_handler.get_purge_status(purge_id) purge_status = self.pagination_handler.get_purge_status(purge_id)
@ -203,12 +204,12 @@ class PurgeHistoryStatusRestServlet(RestServlet):
class AdminRestResource(JsonResource): class AdminRestResource(JsonResource):
"""The REST resource which gets mounted at /_synapse/admin""" """The REST resource which gets mounted at /_synapse/admin"""
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
JsonResource.__init__(self, hs, canonical_json=False) JsonResource.__init__(self, hs, canonical_json=False)
register_servlets(hs, self) register_servlets(hs, self)
def register_servlets(hs, http_server): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
""" """
Register all the admin servlets. Register all the admin servlets.
""" """
@ -242,7 +243,9 @@ def register_servlets(hs, http_server):
RateLimitRestServlet(hs).register(http_server) RateLimitRestServlet(hs).register(http_server)
def register_servlets_for_client_rest_resource(hs, http_server): def register_servlets_for_client_rest_resource(
hs: "HomeServer", http_server: HttpServer
) -> None:
"""Register only the servlets which need to be exposed on /_matrix/client/xxx""" """Register only the servlets which need to be exposed on /_matrix/client/xxx"""
WhoisRestServlet(hs).register(http_server) WhoisRestServlet(hs).register(http_server)
PurgeHistoryStatusRestServlet(hs).register(http_server) PurgeHistoryStatusRestServlet(hs).register(http_server)

View file

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import re import re
from typing import Iterable, Pattern
from synapse.api.auth import Auth from synapse.api.auth import Auth
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
@ -20,7 +21,7 @@ from synapse.http.site import SynapseRequest
from synapse.types import UserID from synapse.types import UserID
def admin_patterns(path_regex: str, version: str = "v1"): def admin_patterns(path_regex: str, version: str = "v1") -> Iterable[Pattern]:
"""Returns the list of patterns for an admin endpoint """Returns the list of patterns for an admin endpoint
Args: Args:

View file

@ -12,10 +12,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import admin_patterns, assert_user_is_admin from synapse.rest.admin._base import admin_patterns, assert_user_is_admin
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -25,12 +31,14 @@ class DeleteGroupAdminRestServlet(RestServlet):
PATTERNS = admin_patterns("/delete_group/(?P<group_id>[^/]*)") PATTERNS = admin_patterns("/delete_group/(?P<group_id>[^/]*)")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.group_server = hs.get_groups_server_handler() self.group_server = hs.get_groups_server_handler()
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
self.auth = hs.get_auth() self.auth = hs.get_auth()
async def on_POST(self, request, group_id): async def on_POST(
self, request: SynapseRequest, group_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user) await assert_user_is_admin(self.auth, requester.user)

View file

@ -17,6 +17,7 @@ import logging
from typing import TYPE_CHECKING, Tuple from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer from synapse.http.servlet import RestServlet, parse_boolean, parse_integer
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import ( from synapse.rest.admin._base import (
@ -37,12 +38,11 @@ class QuarantineMediaInRoom(RestServlet):
this server. this server.
""" """
PATTERNS = ( PATTERNS = [
admin_patterns("/room/(?P<room_id>[^/]+)/media/quarantine") *admin_patterns("/room/(?P<room_id>[^/]+)/media/quarantine"),
+
# This path kept around for legacy reasons # This path kept around for legacy reasons
admin_patterns("/quarantine_media/(?P<room_id>[^/]+)") *admin_patterns("/quarantine_media/(?P<room_id>[^/]+)"),
) ]
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -312,7 +312,7 @@ class DeleteMediaByDateSize(RestServlet):
return 200, {"deleted_media": deleted_media, "total": total} return 200, {"deleted_media": deleted_media, "total": total}
def register_servlets_for_media_repo(hs: "HomeServer", http_server): def register_servlets_for_media_repo(hs: "HomeServer", http_server: HttpServer) -> None:
""" """
Media repo specific APIs. Media repo specific APIs.
""" """

View file

@ -478,13 +478,12 @@ class UserRegisterServlet(RestServlet):
class WhoisRestServlet(RestServlet): class WhoisRestServlet(RestServlet):
path_regex = "/whois/(?P<user_id>[^/]*)$" path_regex = "/whois/(?P<user_id>[^/]*)$"
PATTERNS = ( PATTERNS = [
admin_patterns(path_regex) *admin_patterns(path_regex),
+
# URL for spec reason # URL for spec reason
# https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-admin-whois-userid # https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-admin-whois-userid
client_patterns("/admin" + path_regex, v1=True) *client_patterns("/admin" + path_regex, v1=True),
) ]
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
@ -553,11 +552,7 @@ class DeactivateAccountRestServlet(RestServlet):
class AccountValidityRenewServlet(RestServlet): class AccountValidityRenewServlet(RestServlet):
PATTERNS = admin_patterns("/account_validity/validity$") PATTERNS = admin_patterns("/account_validity/validity$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
"""
Args:
hs (synapse.server.HomeServer): server
"""
self.hs = hs self.hs = hs
self.account_activity_handler = hs.get_account_validity_handler() self.account_activity_handler = hs.get_account_validity_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()