Additional type hints for REST servlets (part 2). (#10674)

Applies the changes from #10665 to additional modules.
This commit is contained in:
Patrick Cloke 2021-08-26 07:53:52 -04:00 committed by GitHub
parent 5548fe0978
commit 1aa0dad021
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 216 additions and 138 deletions

1
changelog.d/10674.misc Normal file
View file

@ -0,0 +1 @@
Add missing type hints to REST servlets.

View file

@ -353,6 +353,11 @@ class BasePresenceHandler(abc.ABC):
# otherwise would not do). # otherwise would not do).
await self.set_state(UserID.from_string(user_id), state, force_notify=True) await self.set_state(UserID.from_string(user_id), state, force_notify=True)
async def is_visible(self, observed_user: UserID, observer_user: UserID) -> bool:
raise NotImplementedError(
"Attempting to check presence on a non-presence worker."
)
class _NullContextManager(ContextManager[None]): class _NullContextManager(ContextManager[None]):
"""A context manager which does nothing.""" """A context manager which does nothing."""

View file

@ -15,11 +15,14 @@
import logging import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from twisted.web.server import Request
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import LoginError, SynapseError from synapse.api.errors import LoginError, SynapseError
from synapse.api.urls import CLIENT_API_PREFIX from synapse.api.urls import CLIENT_API_PREFIX
from synapse.http.server import respond_with_html from synapse.http.server import HttpServer, respond_with_html
from synapse.http.servlet import RestServlet, parse_string from synapse.http.servlet import RestServlet, parse_string
from synapse.http.site import SynapseRequest
from ._base import client_patterns from ._base import client_patterns
@ -49,7 +52,7 @@ class AuthRestServlet(RestServlet):
self.registration_token_template = hs.config.registration_token_template self.registration_token_template = hs.config.registration_token_template
self.success_template = hs.config.fallback_success_template self.success_template = hs.config.fallback_success_template
async def on_GET(self, request, stagetype): async def on_GET(self, request: SynapseRequest, stagetype: str) -> None:
session = parse_string(request, "session") session = parse_string(request, "session")
if not session: if not session:
raise SynapseError(400, "No session supplied") raise SynapseError(400, "No session supplied")
@ -88,7 +91,7 @@ class AuthRestServlet(RestServlet):
respond_with_html(request, 200, html) respond_with_html(request, 200, html)
return None return None
async def on_POST(self, request, stagetype): async def on_POST(self, request: Request, stagetype: str) -> None:
session = parse_string(request, "session") session = parse_string(request, "session")
if not session: if not session:
@ -172,5 +175,5 @@ class AuthRestServlet(RestServlet):
return None return None
def register_servlets(hs, http_server): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
AuthRestServlet(hs).register(http_server) AuthRestServlet(hs).register(http_server)

View file

@ -14,34 +14,36 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Tuple
from synapse.api import errors from synapse.api import errors
from synapse.http.server import HttpServer
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, RestServlet,
assert_params_in_dict, assert_params_in_dict,
parse_json_object_from_request, parse_json_object_from_request,
) )
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.types import JsonDict
from ._base import client_patterns, interactive_auth_handler from ._base import client_patterns, interactive_auth_handler
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DevicesRestServlet(RestServlet): class DevicesRestServlet(RestServlet):
PATTERNS = client_patterns("/devices$") PATTERNS = client_patterns("/devices$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
async def on_GET(self, request): async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
devices = await self.device_handler.get_devices_by_user( devices = await self.device_handler.get_devices_by_user(
requester.user.to_string() requester.user.to_string()
@ -57,7 +59,7 @@ class DeleteDevicesRestServlet(RestServlet):
PATTERNS = client_patterns("/delete_devices") PATTERNS = client_patterns("/delete_devices")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
@ -65,7 +67,7 @@ class DeleteDevicesRestServlet(RestServlet):
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
@interactive_auth_handler @interactive_auth_handler
async def on_POST(self, request): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
try: try:
@ -100,18 +102,16 @@ class DeleteDevicesRestServlet(RestServlet):
class DeviceRestServlet(RestServlet): class DeviceRestServlet(RestServlet):
PATTERNS = client_patterns("/devices/(?P<device_id>[^/]*)$") PATTERNS = client_patterns("/devices/(?P<device_id>[^/]*)$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
async def on_GET(self, request, device_id): async def on_GET(
self, request: SynapseRequest, device_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
device = await self.device_handler.get_device( device = await self.device_handler.get_device(
requester.user.to_string(), device_id requester.user.to_string(), device_id
@ -119,7 +119,9 @@ class DeviceRestServlet(RestServlet):
return 200, device return 200, device
@interactive_auth_handler @interactive_auth_handler
async def on_DELETE(self, request, device_id): async def on_DELETE(
self, request: SynapseRequest, device_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
try: try:
@ -146,7 +148,9 @@ class DeviceRestServlet(RestServlet):
await self.device_handler.delete_device(requester.user.to_string(), device_id) await self.device_handler.delete_device(requester.user.to_string(), device_id)
return 200, {} return 200, {}
async def on_PUT(self, request, device_id): async def on_PUT(
self, request: SynapseRequest, device_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
@ -193,13 +197,13 @@ class DehydratedDeviceServlet(RestServlet):
PATTERNS = client_patterns("/org.matrix.msc2697.v2/dehydrated_device", releases=()) PATTERNS = client_patterns("/org.matrix.msc2697.v2/dehydrated_device", releases=())
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
async def on_GET(self, request: SynapseRequest): async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
dehydrated_device = await self.device_handler.get_dehydrated_device( dehydrated_device = await self.device_handler.get_dehydrated_device(
requester.user.to_string() requester.user.to_string()
@ -211,7 +215,7 @@ class DehydratedDeviceServlet(RestServlet):
else: else:
raise errors.NotFoundError("No dehydrated device available") raise errors.NotFoundError("No dehydrated device available")
async def on_PUT(self, request: SynapseRequest): async def on_PUT(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
submission = parse_json_object_from_request(request) submission = parse_json_object_from_request(request)
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
@ -259,13 +263,13 @@ class ClaimDehydratedDeviceServlet(RestServlet):
"/org.matrix.msc2697.v2/dehydrated_device/claim", releases=() "/org.matrix.msc2697.v2/dehydrated_device/claim", releases=()
) )
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
async def on_POST(self, request: SynapseRequest): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
submission = parse_json_object_from_request(request) submission = parse_json_object_from_request(request)
@ -292,7 +296,7 @@ class ClaimDehydratedDeviceServlet(RestServlet):
return (200, result) return (200, result)
def register_servlets(hs, http_server): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
DeleteDevicesRestServlet(hs).register(http_server) DeleteDevicesRestServlet(hs).register(http_server)
DevicesRestServlet(hs).register(http_server) DevicesRestServlet(hs).register(http_server)
DeviceRestServlet(hs).register(http_server) DeviceRestServlet(hs).register(http_server)

View file

@ -14,11 +14,18 @@
"""This module contains REST servlets to do with event streaming, /events.""" """This module contains REST servlets to do with event streaming, /events."""
import logging import logging
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_string
from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns from synapse.rest.client._base import client_patterns
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -28,31 +35,30 @@ class EventStreamRestServlet(RestServlet):
DEFAULT_LONGPOLL_TIME_MS = 30000 DEFAULT_LONGPOLL_TIME_MS = 30000
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.event_stream_handler = hs.get_event_stream_handler() self.event_stream_handler = hs.get_event_stream_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
async def on_GET(self, request): async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
is_guest = requester.is_guest is_guest = requester.is_guest
room_id = None args: Dict[bytes, List[bytes]] = request.args # type: ignore
if is_guest: if is_guest:
if b"room_id" not in request.args: if b"room_id" not in args:
raise SynapseError(400, "Guest users must specify room_id param") raise SynapseError(400, "Guest users must specify room_id param")
if b"room_id" in request.args: room_id = parse_string(request, "room_id")
room_id = request.args[b"room_id"][0].decode("ascii")
pagin_config = await PaginationConfig.from_request(self.store, request) pagin_config = await PaginationConfig.from_request(self.store, request)
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
if b"timeout" in request.args: if b"timeout" in args:
try: try:
timeout = int(request.args[b"timeout"][0]) timeout = int(args[b"timeout"][0])
except ValueError: except ValueError:
raise SynapseError(400, "timeout must be in milliseconds.") raise SynapseError(400, "timeout must be in milliseconds.")
as_client_event = b"raw" not in request.args as_client_event = b"raw" not in args
chunk = await self.event_stream_handler.get_stream( chunk = await self.event_stream_handler.get_stream(
requester.user.to_string(), requester.user.to_string(),
@ -70,25 +76,27 @@ class EventStreamRestServlet(RestServlet):
class EventRestServlet(RestServlet): class EventRestServlet(RestServlet):
PATTERNS = client_patterns("/events/(?P<event_id>[^/]*)$", v1=True) PATTERNS = client_patterns("/events/(?P<event_id>[^/]*)$", v1=True)
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.event_handler = hs.get_event_handler() self.event_handler = hs.get_event_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
async def on_GET(self, request, event_id): async def on_GET(
self, request: SynapseRequest, event_id: str
) -> Tuple[int, Union[str, JsonDict]]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
event = await self.event_handler.get_event(requester.user, None, event_id) event = await self.event_handler.get_event(requester.user, None, event_id)
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
if event: if event:
event = await self._event_serializer.serialize_event(event, time_now) result = await self._event_serializer.serialize_event(event, time_now)
return 200, event return 200, result
else: else:
return 404, "Event not found." return 404, "Event not found."
def register_servlets(hs, http_server): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
EventStreamRestServlet(hs).register(http_server) EventStreamRestServlet(hs).register(http_server)
EventRestServlet(hs).register(http_server) EventRestServlet(hs).register(http_server)

View file

@ -13,26 +13,34 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import AuthError, NotFoundError, StoreError, SynapseError from synapse.api.errors import AuthError, NotFoundError, StoreError, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import UserID from synapse.http.site import SynapseRequest
from synapse.types import JsonDict, UserID
from ._base import client_patterns, set_timeline_upper_limit from ._base import client_patterns, set_timeline_upper_limit
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class GetFilterRestServlet(RestServlet): class GetFilterRestServlet(RestServlet):
PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/filter/(?P<filter_id>[^/]*)") PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/filter/(?P<filter_id>[^/]*)")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.filtering = hs.get_filtering() self.filtering = hs.get_filtering()
async def on_GET(self, request, user_id, filter_id): async def on_GET(
self, request: SynapseRequest, user_id: str, filter_id: str
) -> Tuple[int, JsonDict]:
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
@ -43,13 +51,13 @@ class GetFilterRestServlet(RestServlet):
raise AuthError(403, "Can only get filters for local users") raise AuthError(403, "Can only get filters for local users")
try: try:
filter_id = int(filter_id) filter_id_int = int(filter_id)
except Exception: except Exception:
raise SynapseError(400, "Invalid filter_id") raise SynapseError(400, "Invalid filter_id")
try: try:
filter_collection = await self.filtering.get_user_filter( filter_collection = await self.filtering.get_user_filter(
user_localpart=target_user.localpart, filter_id=filter_id user_localpart=target_user.localpart, filter_id=filter_id_int
) )
except StoreError as e: except StoreError as e:
if e.code != 404: if e.code != 404:
@ -62,13 +70,15 @@ class GetFilterRestServlet(RestServlet):
class CreateFilterRestServlet(RestServlet): class CreateFilterRestServlet(RestServlet):
PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/filter") PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/filter")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.filtering = hs.get_filtering() self.filtering = hs.get_filtering()
async def on_POST(self, request, user_id): async def on_POST(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
@ -89,6 +99,6 @@ class CreateFilterRestServlet(RestServlet):
return 200, {"filter_id": str(filter_id)} return 200, {"filter_id": str(filter_id)}
def register_servlets(hs, http_server): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
GetFilterRestServlet(hs).register(http_server) GetFilterRestServlet(hs).register(http_server)
CreateFilterRestServlet(hs).register(http_server) CreateFilterRestServlet(hs).register(http_server)

View file

@ -26,6 +26,7 @@ from synapse.api.constants import (
) )
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.handlers.groups_local import GroupsLocalHandler from synapse.handlers.groups_local import GroupsLocalHandler
from synapse.http.server import HttpServer
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, RestServlet,
assert_params_in_dict, assert_params_in_dict,
@ -930,7 +931,7 @@ class GroupsForUserServlet(RestServlet):
return 200, result return 200, result
def register_servlets(hs: "HomeServer", http_server): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
GroupServlet(hs).register(http_server) GroupServlet(hs).register(http_server)
GroupSummaryServlet(hs).register(http_server) GroupSummaryServlet(hs).register(http_server)
GroupInvitedUsersServlet(hs).register(http_server) GroupInvitedUsersServlet(hs).register(http_server)

View file

@ -12,25 +12,33 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING, Dict, List, Tuple
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_boolean from synapse.http.servlet import RestServlet, parse_boolean
from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns from synapse.rest.client._base import client_patterns
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
# TODO: Needs unit testing # TODO: Needs unit testing
class InitialSyncRestServlet(RestServlet): class InitialSyncRestServlet(RestServlet):
PATTERNS = client_patterns("/initialSync$", v1=True) PATTERNS = client_patterns("/initialSync$", v1=True)
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.initial_sync_handler = hs.get_initial_sync_handler() self.initial_sync_handler = hs.get_initial_sync_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
async def on_GET(self, request): async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
as_client_event = b"raw" not in request.args args: Dict[bytes, List[bytes]] = request.args # type: ignore
as_client_event = b"raw" not in args
pagination_config = await PaginationConfig.from_request(self.store, request) pagination_config = await PaginationConfig.from_request(self.store, request)
include_archived = parse_boolean(request, "archived", default=False) include_archived = parse_boolean(request, "archived", default=False)
content = await self.initial_sync_handler.snapshot_all_rooms( content = await self.initial_sync_handler.snapshot_all_rooms(
@ -43,5 +51,5 @@ class InitialSyncRestServlet(RestServlet):
return 200, content return 200, content
def register_servlets(hs, http_server): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
InitialSyncRestServlet(hs).register(http_server) InitialSyncRestServlet(hs).register(http_server)

View file

@ -15,20 +15,25 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Any from typing import TYPE_CHECKING, Any, Optional, Tuple
from synapse.api.errors import InvalidAPICallError, SynapseError from synapse.api.errors import InvalidAPICallError, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, RestServlet,
parse_integer, parse_integer,
parse_json_object_from_request, parse_json_object_from_request,
parse_string, parse_string,
) )
from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.types import StreamToken from synapse.types import JsonDict, StreamToken
from ._base import client_patterns, interactive_auth_handler from ._base import client_patterns, interactive_auth_handler
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -60,18 +65,16 @@ class KeyUploadServlet(RestServlet):
PATTERNS = client_patterns("/keys/upload(/(?P<device_id>[^/]+))?$") PATTERNS = client_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler() self.e2e_keys_handler = hs.get_e2e_keys_handler()
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
@trace(opname="upload_keys") @trace(opname="upload_keys")
async def on_POST(self, request, device_id): async def on_POST(
self, request: SynapseRequest, device_id: Optional[str]
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
user_id = requester.user.to_string() user_id = requester.user.to_string()
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
@ -149,16 +152,12 @@ class KeyQueryServlet(RestServlet):
PATTERNS = client_patterns("/keys/query$") PATTERNS = client_patterns("/keys/query$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
"""
Args:
hs (synapse.server.HomeServer):
"""
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler() self.e2e_keys_handler = hs.get_e2e_keys_handler()
async def on_POST(self, request): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
user_id = requester.user.to_string() user_id = requester.user.to_string()
device_id = requester.device_id device_id = requester.device_id
@ -195,17 +194,13 @@ class KeyChangesServlet(RestServlet):
PATTERNS = client_patterns("/keys/changes$") PATTERNS = client_patterns("/keys/changes$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
"""
Args:
hs (synapse.server.HomeServer):
"""
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
self.store = hs.get_datastore() self.store = hs.get_datastore()
async def on_GET(self, request): async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
from_token_string = parse_string(request, "from", required=True) from_token_string = parse_string(request, "from", required=True)
@ -245,12 +240,12 @@ class OneTimeKeyServlet(RestServlet):
PATTERNS = client_patterns("/keys/claim$") PATTERNS = client_patterns("/keys/claim$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler() self.e2e_keys_handler = hs.get_e2e_keys_handler()
async def on_POST(self, request): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True) await self.auth.get_user_by_req(request, allow_guest=True)
timeout = parse_integer(request, "timeout", 10 * 1000) timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
@ -269,11 +264,7 @@ class SigningKeyUploadServlet(RestServlet):
PATTERNS = client_patterns("/keys/device_signing/upload$", releases=()) PATTERNS = client_patterns("/keys/device_signing/upload$", releases=())
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
@ -281,7 +272,7 @@ class SigningKeyUploadServlet(RestServlet):
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
@interactive_auth_handler @interactive_auth_handler
async def on_POST(self, request): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string() user_id = requester.user.to_string()
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
@ -329,16 +320,12 @@ class SignaturesUploadServlet(RestServlet):
PATTERNS = client_patterns("/keys/signatures/upload$") PATTERNS = client_patterns("/keys/signatures/upload$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler() self.e2e_keys_handler = hs.get_e2e_keys_handler()
async def on_POST(self, request): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
user_id = requester.user.to_string() user_id = requester.user.to_string()
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
@ -349,7 +336,7 @@ class SignaturesUploadServlet(RestServlet):
return 200, result return 200, result
def register_servlets(hs, http_server): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
KeyUploadServlet(hs).register(http_server) KeyUploadServlet(hs).register(http_server)
KeyQueryServlet(hs).register(http_server) KeyQueryServlet(hs).register(http_server)
KeyChangesServlet(hs).register(http_server) KeyChangesServlet(hs).register(http_server)

View file

@ -19,6 +19,7 @@ from twisted.web.server import Request
from synapse.api.constants import Membership from synapse.api.constants import Membership
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, RestServlet,
parse_json_object_from_request, parse_json_object_from_request,
@ -103,5 +104,5 @@ class KnockRoomAliasServlet(RestServlet):
) )
def register_servlets(hs, http_server): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
KnockRoomAliasServlet(hs).register(http_server) KnockRoomAliasServlet(hs).register(http_server)

View file

@ -1,4 +1,4 @@
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -14,7 +14,7 @@
import logging import logging
import re import re
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Tuple
from typing_extensions import TypedDict from typing_extensions import TypedDict
@ -110,7 +110,7 @@ class LoginRestServlet(RestServlet):
# counters are initialised for the auth_provider_ids. # counters are initialised for the auth_provider_ids.
_load_sso_handlers(hs) _load_sso_handlers(hs)
def on_GET(self, request: SynapseRequest): def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
flows = [] flows = []
if self.jwt_enabled: if self.jwt_enabled:
flows.append({"type": LoginRestServlet.JWT_TYPE}) flows.append({"type": LoginRestServlet.JWT_TYPE})
@ -157,7 +157,7 @@ class LoginRestServlet(RestServlet):
return 200, {"flows": flows} return 200, {"flows": flows}
async def on_POST(self, request: SynapseRequest): async def on_POST(self, request: SynapseRequest) -> Tuple[int, LoginResponse]:
login_submission = parse_json_object_from_request(request) login_submission = parse_json_object_from_request(request)
if self._msc2918_enabled: if self._msc2918_enabled:
@ -217,7 +217,7 @@ class LoginRestServlet(RestServlet):
login_submission: JsonDict, login_submission: JsonDict,
appservice: ApplicationService, appservice: ApplicationService,
should_issue_refresh_token: bool = False, should_issue_refresh_token: bool = False,
): ) -> LoginResponse:
identifier = login_submission.get("identifier") identifier = login_submission.get("identifier")
logger.info("Got appservice login request with identifier: %r", identifier) logger.info("Got appservice login request with identifier: %r", identifier)
@ -467,10 +467,7 @@ class RefreshTokenServlet(RestServlet):
self._clock = hs.get_clock() self._clock = hs.get_clock()
self.access_token_lifetime = hs.config.access_token_lifetime self.access_token_lifetime = hs.config.access_token_lifetime
async def on_POST( async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
self,
request: SynapseRequest,
):
refresh_submission = parse_json_object_from_request(request) refresh_submission = parse_json_object_from_request(request)
assert_params_in_dict(refresh_submission, ["refresh_token"]) assert_params_in_dict(refresh_submission, ["refresh_token"])
@ -570,7 +567,7 @@ class SsoRedirectServlet(RestServlet):
class CasTicketServlet(RestServlet): class CasTicketServlet(RestServlet):
PATTERNS = client_patterns("/login/cas/ticket", v1=True) PATTERNS = client_patterns("/login/cas/ticket", v1=True)
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self._cas_handler = hs.get_cas_handler() self._cas_handler = hs.get_cas_handler()
@ -592,7 +589,7 @@ class CasTicketServlet(RestServlet):
) )
def register_servlets(hs, http_server): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
LoginRestServlet(hs).register(http_server) LoginRestServlet(hs).register(http_server)
if hs.config.access_token_lifetime is not None: if hs.config.access_token_lifetime is not None:
RefreshTokenServlet(hs).register(http_server) RefreshTokenServlet(hs).register(http_server)
@ -601,7 +598,7 @@ def register_servlets(hs, http_server):
CasTicketServlet(hs).register(http_server) CasTicketServlet(hs).register(http_server)
def _load_sso_handlers(hs: "HomeServer"): def _load_sso_handlers(hs: "HomeServer") -> None:
"""Ensure that the SSO handlers are loaded, if they are enabled by configuration. """Ensure that the SSO handlers are loaded, if they are enabled by configuration.
This is mostly useful to ensure that the CAS/SAML/OIDC handlers register themselves This is mostly useful to ensure that the CAS/SAML/OIDC handlers register themselves

View file

@ -13,9 +13,16 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Tuple
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns from synapse.rest.client._base import client_patterns
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -23,13 +30,13 @@ logger = logging.getLogger(__name__)
class LogoutRestServlet(RestServlet): class LogoutRestServlet(RestServlet):
PATTERNS = client_patterns("/logout$", v1=True) PATTERNS = client_patterns("/logout$", v1=True)
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler() self._device_handler = hs.get_device_handler()
async def on_POST(self, request): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_expired=True) requester = await self.auth.get_user_by_req(request, allow_expired=True)
if requester.device_id is None: if requester.device_id is None:
@ -48,13 +55,13 @@ class LogoutRestServlet(RestServlet):
class LogoutAllRestServlet(RestServlet): class LogoutAllRestServlet(RestServlet):
PATTERNS = client_patterns("/logout/all$", v1=True) PATTERNS = client_patterns("/logout/all$", v1=True)
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler() self._device_handler = hs.get_device_handler()
async def on_POST(self, request): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_expired=True) requester = await self.auth.get_user_by_req(request, allow_expired=True)
user_id = requester.user.to_string() user_id = requester.user.to_string()
@ -67,6 +74,6 @@ class LogoutAllRestServlet(RestServlet):
return 200, {} return 200, {}
def register_servlets(hs, http_server): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
LogoutRestServlet(hs).register(http_server) LogoutRestServlet(hs).register(http_server)
LogoutAllRestServlet(hs).register(http_server) LogoutAllRestServlet(hs).register(http_server)

View file

@ -13,26 +13,33 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Tuple
from synapse.events.utils import format_event_for_client_v2_without_room_id from synapse.events.utils import format_event_for_client_v2_without_room_id
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_integer, parse_string from synapse.http.servlet import RestServlet, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict
from ._base import client_patterns from ._base import client_patterns
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class NotificationsServlet(RestServlet): class NotificationsServlet(RestServlet):
PATTERNS = client_patterns("/notifications$") PATTERNS = client_patterns("/notifications$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
async def on_GET(self, request): async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string() user_id = requester.user.to_string()
@ -87,5 +94,5 @@ class NotificationsServlet(RestServlet):
return 200, {"notifications": returned_push_actions, "next_token": next_token} return 200, {"notifications": returned_push_actions, "next_token": next_token}
def register_servlets(hs, http_server): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
NotificationsServlet(hs).register(http_server) NotificationsServlet(hs).register(http_server)

View file

@ -12,15 +12,21 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
from ._base import client_patterns from ._base import client_patterns
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -58,14 +64,16 @@ class IdTokenServlet(RestServlet):
EXPIRES_MS = 3600 * 1000 EXPIRES_MS = 3600 * 1000
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.server_name = hs.config.server_name self.server_name = hs.config.server_name
async def on_POST(self, request, user_id): async def on_POST(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string(): if user_id != requester.user.to_string():
raise AuthError(403, "Cannot request tokens for other users.") raise AuthError(403, "Cannot request tokens for other users.")
@ -90,5 +98,5 @@ class IdTokenServlet(RestServlet):
) )
def register_servlets(hs, http_server): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
IdTokenServlet(hs).register(http_server) IdTokenServlet(hs).register(http_server)

View file

@ -13,28 +13,32 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Tuple
from twisted.web.server import Request
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from synapse.types import JsonDict
from ._base import client_patterns from ._base import client_patterns
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PasswordPolicyServlet(RestServlet): class PasswordPolicyServlet(RestServlet):
PATTERNS = client_patterns("/password_policy$") PATTERNS = client_patterns("/password_policy$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super().__init__() super().__init__()
self.policy = hs.config.password_policy self.policy = hs.config.password_policy
self.enabled = hs.config.password_policy_enabled self.enabled = hs.config.password_policy_enabled
def on_GET(self, request): def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
if not self.enabled or not self.policy: if not self.enabled or not self.policy:
return (200, {}) return (200, {})
@ -53,5 +57,5 @@ class PasswordPolicyServlet(RestServlet):
return (200, policy) return (200, policy)
def register_servlets(hs, http_server): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
PasswordPolicyServlet(hs).register(http_server) PasswordPolicyServlet(hs).register(http_server)

View file

@ -15,12 +15,18 @@
""" This module contains REST servlets to do with presence: /presence/<paths> """ This module contains REST servlets to do with presence: /presence/<paths>
""" """
import logging import logging
from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import AuthError, SynapseError from synapse.api.errors import AuthError, SynapseError
from synapse.handlers.presence import format_user_presence_state from synapse.handlers.presence import format_user_presence_state
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns from synapse.rest.client._base import client_patterns
from synapse.types import UserID from synapse.types import JsonDict, UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -28,7 +34,7 @@ logger = logging.getLogger(__name__)
class PresenceStatusRestServlet(RestServlet): class PresenceStatusRestServlet(RestServlet):
PATTERNS = client_patterns("/presence/(?P<user_id>[^/]*)/status", v1=True) PATTERNS = client_patterns("/presence/(?P<user_id>[^/]*)/status", v1=True)
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self.presence_handler = hs.get_presence_handler() self.presence_handler = hs.get_presence_handler()
@ -37,7 +43,9 @@ class PresenceStatusRestServlet(RestServlet):
self._use_presence = hs.config.server.use_presence self._use_presence = hs.config.server.use_presence
async def on_GET(self, request, user_id): async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
@ -53,13 +61,15 @@ class PresenceStatusRestServlet(RestServlet):
raise AuthError(403, "You are not allowed to see their presence.") raise AuthError(403, "You are not allowed to see their presence.")
state = await self.presence_handler.get_state(target_user=user) state = await self.presence_handler.get_state(target_user=user)
state = format_user_presence_state( result = format_user_presence_state(
state, self.clock.time_msec(), include_user_id=False state, self.clock.time_msec(), include_user_id=False
) )
return 200, state return 200, result
async def on_PUT(self, request, user_id): async def on_PUT(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
@ -91,5 +101,5 @@ class PresenceStatusRestServlet(RestServlet):
return 200, {} return 200, {}
def register_servlets(hs, http_server): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
PresenceStatusRestServlet(hs).register(http_server) PresenceStatusRestServlet(hs).register(http_server)

View file

@ -14,22 +14,31 @@
""" This module contains REST servlets to do with profile: /profile/<paths> """ """ This module contains REST servlets to do with profile: /profile/<paths> """
from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns from synapse.rest.client._base import client_patterns
from synapse.types import UserID from synapse.types import JsonDict, UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
class ProfileDisplaynameRestServlet(RestServlet): class ProfileDisplaynameRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/displayname", v1=True) PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/displayname", v1=True)
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self.profile_handler = hs.get_profile_handler() self.profile_handler = hs.get_profile_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
async def on_GET(self, request, user_id): async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
requester_user = None requester_user = None
if self.hs.config.require_auth_for_profile_requests: if self.hs.config.require_auth_for_profile_requests:
@ -48,7 +57,9 @@ class ProfileDisplaynameRestServlet(RestServlet):
return 200, ret return 200, ret
async def on_PUT(self, request, user_id): async def on_PUT(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
is_admin = await self.auth.is_server_admin(requester.user) is_admin = await self.auth.is_server_admin(requester.user)
@ -72,13 +83,15 @@ class ProfileDisplaynameRestServlet(RestServlet):
class ProfileAvatarURLRestServlet(RestServlet): class ProfileAvatarURLRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/avatar_url", v1=True) PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/avatar_url", v1=True)
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self.profile_handler = hs.get_profile_handler() self.profile_handler = hs.get_profile_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
async def on_GET(self, request, user_id): async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
requester_user = None requester_user = None
if self.hs.config.require_auth_for_profile_requests: if self.hs.config.require_auth_for_profile_requests:
@ -97,7 +110,9 @@ class ProfileAvatarURLRestServlet(RestServlet):
return 200, ret return 200, ret
async def on_PUT(self, request, user_id): async def on_PUT(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
is_admin = await self.auth.is_server_admin(requester.user) is_admin = await self.auth.is_server_admin(requester.user)
@ -120,13 +135,15 @@ class ProfileAvatarURLRestServlet(RestServlet):
class ProfileRestServlet(RestServlet): class ProfileRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)", v1=True) PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)", v1=True)
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self.profile_handler = hs.get_profile_handler() self.profile_handler = hs.get_profile_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
async def on_GET(self, request, user_id): async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
requester_user = None requester_user = None
if self.hs.config.require_auth_for_profile_requests: if self.hs.config.require_auth_for_profile_requests:
@ -149,7 +166,7 @@ class ProfileRestServlet(RestServlet):
return 200, ret return 200, ret
def register_servlets(hs, http_server): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ProfileDisplaynameRestServlet(hs).register(http_server) ProfileDisplaynameRestServlet(hs).register(http_server)
ProfileAvatarURLRestServlet(hs).register(http_server) ProfileAvatarURLRestServlet(hs).register(http_server)
ProfileRestServlet(hs).register(http_server) ProfileRestServlet(hs).register(http_server)