0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-14 10:53:54 +01:00

Addtional type hints for the REST servlets. (#10665)

This commit is contained in:
Patrick Cloke 2021-08-23 08:14:17 -04:00 committed by GitHub
parent 31dac7ffee
commit 2af6d31b78
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 204 additions and 107 deletions

1
changelog.d/10665.misc Normal file
View file

@ -0,0 +1 @@
Add missing type hints to REST servlets.

View file

@ -13,24 +13,27 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import SynapseError from twisted.web.server import Request
from synapse.http.server import respond_with_html
from synapse.http.servlet import RestServlet from synapse.http.server import HttpServer, respond_with_html
from synapse.http.servlet import RestServlet, parse_string
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict
from ._base import client_patterns from ._base import client_patterns
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class AccountValidityRenewServlet(RestServlet): class AccountValidityRenewServlet(RestServlet):
PATTERNS = client_patterns("/account_validity/renew$") PATTERNS = client_patterns("/account_validity/renew$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super().__init__() super().__init__()
self.hs = hs self.hs = hs
@ -46,18 +49,14 @@ class AccountValidityRenewServlet(RestServlet):
hs.config.account_validity.account_validity_invalid_token_template hs.config.account_validity.account_validity_invalid_token_template
) )
async def on_GET(self, request): async def on_GET(self, request: Request) -> None:
if b"token" not in request.args: renewal_token = parse_string(request, "token", required=True)
raise SynapseError(400, "Missing renewal token")
renewal_token = request.args[b"token"][0]
( (
token_valid, token_valid,
token_stale, token_stale,
expiration_ts, expiration_ts,
) = await self.account_activity_handler.renew_account( ) = await self.account_activity_handler.renew_account(renewal_token)
renewal_token.decode("utf8")
)
if token_valid: if token_valid:
status_code = 200 status_code = 200
@ -77,11 +76,7 @@ class AccountValidityRenewServlet(RestServlet):
class AccountValiditySendMailServlet(RestServlet): class AccountValiditySendMailServlet(RestServlet):
PATTERNS = client_patterns("/account_validity/send_mail$") PATTERNS = client_patterns("/account_validity/send_mail$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super().__init__() super().__init__()
self.hs = hs self.hs = hs
@ -91,7 +86,7 @@ class AccountValiditySendMailServlet(RestServlet):
hs.config.account_validity.account_validity_renew_by_email_enabled hs.config.account_validity.account_validity_renew_by_email_enabled
) )
async def on_POST(self, request): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_expired=True) requester = await self.auth.get_user_by_req(request, allow_expired=True)
user_id = requester.user.to_string() user_id = requester.user.to_string()
await self.account_activity_handler.send_renewal_email_to_user(user_id) await self.account_activity_handler.send_renewal_email_to_user(user_id)
@ -99,6 +94,6 @@ class AccountValiditySendMailServlet(RestServlet):
return 200, {} return 200, {}
def register_servlets(hs, http_server): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
AccountValidityRenewServlet(hs).register(http_server) AccountValidityRenewServlet(hs).register(http_server)
AccountValiditySendMailServlet(hs).register(http_server) AccountValiditySendMailServlet(hs).register(http_server)

View file

@ -15,6 +15,7 @@ import logging
from typing import TYPE_CHECKING, Tuple from typing import TYPE_CHECKING, Tuple
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, MSC3244_CAPABILITIES from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, MSC3244_CAPABILITIES
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.types import JsonDict from synapse.types import JsonDict
@ -75,5 +76,5 @@ class CapabilitiesRestServlet(RestServlet):
return 200, response return 200, response
def register_servlets(hs: "HomeServer", http_server): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
CapabilitiesRestServlet(hs).register(http_server) CapabilitiesRestServlet(hs).register(http_server)

View file

@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Tuple
from twisted.web.server import Request
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, AuthError,
@ -22,14 +24,19 @@ from synapse.api.errors import (
NotFoundError, NotFoundError,
SynapseError, SynapseError,
) )
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns from synapse.rest.client._base import client_patterns
from synapse.types import RoomAlias from synapse.types import JsonDict, RoomAlias
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def register_servlets(hs, http_server): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ClientDirectoryServer(hs).register(http_server) ClientDirectoryServer(hs).register(http_server)
ClientDirectoryListServer(hs).register(http_server) ClientDirectoryListServer(hs).register(http_server)
ClientAppserviceDirectoryListServer(hs).register(http_server) ClientAppserviceDirectoryListServer(hs).register(http_server)
@ -38,21 +45,23 @@ def register_servlets(hs, http_server):
class ClientDirectoryServer(RestServlet): class ClientDirectoryServer(RestServlet):
PATTERNS = client_patterns("/directory/room/(?P<room_alias>[^/]*)$", v1=True) PATTERNS = client_patterns("/directory/room/(?P<room_alias>[^/]*)$", v1=True)
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.directory_handler = hs.get_directory_handler() self.directory_handler = hs.get_directory_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
async def on_GET(self, request, room_alias): async def on_GET(self, request: Request, room_alias: str) -> Tuple[int, JsonDict]:
room_alias = RoomAlias.from_string(room_alias) room_alias_obj = RoomAlias.from_string(room_alias)
res = await self.directory_handler.get_association(room_alias) res = await self.directory_handler.get_association(room_alias_obj)
return 200, res return 200, res
async def on_PUT(self, request, room_alias): async def on_PUT(
room_alias = RoomAlias.from_string(room_alias) self, request: SynapseRequest, room_alias: str
) -> Tuple[int, JsonDict]:
room_alias_obj = RoomAlias.from_string(room_alias)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
if "room_id" not in content: if "room_id" not in content:
@ -61,7 +70,7 @@ class ClientDirectoryServer(RestServlet):
) )
logger.debug("Got content: %s", content) logger.debug("Got content: %s", content)
logger.debug("Got room name: %s", room_alias.to_string()) logger.debug("Got room name: %s", room_alias_obj.to_string())
room_id = content["room_id"] room_id = content["room_id"]
servers = content["servers"] if "servers" in content else None servers = content["servers"] if "servers" in content else None
@ -78,22 +87,25 @@ class ClientDirectoryServer(RestServlet):
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
await self.directory_handler.create_association( await self.directory_handler.create_association(
requester, room_alias, room_id, servers requester, room_alias_obj, room_id, servers
) )
return 200, {} return 200, {}
async def on_DELETE(self, request, room_alias): async def on_DELETE(
self, request: SynapseRequest, room_alias: str
) -> Tuple[int, JsonDict]:
room_alias_obj = RoomAlias.from_string(room_alias)
try: try:
service = self.auth.get_appservice_by_req(request) service = self.auth.get_appservice_by_req(request)
room_alias = RoomAlias.from_string(room_alias)
await self.directory_handler.delete_appservice_association( await self.directory_handler.delete_appservice_association(
service, room_alias service, room_alias_obj
) )
logger.info( logger.info(
"Application service at %s deleted alias %s", "Application service at %s deleted alias %s",
service.url, service.url,
room_alias.to_string(), room_alias_obj.to_string(),
) )
return 200, {} return 200, {}
except InvalidClientCredentialsError: except InvalidClientCredentialsError:
@ -103,12 +115,10 @@ class ClientDirectoryServer(RestServlet):
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
user = requester.user user = requester.user
room_alias = RoomAlias.from_string(room_alias) await self.directory_handler.delete_association(requester, room_alias_obj)
await self.directory_handler.delete_association(requester, room_alias)
logger.info( logger.info(
"User %s deleted alias %s", user.to_string(), room_alias.to_string() "User %s deleted alias %s", user.to_string(), room_alias_obj.to_string()
) )
return 200, {} return 200, {}
@ -117,20 +127,22 @@ class ClientDirectoryServer(RestServlet):
class ClientDirectoryListServer(RestServlet): class ClientDirectoryListServer(RestServlet):
PATTERNS = client_patterns("/directory/list/room/(?P<room_id>[^/]*)$", v1=True) PATTERNS = client_patterns("/directory/list/room/(?P<room_id>[^/]*)$", v1=True)
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.directory_handler = hs.get_directory_handler() self.directory_handler = hs.get_directory_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
async def on_GET(self, request, room_id): async def on_GET(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
room = await self.store.get_room(room_id) room = await self.store.get_room(room_id)
if room is None: if room is None:
raise NotFoundError("Unknown room") raise NotFoundError("Unknown room")
return 200, {"visibility": "public" if room["is_public"] else "private"} return 200, {"visibility": "public" if room["is_public"] else "private"}
async def on_PUT(self, request, room_id): async def on_PUT(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
@ -142,7 +154,9 @@ class ClientDirectoryListServer(RestServlet):
return 200, {} return 200, {}
async def on_DELETE(self, request, room_id): async def on_DELETE(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
await self.directory_handler.edit_published_room_list( await self.directory_handler.edit_published_room_list(
@ -157,21 +171,27 @@ class ClientAppserviceDirectoryListServer(RestServlet):
"/directory/list/appservice/(?P<network_id>[^/]*)/(?P<room_id>[^/]*)$", v1=True "/directory/list/appservice/(?P<network_id>[^/]*)/(?P<room_id>[^/]*)$", v1=True
) )
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.directory_handler = hs.get_directory_handler() self.directory_handler = hs.get_directory_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
def on_PUT(self, request, network_id, room_id): async def on_PUT(
self, request: SynapseRequest, network_id: str, room_id: str
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
visibility = content.get("visibility", "public") visibility = content.get("visibility", "public")
return self._edit(request, network_id, room_id, visibility) return await self._edit(request, network_id, room_id, visibility)
def on_DELETE(self, request, network_id, room_id): async def on_DELETE(
return self._edit(request, network_id, room_id, "private") self, request: SynapseRequest, network_id: str, room_id: str
) -> Tuple[int, JsonDict]:
return await self._edit(request, network_id, room_id, "private")
async def _edit(self, request, network_id, room_id, visibility): async def _edit(
self, request: SynapseRequest, network_id: str, room_id: str, visibility: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
if not requester.app_service: if not requester.app_service:
raise AuthError( raise AuthError(

View file

@ -13,17 +13,23 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import Codes, StoreError, SynapseError from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.http.server import respond_with_html_bytes from synapse.http.server import HttpServer, respond_with_html_bytes
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, RestServlet,
assert_params_in_dict, assert_params_in_dict,
parse_json_object_from_request, parse_json_object_from_request,
parse_string, parse_string,
) )
from synapse.http.site import SynapseRequest
from synapse.push import PusherConfigException from synapse.push import PusherConfigException
from synapse.rest.client._base import client_patterns from synapse.rest.client._base import client_patterns
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -31,12 +37,12 @@ logger = logging.getLogger(__name__)
class PushersRestServlet(RestServlet): class PushersRestServlet(RestServlet):
PATTERNS = client_patterns("/pushers$", v1=True) PATTERNS = client_patterns("/pushers$", v1=True)
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
async def on_GET(self, request): async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
user = requester.user user = requester.user
@ -50,14 +56,14 @@ class PushersRestServlet(RestServlet):
class PushersSetRestServlet(RestServlet): class PushersSetRestServlet(RestServlet):
PATTERNS = client_patterns("/pushers/set$", v1=True) PATTERNS = client_patterns("/pushers/set$", v1=True)
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.pusher_pool = self.hs.get_pusherpool() self.pusher_pool = self.hs.get_pusherpool()
async def on_POST(self, request): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
user = requester.user user = requester.user
@ -132,14 +138,14 @@ class PushersRemoveRestServlet(RestServlet):
PATTERNS = client_patterns("/pushers/remove$", v1=True) PATTERNS = client_patterns("/pushers/remove$", v1=True)
SUCCESS_HTML = b"<html><body>You have been unsubscribed</body><html>" SUCCESS_HTML = b"<html><body>You have been unsubscribed</body><html>"
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.pusher_pool = self.hs.get_pusherpool() self.pusher_pool = self.hs.get_pusherpool()
async def on_GET(self, request): async def on_GET(self, request: SynapseRequest) -> None:
requester = await self.auth.get_user_by_req(request, rights="delete_pusher") requester = await self.auth.get_user_by_req(request, rights="delete_pusher")
user = requester.user user = requester.user
@ -165,7 +171,7 @@ class PushersRemoveRestServlet(RestServlet):
return None return None
def register_servlets(hs, http_server): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
PushersRestServlet(hs).register(http_server) PushersRestServlet(hs).register(http_server)
PushersSetRestServlet(hs).register(http_server) PushersSetRestServlet(hs).register(http_server)
PushersRemoveRestServlet(hs).register(http_server) PushersRemoveRestServlet(hs).register(http_server)

View file

@ -13,27 +13,36 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Tuple
from synapse.api.constants import ReadReceiptEventFields from synapse.api.constants import ReadReceiptEventFields
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict
from ._base import client_patterns from ._base import client_patterns
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ReadMarkerRestServlet(RestServlet): class ReadMarkerRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/read_markers$") PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/read_markers$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.receipts_handler = hs.get_receipts_handler() self.receipts_handler = hs.get_receipts_handler()
self.read_marker_handler = hs.get_read_marker_handler() self.read_marker_handler = hs.get_read_marker_handler()
self.presence_handler = hs.get_presence_handler() self.presence_handler = hs.get_presence_handler()
async def on_POST(self, request, room_id): async def on_POST(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
await self.presence_handler.bump_presence_active_time(requester.user) await self.presence_handler.bump_presence_active_time(requester.user)
@ -70,5 +79,5 @@ class ReadMarkerRestServlet(RestServlet):
return 200, {} return 200, {}
def register_servlets(hs, http_server): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReadMarkerRestServlet(hs).register(http_server) ReadMarkerRestServlet(hs).register(http_server)

View file

@ -13,18 +13,25 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import Codes, ShadowBanError, SynapseError from synapse.api.errors import Codes, ShadowBanError, SynapseError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.http.server import HttpServer
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, RestServlet,
assert_params_in_dict, assert_params_in_dict,
parse_json_object_from_request, parse_json_object_from_request,
) )
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict
from synapse.util import stringutils from synapse.util import stringutils
from ._base import client_patterns from ._base import client_patterns
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -41,9 +48,6 @@ class RoomUpgradeRestServlet(RestServlet):
} }
Creates a new room and shuts down the old one. Returns the ID of the new room. Creates a new room and shuts down the old one. Returns the ID of the new room.
Args:
hs (synapse.server.HomeServer):
""" """
PATTERNS = client_patterns( PATTERNS = client_patterns(
@ -51,13 +55,15 @@ class RoomUpgradeRestServlet(RestServlet):
"/rooms/(?P<room_id>[^/]*)/upgrade$" "/rooms/(?P<room_id>[^/]*)/upgrade$"
) )
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self._hs = hs self._hs = hs
self._room_creation_handler = hs.get_room_creation_handler() self._room_creation_handler = hs.get_room_creation_handler()
self._auth = hs.get_auth() self._auth = hs.get_auth()
async def on_POST(self, request, room_id): async def on_POST(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self._auth.get_user_by_req(request) requester = await self._auth.get_user_by_req(request)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
@ -84,5 +90,5 @@ class RoomUpgradeRestServlet(RestServlet):
return 200, ret return 200, ret
def register_servlets(hs, http_server): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
RoomUpgradeRestServlet(hs).register(http_server) RoomUpgradeRestServlet(hs).register(http_server)

View file

@ -12,13 +12,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from synapse.types import UserID from synapse.http.site import SynapseRequest
from synapse.types import JsonDict, UserID
from ._base import client_patterns from ._base import client_patterns
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -32,13 +38,15 @@ class UserSharedRoomsServlet(RestServlet):
releases=(), # This is an unstable feature releases=(), # This is an unstable feature
) )
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.user_directory_active = hs.config.update_user_directory self.user_directory_active = hs.config.update_user_directory
async def on_GET(self, request, user_id): async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
if not self.user_directory_active: if not self.user_directory_active:
raise SynapseError( raise SynapseError(
@ -63,5 +71,5 @@ class UserSharedRoomsServlet(RestServlet):
return 200, {"joined": list(rooms)} return 200, {"joined": list(rooms)}
def register_servlets(hs, http_server): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
UserSharedRoomsServlet(hs).register(http_server) UserSharedRoomsServlet(hs).register(http_server)

View file

@ -13,12 +13,19 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict
from ._base import client_patterns from ._base import client_patterns
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -29,12 +36,14 @@ class TagListServlet(RestServlet):
PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags") PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
async def on_GET(self, request, user_id, room_id): async def on_GET(
self, request: SynapseRequest, user_id: str, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string(): if user_id != requester.user.to_string():
raise AuthError(403, "Cannot get tags for other users.") raise AuthError(403, "Cannot get tags for other users.")
@ -54,12 +63,14 @@ class TagServlet(RestServlet):
"/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags/(?P<tag>[^/]*)" "/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags/(?P<tag>[^/]*)"
) )
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.handler = hs.get_account_data_handler() self.handler = hs.get_account_data_handler()
async def on_PUT(self, request, user_id, room_id, tag): async def on_PUT(
self, request: SynapseRequest, user_id: str, room_id: str, tag: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string(): if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add tags for other users.") raise AuthError(403, "Cannot add tags for other users.")
@ -70,7 +81,9 @@ class TagServlet(RestServlet):
return 200, {} return 200, {}
async def on_DELETE(self, request, user_id, room_id, tag): async def on_DELETE(
self, request: SynapseRequest, user_id: str, room_id: str, tag: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string(): if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add tags for other users.") raise AuthError(403, "Cannot add tags for other users.")
@ -80,6 +93,6 @@ class TagServlet(RestServlet):
return 200, {} return 200, {}
def register_servlets(hs, http_server): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
TagListServlet(hs).register(http_server) TagListServlet(hs).register(http_server)
TagServlet(hs).register(http_server) TagServlet(hs).register(http_server)

View file

@ -12,27 +12,33 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Dict, List, Tuple
from synapse.api.constants import ThirdPartyEntityKind from synapse.api.constants import ThirdPartyEntityKind
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict
from ._base import client_patterns from ._base import client_patterns
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ThirdPartyProtocolsServlet(RestServlet): class ThirdPartyProtocolsServlet(RestServlet):
PATTERNS = client_patterns("/thirdparty/protocols") PATTERNS = client_patterns("/thirdparty/protocols")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler() self.appservice_handler = hs.get_application_service_handler()
async def on_GET(self, request): async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True) await self.auth.get_user_by_req(request, allow_guest=True)
protocols = await self.appservice_handler.get_3pe_protocols() protocols = await self.appservice_handler.get_3pe_protocols()
@ -42,13 +48,15 @@ class ThirdPartyProtocolsServlet(RestServlet):
class ThirdPartyProtocolServlet(RestServlet): class ThirdPartyProtocolServlet(RestServlet):
PATTERNS = client_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$") PATTERNS = client_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler() self.appservice_handler = hs.get_application_service_handler()
async def on_GET(self, request, protocol): async def on_GET(
self, request: SynapseRequest, protocol: str
) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True) await self.auth.get_user_by_req(request, allow_guest=True)
protocols = await self.appservice_handler.get_3pe_protocols( protocols = await self.appservice_handler.get_3pe_protocols(
@ -63,16 +71,18 @@ class ThirdPartyProtocolServlet(RestServlet):
class ThirdPartyUserServlet(RestServlet): class ThirdPartyUserServlet(RestServlet):
PATTERNS = client_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$") PATTERNS = client_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler() self.appservice_handler = hs.get_application_service_handler()
async def on_GET(self, request, protocol): async def on_GET(
self, request: SynapseRequest, protocol: str
) -> Tuple[int, List[JsonDict]]:
await self.auth.get_user_by_req(request, allow_guest=True) await self.auth.get_user_by_req(request, allow_guest=True)
fields = request.args fields: Dict[bytes, List[bytes]] = request.args # type: ignore[assignment]
fields.pop(b"access_token", None) fields.pop(b"access_token", None)
results = await self.appservice_handler.query_3pe( results = await self.appservice_handler.query_3pe(
@ -85,16 +95,18 @@ class ThirdPartyUserServlet(RestServlet):
class ThirdPartyLocationServlet(RestServlet): class ThirdPartyLocationServlet(RestServlet):
PATTERNS = client_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$") PATTERNS = client_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler() self.appservice_handler = hs.get_application_service_handler()
async def on_GET(self, request, protocol): async def on_GET(
self, request: SynapseRequest, protocol: str
) -> Tuple[int, List[JsonDict]]:
await self.auth.get_user_by_req(request, allow_guest=True) await self.auth.get_user_by_req(request, allow_guest=True)
fields = request.args fields: Dict[bytes, List[bytes]] = request.args # type: ignore[assignment]
fields.pop(b"access_token", None) fields.pop(b"access_token", None)
results = await self.appservice_handler.query_3pe( results = await self.appservice_handler.query_3pe(
@ -104,7 +116,7 @@ class ThirdPartyLocationServlet(RestServlet):
return 200, results return 200, results
def register_servlets(hs, http_server): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ThirdPartyProtocolsServlet(hs).register(http_server) ThirdPartyProtocolsServlet(hs).register(http_server)
ThirdPartyProtocolServlet(hs).register(http_server) ThirdPartyProtocolServlet(hs).register(http_server)
ThirdPartyUserServlet(hs).register(http_server) ThirdPartyUserServlet(hs).register(http_server)

View file

@ -12,11 +12,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING
from twisted.web.server import Request
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from ._base import client_patterns from ._base import client_patterns
if TYPE_CHECKING:
from synapse.server import HomeServer
class TokenRefreshRestServlet(RestServlet): class TokenRefreshRestServlet(RestServlet):
""" """
@ -26,12 +34,12 @@ class TokenRefreshRestServlet(RestServlet):
PATTERNS = client_patterns("/tokenrefresh") PATTERNS = client_patterns("/tokenrefresh")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
async def on_POST(self, request): async def on_POST(self, request: Request) -> None:
raise AuthError(403, "tokenrefresh is no longer supported.") raise AuthError(403, "tokenrefresh is no longer supported.")
def register_servlets(hs, http_server): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
TokenRefreshRestServlet(hs).register(http_server) TokenRefreshRestServlet(hs).register(http_server)

View file

@ -13,29 +13,32 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict
from ._base import client_patterns from ._base import client_patterns
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class UserDirectorySearchRestServlet(RestServlet): class UserDirectorySearchRestServlet(RestServlet):
PATTERNS = client_patterns("/user_directory/search$") PATTERNS = client_patterns("/user_directory/search$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.user_directory_handler = hs.get_user_directory_handler() self.user_directory_handler = hs.get_user_directory_handler()
async def on_POST(self, request): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
"""Searches for users in directory """Searches for users in directory
Returns: Returns:
@ -75,5 +78,5 @@ class UserDirectorySearchRestServlet(RestServlet):
return 200, results return 200, results
def register_servlets(hs, http_server): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
UserDirectorySearchRestServlet(hs).register(http_server) UserDirectorySearchRestServlet(hs).register(http_server)

View file

@ -17,9 +17,17 @@
import logging import logging
import re import re
from typing import TYPE_CHECKING, Tuple
from twisted.web.server import Request
from synapse.api.constants import RoomCreationPreset from synapse.api.constants import RoomCreationPreset
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -27,7 +35,7 @@ logger = logging.getLogger(__name__)
class VersionsRestServlet(RestServlet): class VersionsRestServlet(RestServlet):
PATTERNS = [re.compile("^/_matrix/client/versions$")] PATTERNS = [re.compile("^/_matrix/client/versions$")]
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.config = hs.config self.config = hs.config
@ -45,7 +53,7 @@ class VersionsRestServlet(RestServlet):
in self.config.encryption_enabled_by_default_for_room_presets in self.config.encryption_enabled_by_default_for_room_presets
) )
def on_GET(self, request): def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
return ( return (
200, 200,
{ {
@ -89,5 +97,5 @@ class VersionsRestServlet(RestServlet):
) )
def register_servlets(hs, http_server): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
VersionsRestServlet(hs).register(http_server) VersionsRestServlet(hs).register(http_server)

View file

@ -15,20 +15,27 @@
import base64 import base64
import hashlib import hashlib
import hmac import hmac
from typing import TYPE_CHECKING, Tuple
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns from synapse.rest.client._base import client_patterns
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
class VoipRestServlet(RestServlet): class VoipRestServlet(RestServlet):
PATTERNS = client_patterns("/voip/turnServer$", v1=True) PATTERNS = client_patterns("/voip/turnServer$", v1=True)
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
async def on_GET(self, request): async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req( requester = await self.auth.get_user_by_req(
request, self.hs.config.turn_allow_guests request, self.hs.config.turn_allow_guests
) )
@ -69,5 +76,5 @@ class VoipRestServlet(RestServlet):
) )
def register_servlets(hs, http_server): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
VoipRestServlet(hs).register(http_server) VoipRestServlet(hs).register(http_server)