Add missing type hints to handlers and fix a Spam Checker type hint. (#9896)

The user_may_create_room_alias method on spam checkers
declared the room_alias parameter as a str when in reality it is
passed a RoomAlias object.
This commit is contained in:
Patrick Cloke 2021-04-29 07:17:28 -04:00 committed by GitHub
parent 0085dc5abc
commit bb4b11846f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 82 additions and 54 deletions

1
changelog.d/9896.bugfix Normal file
View file

@ -0,0 +1 @@
Correct the type hint for the `user_may_create_room_alias` method of spam checkers. It is provided a `RoomAlias`, not a `str`.

1
changelog.d/9896.misc Normal file
View file

@ -0,0 +1 @@
Add type hints to the `synapse.handlers` module.

View file

@ -20,6 +20,7 @@ from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Tuple,
from synapse.rest.media.v1._base import FileInfo from synapse.rest.media.v1._base import FileInfo
from synapse.rest.media.v1.media_storage import ReadableFileWrapper from synapse.rest.media.v1.media_storage import ReadableFileWrapper
from synapse.spam_checker_api import RegistrationBehaviour from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import RoomAlias
from synapse.util.async_helpers import maybe_awaitable from synapse.util.async_helpers import maybe_awaitable
if TYPE_CHECKING: if TYPE_CHECKING:
@ -113,7 +114,9 @@ class SpamChecker:
return True return True
async def user_may_create_room_alias(self, userid: str, room_alias: str) -> bool: async def user_may_create_room_alias(
self, userid: str, room_alias: RoomAlias
) -> bool:
"""Checks if a given user may create a room alias """Checks if a given user may create a room alias
If this method returns false, the association request will be rejected. If this method returns false, the association request will be rejected.

View file

@ -14,7 +14,7 @@
import logging import logging
import string import string
from typing import Iterable, List, Optional from typing import TYPE_CHECKING, Iterable, List, Optional
from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes
from synapse.api.errors import ( from synapse.api.errors import (
@ -27,15 +27,19 @@ from synapse.api.errors import (
SynapseError, SynapseError,
) )
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.types import Requester, RoomAlias, UserID, get_domain_from_id from synapse.storage.databases.main.directory import RoomAliasMapping
from synapse.types import JsonDict, Requester, RoomAlias, UserID, get_domain_from_id
from ._base import BaseHandler from ._base import BaseHandler
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DirectoryHandler(BaseHandler): class DirectoryHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
@ -60,7 +64,7 @@ class DirectoryHandler(BaseHandler):
room_id: str, room_id: str,
servers: Optional[Iterable[str]] = None, servers: Optional[Iterable[str]] = None,
creator: Optional[str] = None, creator: Optional[str] = None,
): ) -> None:
# general association creation for both human users and app services # general association creation for both human users and app services
for wchar in string.whitespace: for wchar in string.whitespace:
@ -104,8 +108,9 @@ class DirectoryHandler(BaseHandler):
""" """
user_id = requester.user.to_string() user_id = requester.user.to_string()
room_alias_str = room_alias.to_string()
if len(room_alias.to_string()) > MAX_ALIAS_LENGTH: if len(room_alias_str) > MAX_ALIAS_LENGTH:
raise SynapseError( raise SynapseError(
400, 400,
"Can't create aliases longer than %s characters" % MAX_ALIAS_LENGTH, "Can't create aliases longer than %s characters" % MAX_ALIAS_LENGTH,
@ -114,7 +119,7 @@ class DirectoryHandler(BaseHandler):
service = requester.app_service service = requester.app_service
if service: if service:
if not service.is_interested_in_alias(room_alias.to_string()): if not service.is_interested_in_alias(room_alias_str):
raise SynapseError( raise SynapseError(
400, 400,
"This application service has not reserved this kind of alias.", "This application service has not reserved this kind of alias.",
@ -138,7 +143,7 @@ class DirectoryHandler(BaseHandler):
raise AuthError(403, "This user is not permitted to create this alias") raise AuthError(403, "This user is not permitted to create this alias")
if not self.config.is_alias_creation_allowed( if not self.config.is_alias_creation_allowed(
user_id, room_id, room_alias.to_string() user_id, room_id, room_alias_str
): ):
# Lets just return a generic message, as there may be all sorts of # Lets just return a generic message, as there may be all sorts of
# reasons why we said no. TODO: Allow configurable error messages # reasons why we said no. TODO: Allow configurable error messages
@ -211,7 +216,7 @@ class DirectoryHandler(BaseHandler):
async def delete_appservice_association( async def delete_appservice_association(
self, service: ApplicationService, room_alias: RoomAlias self, service: ApplicationService, room_alias: RoomAlias
): ) -> None:
if not service.is_interested_in_alias(room_alias.to_string()): if not service.is_interested_in_alias(room_alias.to_string()):
raise SynapseError( raise SynapseError(
400, 400,
@ -220,7 +225,7 @@ class DirectoryHandler(BaseHandler):
) )
await self._delete_association(room_alias) await self._delete_association(room_alias)
async def _delete_association(self, room_alias: RoomAlias): async def _delete_association(self, room_alias: RoomAlias) -> str:
if not self.hs.is_mine(room_alias): if not self.hs.is_mine(room_alias):
raise SynapseError(400, "Room alias must be local") raise SynapseError(400, "Room alias must be local")
@ -228,17 +233,19 @@ class DirectoryHandler(BaseHandler):
return room_id return room_id
async def get_association(self, room_alias: RoomAlias): async def get_association(self, room_alias: RoomAlias) -> JsonDict:
room_id = None room_id = None
if self.hs.is_mine(room_alias): if self.hs.is_mine(room_alias):
result = await self.get_association_from_room_alias(room_alias) result = await self.get_association_from_room_alias(
room_alias
) # type: Optional[RoomAliasMapping]
if result: if result:
room_id = result.room_id room_id = result.room_id
servers = result.servers servers = result.servers
else: else:
try: try:
result = await self.federation.make_query( fed_result = await self.federation.make_query(
destination=room_alias.domain, destination=room_alias.domain,
query_type="directory", query_type="directory",
args={"room_alias": room_alias.to_string()}, args={"room_alias": room_alias.to_string()},
@ -248,13 +255,13 @@ class DirectoryHandler(BaseHandler):
except CodeMessageException as e: except CodeMessageException as e:
logging.warning("Error retrieving alias") logging.warning("Error retrieving alias")
if e.code == 404: if e.code == 404:
result = None fed_result = None
else: else:
raise raise
if result and "room_id" in result and "servers" in result: if fed_result and "room_id" in fed_result and "servers" in fed_result:
room_id = result["room_id"] room_id = fed_result["room_id"]
servers = result["servers"] servers = fed_result["servers"]
if not room_id: if not room_id:
raise SynapseError( raise SynapseError(
@ -275,7 +282,7 @@ class DirectoryHandler(BaseHandler):
return {"room_id": room_id, "servers": servers} return {"room_id": room_id, "servers": servers}
async def on_directory_query(self, args): async def on_directory_query(self, args: JsonDict) -> JsonDict:
room_alias = RoomAlias.from_string(args["room_alias"]) room_alias = RoomAlias.from_string(args["room_alias"])
if not self.hs.is_mine(room_alias): if not self.hs.is_mine(room_alias):
raise SynapseError(400, "Room Alias is not hosted on this homeserver") raise SynapseError(400, "Room Alias is not hosted on this homeserver")
@ -293,7 +300,7 @@ class DirectoryHandler(BaseHandler):
async def _update_canonical_alias( async def _update_canonical_alias(
self, requester: Requester, user_id: str, room_id: str, room_alias: RoomAlias self, requester: Requester, user_id: str, room_id: str, room_alias: RoomAlias
): ) -> None:
""" """
Send an updated canonical alias event if the removed alias was set as Send an updated canonical alias event if the removed alias was set as
the canonical alias or listed in the alt_aliases field. the canonical alias or listed in the alt_aliases field.
@ -344,7 +351,9 @@ class DirectoryHandler(BaseHandler):
ratelimit=False, ratelimit=False,
) )
async def get_association_from_room_alias(self, room_alias: RoomAlias): async def get_association_from_room_alias(
self, room_alias: RoomAlias
) -> Optional[RoomAliasMapping]:
result = await self.store.get_association_from_room_alias(room_alias) result = await self.store.get_association_from_room_alias(room_alias)
if not result: if not result:
# Query AS to see if it exists # Query AS to see if it exists
@ -372,7 +381,7 @@ class DirectoryHandler(BaseHandler):
# either no interested services, or no service with an exclusive lock # either no interested services, or no service with an exclusive lock
return True return True
async def _user_can_delete_alias(self, alias: RoomAlias, user_id: str): async def _user_can_delete_alias(self, alias: RoomAlias, user_id: str) -> bool:
"""Determine whether a user can delete an alias. """Determine whether a user can delete an alias.
One of the following must be true: One of the following must be true:
@ -394,14 +403,13 @@ class DirectoryHandler(BaseHandler):
if not room_id: if not room_id:
return False return False
res = await self.auth.check_can_change_room_list( return await self.auth.check_can_change_room_list(
room_id, UserID.from_string(user_id) room_id, UserID.from_string(user_id)
) )
return res
async def edit_published_room_list( async def edit_published_room_list(
self, requester: Requester, room_id: str, visibility: str self, requester: Requester, room_id: str, visibility: str
): ) -> None:
"""Edit the entry of the room in the published room list. """Edit the entry of the room in the published room list.
requester requester
@ -469,7 +477,7 @@ class DirectoryHandler(BaseHandler):
async def edit_published_appservice_room_list( async def edit_published_appservice_room_list(
self, appservice_id: str, network_id: str, room_id: str, visibility: str self, appservice_id: str, network_id: str, room_id: str, visibility: str
): ) -> None:
"""Add or remove a room from the appservice/network specific public """Add or remove a room from the appservice/network specific public
room list. room list.
@ -499,5 +507,4 @@ class DirectoryHandler(BaseHandler):
room_id, requester.user.to_string() room_id, requester.user.to_string()
) )
aliases = await self.store.get_aliases_for_room(room_id) return await self.store.get_aliases_for_room(room_id)
return aliases

View file

@ -17,7 +17,7 @@
"""Utilities for interacting with Identity Servers""" """Utilities for interacting with Identity Servers"""
import logging import logging
import urllib.parse import urllib.parse
from typing import Awaitable, Callable, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Tuple
from synapse.api.errors import ( from synapse.api.errors import (
CodeMessageException, CodeMessageException,
@ -41,13 +41,16 @@ from synapse.util.stringutils import (
from ._base import BaseHandler from ._base import BaseHandler
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
id_server_scheme = "https://" id_server_scheme = "https://"
class IdentityHandler(BaseHandler): class IdentityHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
# An HTTP client for contacting trusted URLs. # An HTTP client for contacting trusted URLs.
@ -80,7 +83,7 @@ class IdentityHandler(BaseHandler):
request: SynapseRequest, request: SynapseRequest,
medium: str, medium: str,
address: str, address: str,
): ) -> None:
"""Used to ratelimit requests to `/requestToken` by IP and address. """Used to ratelimit requests to `/requestToken` by IP and address.
Args: Args:

View file

@ -15,7 +15,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
import random import random
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
@ -66,7 +66,7 @@ logger = logging.getLogger(__name__)
class MessageHandler: class MessageHandler:
"""Contains some read only APIs to get state about a room""" """Contains some read only APIs to get state about a room"""
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
@ -91,7 +91,7 @@ class MessageHandler:
room_id: str, room_id: str,
event_type: str, event_type: str,
state_key: str, state_key: str,
) -> dict: ) -> Optional[EventBase]:
"""Get data from a room. """Get data from a room.
Args: Args:
@ -115,6 +115,10 @@ class MessageHandler:
data = await self.state.get_current_state(room_id, event_type, state_key) data = await self.state.get_current_state(room_id, event_type, state_key)
elif membership == Membership.LEAVE: elif membership == Membership.LEAVE:
key = (event_type, state_key) key = (event_type, state_key)
# If the membership is not JOIN, then the event ID should exist.
assert (
membership_event_id is not None
), "check_user_in_room_or_world_readable returned invalid data"
room_state = await self.state_store.get_state_for_events( room_state = await self.state_store.get_state_for_events(
[membership_event_id], StateFilter.from_types([key]) [membership_event_id], StateFilter.from_types([key])
) )
@ -186,10 +190,12 @@ class MessageHandler:
event = last_events[0] event = last_events[0]
if visible_events: if visible_events:
room_state = await self.state_store.get_state_for_events( room_state_events = await self.state_store.get_state_for_events(
[event.event_id], state_filter=state_filter [event.event_id], state_filter=state_filter
) )
room_state = room_state[event.event_id] room_state = room_state_events[
event.event_id
] # type: Mapping[Any, EventBase]
else: else:
raise AuthError( raise AuthError(
403, 403,
@ -210,10 +216,14 @@ class MessageHandler:
) )
room_state = await self.store.get_events(state_ids.values()) room_state = await self.store.get_events(state_ids.values())
elif membership == Membership.LEAVE: elif membership == Membership.LEAVE:
room_state = await self.state_store.get_state_for_events( # If the membership is not JOIN, then the event ID should exist.
assert (
membership_event_id is not None
), "check_user_in_room_or_world_readable returned invalid data"
room_state_events = await self.state_store.get_state_for_events(
[membership_event_id], state_filter=state_filter [membership_event_id], state_filter=state_filter
) )
room_state = room_state[membership_event_id] room_state = room_state_events[membership_event_id]
now = self.clock.time_msec() now = self.clock.time_msec()
events = await self._event_serializer.serialize_events( events = await self._event_serializer.serialize_events(

View file

@ -1044,7 +1044,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
class RoomMemberMasterHandler(RoomMemberHandler): class RoomMemberMasterHandler(RoomMemberHandler):
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.distributor = hs.get_distributor() self.distributor = hs.get_distributor()

View file

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Any from typing import TYPE_CHECKING, Any
from twisted.web.client import PartialDownloadError from twisted.web.client import PartialDownloadError
@ -22,13 +22,16 @@ from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.config.emailconfig import ThreepidBehaviour from synapse.config.emailconfig import ThreepidBehaviour
from synapse.util import json_decoder from synapse.util import json_decoder
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class UserInteractiveAuthChecker: class UserInteractiveAuthChecker:
"""Abstract base class for an interactive auth checker""" """Abstract base class for an interactive auth checker"""
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
pass pass
def is_enabled(self) -> bool: def is_enabled(self) -> bool:
@ -57,10 +60,10 @@ class UserInteractiveAuthChecker:
class DummyAuthChecker(UserInteractiveAuthChecker): class DummyAuthChecker(UserInteractiveAuthChecker):
AUTH_TYPE = LoginType.DUMMY AUTH_TYPE = LoginType.DUMMY
def is_enabled(self): def is_enabled(self) -> bool:
return True return True
async def check_auth(self, authdict, clientip): async def check_auth(self, authdict: dict, clientip: str) -> Any:
return True return True
@ -70,24 +73,24 @@ class TermsAuthChecker(UserInteractiveAuthChecker):
def is_enabled(self): def is_enabled(self):
return True return True
async def check_auth(self, authdict, clientip): async def check_auth(self, authdict: dict, clientip: str) -> Any:
return True return True
class RecaptchaAuthChecker(UserInteractiveAuthChecker): class RecaptchaAuthChecker(UserInteractiveAuthChecker):
AUTH_TYPE = LoginType.RECAPTCHA AUTH_TYPE = LoginType.RECAPTCHA
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self._enabled = bool(hs.config.recaptcha_private_key) self._enabled = bool(hs.config.recaptcha_private_key)
self._http_client = hs.get_proxied_http_client() self._http_client = hs.get_proxied_http_client()
self._url = hs.config.recaptcha_siteverify_api self._url = hs.config.recaptcha_siteverify_api
self._secret = hs.config.recaptcha_private_key self._secret = hs.config.recaptcha_private_key
def is_enabled(self): def is_enabled(self) -> bool:
return self._enabled return self._enabled
async def check_auth(self, authdict, clientip): async def check_auth(self, authdict: dict, clientip: str) -> Any:
try: try:
user_response = authdict["response"] user_response = authdict["response"]
except KeyError: except KeyError:
@ -132,11 +135,11 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
class _BaseThreepidAuthChecker: class _BaseThreepidAuthChecker:
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
self.store = hs.get_datastore() self.store = hs.get_datastore()
async def _check_threepid(self, medium, authdict): async def _check_threepid(self, medium: str, authdict: dict) -> dict:
if "threepid_creds" not in authdict: if "threepid_creds" not in authdict:
raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM) raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)
@ -206,31 +209,31 @@ class _BaseThreepidAuthChecker:
class EmailIdentityAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker): class EmailIdentityAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
AUTH_TYPE = LoginType.EMAIL_IDENTITY AUTH_TYPE = LoginType.EMAIL_IDENTITY
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
UserInteractiveAuthChecker.__init__(self, hs) UserInteractiveAuthChecker.__init__(self, hs)
_BaseThreepidAuthChecker.__init__(self, hs) _BaseThreepidAuthChecker.__init__(self, hs)
def is_enabled(self): def is_enabled(self) -> bool:
return self.hs.config.threepid_behaviour_email in ( return self.hs.config.threepid_behaviour_email in (
ThreepidBehaviour.REMOTE, ThreepidBehaviour.REMOTE,
ThreepidBehaviour.LOCAL, ThreepidBehaviour.LOCAL,
) )
async def check_auth(self, authdict, clientip): async def check_auth(self, authdict: dict, clientip: str) -> Any:
return await self._check_threepid("email", authdict) return await self._check_threepid("email", authdict)
class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker): class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
AUTH_TYPE = LoginType.MSISDN AUTH_TYPE = LoginType.MSISDN
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
UserInteractiveAuthChecker.__init__(self, hs) UserInteractiveAuthChecker.__init__(self, hs)
_BaseThreepidAuthChecker.__init__(self, hs) _BaseThreepidAuthChecker.__init__(self, hs)
def is_enabled(self): def is_enabled(self) -> bool:
return bool(self.hs.config.account_threepid_delegate_msisdn) return bool(self.hs.config.account_threepid_delegate_msisdn)
async def check_auth(self, authdict, clientip): async def check_auth(self, authdict: dict, clientip: str) -> Any:
return await self._check_threepid("msisdn", authdict) return await self._check_threepid("msisdn", authdict)