forked from MirrorHub/synapse
Add missing type hints to handlers and fix a Spam Checker type hint. (#9896)
The user_may_create_room_alias method on spam checkers declared the room_alias parameter as a str when in reality it is passed a RoomAlias object.
This commit is contained in:
parent
0085dc5abc
commit
bb4b11846f
8 changed files with 82 additions and 54 deletions
1
changelog.d/9896.bugfix
Normal file
1
changelog.d/9896.bugfix
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Correct the type hint for the `user_may_create_room_alias` method of spam checkers. It is provided a `RoomAlias`, not a `str`.
|
1
changelog.d/9896.misc
Normal file
1
changelog.d/9896.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Add type hints to the `synapse.handlers` module.
|
|
@ -20,6 +20,7 @@ from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Tuple,
|
||||||
from synapse.rest.media.v1._base import FileInfo
|
from synapse.rest.media.v1._base import FileInfo
|
||||||
from synapse.rest.media.v1.media_storage import ReadableFileWrapper
|
from synapse.rest.media.v1.media_storage import ReadableFileWrapper
|
||||||
from synapse.spam_checker_api import RegistrationBehaviour
|
from synapse.spam_checker_api import RegistrationBehaviour
|
||||||
|
from synapse.types import RoomAlias
|
||||||
from synapse.util.async_helpers import maybe_awaitable
|
from synapse.util.async_helpers import maybe_awaitable
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -113,7 +114,9 @@ class SpamChecker:
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def user_may_create_room_alias(self, userid: str, room_alias: str) -> bool:
|
async def user_may_create_room_alias(
|
||||||
|
self, userid: str, room_alias: RoomAlias
|
||||||
|
) -> bool:
|
||||||
"""Checks if a given user may create a room alias
|
"""Checks if a given user may create a room alias
|
||||||
|
|
||||||
If this method returns false, the association request will be rejected.
|
If this method returns false, the association request will be rejected.
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import string
|
import string
|
||||||
from typing import Iterable, List, Optional
|
from typing import TYPE_CHECKING, Iterable, List, Optional
|
||||||
|
|
||||||
from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes
|
from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
|
@ -27,15 +27,19 @@ from synapse.api.errors import (
|
||||||
SynapseError,
|
SynapseError,
|
||||||
)
|
)
|
||||||
from synapse.appservice import ApplicationService
|
from synapse.appservice import ApplicationService
|
||||||
from synapse.types import Requester, RoomAlias, UserID, get_domain_from_id
|
from synapse.storage.databases.main.directory import RoomAliasMapping
|
||||||
|
from synapse.types import JsonDict, Requester, RoomAlias, UserID, get_domain_from_id
|
||||||
|
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class DirectoryHandler(BaseHandler):
|
class DirectoryHandler(BaseHandler):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
super().__init__(hs)
|
super().__init__(hs)
|
||||||
|
|
||||||
self.state = hs.get_state_handler()
|
self.state = hs.get_state_handler()
|
||||||
|
@ -60,7 +64,7 @@ class DirectoryHandler(BaseHandler):
|
||||||
room_id: str,
|
room_id: str,
|
||||||
servers: Optional[Iterable[str]] = None,
|
servers: Optional[Iterable[str]] = None,
|
||||||
creator: Optional[str] = None,
|
creator: Optional[str] = None,
|
||||||
):
|
) -> None:
|
||||||
# general association creation for both human users and app services
|
# general association creation for both human users and app services
|
||||||
|
|
||||||
for wchar in string.whitespace:
|
for wchar in string.whitespace:
|
||||||
|
@ -104,8 +108,9 @@ class DirectoryHandler(BaseHandler):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
user_id = requester.user.to_string()
|
user_id = requester.user.to_string()
|
||||||
|
room_alias_str = room_alias.to_string()
|
||||||
|
|
||||||
if len(room_alias.to_string()) > MAX_ALIAS_LENGTH:
|
if len(room_alias_str) > MAX_ALIAS_LENGTH:
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
400,
|
400,
|
||||||
"Can't create aliases longer than %s characters" % MAX_ALIAS_LENGTH,
|
"Can't create aliases longer than %s characters" % MAX_ALIAS_LENGTH,
|
||||||
|
@ -114,7 +119,7 @@ class DirectoryHandler(BaseHandler):
|
||||||
|
|
||||||
service = requester.app_service
|
service = requester.app_service
|
||||||
if service:
|
if service:
|
||||||
if not service.is_interested_in_alias(room_alias.to_string()):
|
if not service.is_interested_in_alias(room_alias_str):
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
400,
|
400,
|
||||||
"This application service has not reserved this kind of alias.",
|
"This application service has not reserved this kind of alias.",
|
||||||
|
@ -138,7 +143,7 @@ class DirectoryHandler(BaseHandler):
|
||||||
raise AuthError(403, "This user is not permitted to create this alias")
|
raise AuthError(403, "This user is not permitted to create this alias")
|
||||||
|
|
||||||
if not self.config.is_alias_creation_allowed(
|
if not self.config.is_alias_creation_allowed(
|
||||||
user_id, room_id, room_alias.to_string()
|
user_id, room_id, room_alias_str
|
||||||
):
|
):
|
||||||
# Lets just return a generic message, as there may be all sorts of
|
# Lets just return a generic message, as there may be all sorts of
|
||||||
# reasons why we said no. TODO: Allow configurable error messages
|
# reasons why we said no. TODO: Allow configurable error messages
|
||||||
|
@ -211,7 +216,7 @@ class DirectoryHandler(BaseHandler):
|
||||||
|
|
||||||
async def delete_appservice_association(
|
async def delete_appservice_association(
|
||||||
self, service: ApplicationService, room_alias: RoomAlias
|
self, service: ApplicationService, room_alias: RoomAlias
|
||||||
):
|
) -> None:
|
||||||
if not service.is_interested_in_alias(room_alias.to_string()):
|
if not service.is_interested_in_alias(room_alias.to_string()):
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
400,
|
400,
|
||||||
|
@ -220,7 +225,7 @@ class DirectoryHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
await self._delete_association(room_alias)
|
await self._delete_association(room_alias)
|
||||||
|
|
||||||
async def _delete_association(self, room_alias: RoomAlias):
|
async def _delete_association(self, room_alias: RoomAlias) -> str:
|
||||||
if not self.hs.is_mine(room_alias):
|
if not self.hs.is_mine(room_alias):
|
||||||
raise SynapseError(400, "Room alias must be local")
|
raise SynapseError(400, "Room alias must be local")
|
||||||
|
|
||||||
|
@ -228,17 +233,19 @@ class DirectoryHandler(BaseHandler):
|
||||||
|
|
||||||
return room_id
|
return room_id
|
||||||
|
|
||||||
async def get_association(self, room_alias: RoomAlias):
|
async def get_association(self, room_alias: RoomAlias) -> JsonDict:
|
||||||
room_id = None
|
room_id = None
|
||||||
if self.hs.is_mine(room_alias):
|
if self.hs.is_mine(room_alias):
|
||||||
result = await self.get_association_from_room_alias(room_alias)
|
result = await self.get_association_from_room_alias(
|
||||||
|
room_alias
|
||||||
|
) # type: Optional[RoomAliasMapping]
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
room_id = result.room_id
|
room_id = result.room_id
|
||||||
servers = result.servers
|
servers = result.servers
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
result = await self.federation.make_query(
|
fed_result = await self.federation.make_query(
|
||||||
destination=room_alias.domain,
|
destination=room_alias.domain,
|
||||||
query_type="directory",
|
query_type="directory",
|
||||||
args={"room_alias": room_alias.to_string()},
|
args={"room_alias": room_alias.to_string()},
|
||||||
|
@ -248,13 +255,13 @@ class DirectoryHandler(BaseHandler):
|
||||||
except CodeMessageException as e:
|
except CodeMessageException as e:
|
||||||
logging.warning("Error retrieving alias")
|
logging.warning("Error retrieving alias")
|
||||||
if e.code == 404:
|
if e.code == 404:
|
||||||
result = None
|
fed_result = None
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
if result and "room_id" in result and "servers" in result:
|
if fed_result and "room_id" in fed_result and "servers" in fed_result:
|
||||||
room_id = result["room_id"]
|
room_id = fed_result["room_id"]
|
||||||
servers = result["servers"]
|
servers = fed_result["servers"]
|
||||||
|
|
||||||
if not room_id:
|
if not room_id:
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
|
@ -275,7 +282,7 @@ class DirectoryHandler(BaseHandler):
|
||||||
|
|
||||||
return {"room_id": room_id, "servers": servers}
|
return {"room_id": room_id, "servers": servers}
|
||||||
|
|
||||||
async def on_directory_query(self, args):
|
async def on_directory_query(self, args: JsonDict) -> JsonDict:
|
||||||
room_alias = RoomAlias.from_string(args["room_alias"])
|
room_alias = RoomAlias.from_string(args["room_alias"])
|
||||||
if not self.hs.is_mine(room_alias):
|
if not self.hs.is_mine(room_alias):
|
||||||
raise SynapseError(400, "Room Alias is not hosted on this homeserver")
|
raise SynapseError(400, "Room Alias is not hosted on this homeserver")
|
||||||
|
@ -293,7 +300,7 @@ class DirectoryHandler(BaseHandler):
|
||||||
|
|
||||||
async def _update_canonical_alias(
|
async def _update_canonical_alias(
|
||||||
self, requester: Requester, user_id: str, room_id: str, room_alias: RoomAlias
|
self, requester: Requester, user_id: str, room_id: str, room_alias: RoomAlias
|
||||||
):
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Send an updated canonical alias event if the removed alias was set as
|
Send an updated canonical alias event if the removed alias was set as
|
||||||
the canonical alias or listed in the alt_aliases field.
|
the canonical alias or listed in the alt_aliases field.
|
||||||
|
@ -344,7 +351,9 @@ class DirectoryHandler(BaseHandler):
|
||||||
ratelimit=False,
|
ratelimit=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_association_from_room_alias(self, room_alias: RoomAlias):
|
async def get_association_from_room_alias(
|
||||||
|
self, room_alias: RoomAlias
|
||||||
|
) -> Optional[RoomAliasMapping]:
|
||||||
result = await self.store.get_association_from_room_alias(room_alias)
|
result = await self.store.get_association_from_room_alias(room_alias)
|
||||||
if not result:
|
if not result:
|
||||||
# Query AS to see if it exists
|
# Query AS to see if it exists
|
||||||
|
@ -372,7 +381,7 @@ class DirectoryHandler(BaseHandler):
|
||||||
# either no interested services, or no service with an exclusive lock
|
# either no interested services, or no service with an exclusive lock
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def _user_can_delete_alias(self, alias: RoomAlias, user_id: str):
|
async def _user_can_delete_alias(self, alias: RoomAlias, user_id: str) -> bool:
|
||||||
"""Determine whether a user can delete an alias.
|
"""Determine whether a user can delete an alias.
|
||||||
|
|
||||||
One of the following must be true:
|
One of the following must be true:
|
||||||
|
@ -394,14 +403,13 @@ class DirectoryHandler(BaseHandler):
|
||||||
if not room_id:
|
if not room_id:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
res = await self.auth.check_can_change_room_list(
|
return await self.auth.check_can_change_room_list(
|
||||||
room_id, UserID.from_string(user_id)
|
room_id, UserID.from_string(user_id)
|
||||||
)
|
)
|
||||||
return res
|
|
||||||
|
|
||||||
async def edit_published_room_list(
|
async def edit_published_room_list(
|
||||||
self, requester: Requester, room_id: str, visibility: str
|
self, requester: Requester, room_id: str, visibility: str
|
||||||
):
|
) -> None:
|
||||||
"""Edit the entry of the room in the published room list.
|
"""Edit the entry of the room in the published room list.
|
||||||
|
|
||||||
requester
|
requester
|
||||||
|
@ -469,7 +477,7 @@ class DirectoryHandler(BaseHandler):
|
||||||
|
|
||||||
async def edit_published_appservice_room_list(
|
async def edit_published_appservice_room_list(
|
||||||
self, appservice_id: str, network_id: str, room_id: str, visibility: str
|
self, appservice_id: str, network_id: str, room_id: str, visibility: str
|
||||||
):
|
) -> None:
|
||||||
"""Add or remove a room from the appservice/network specific public
|
"""Add or remove a room from the appservice/network specific public
|
||||||
room list.
|
room list.
|
||||||
|
|
||||||
|
@ -499,5 +507,4 @@ class DirectoryHandler(BaseHandler):
|
||||||
room_id, requester.user.to_string()
|
room_id, requester.user.to_string()
|
||||||
)
|
)
|
||||||
|
|
||||||
aliases = await self.store.get_aliases_for_room(room_id)
|
return await self.store.get_aliases_for_room(room_id)
|
||||||
return aliases
|
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
"""Utilities for interacting with Identity Servers"""
|
"""Utilities for interacting with Identity Servers"""
|
||||||
import logging
|
import logging
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
from typing import Awaitable, Callable, Dict, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
CodeMessageException,
|
CodeMessageException,
|
||||||
|
@ -41,13 +41,16 @@ from synapse.util.stringutils import (
|
||||||
|
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
id_server_scheme = "https://"
|
id_server_scheme = "https://"
|
||||||
|
|
||||||
|
|
||||||
class IdentityHandler(BaseHandler):
|
class IdentityHandler(BaseHandler):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
super().__init__(hs)
|
super().__init__(hs)
|
||||||
|
|
||||||
# An HTTP client for contacting trusted URLs.
|
# An HTTP client for contacting trusted URLs.
|
||||||
|
@ -80,7 +83,7 @@ class IdentityHandler(BaseHandler):
|
||||||
request: SynapseRequest,
|
request: SynapseRequest,
|
||||||
medium: str,
|
medium: str,
|
||||||
address: str,
|
address: str,
|
||||||
):
|
) -> None:
|
||||||
"""Used to ratelimit requests to `/requestToken` by IP and address.
|
"""Used to ratelimit requests to `/requestToken` by IP and address.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple
|
||||||
|
|
||||||
from canonicaljson import encode_canonical_json
|
from canonicaljson import encode_canonical_json
|
||||||
|
|
||||||
|
@ -66,7 +66,7 @@ logger = logging.getLogger(__name__)
|
||||||
class MessageHandler:
|
class MessageHandler:
|
||||||
"""Contains some read only APIs to get state about a room"""
|
"""Contains some read only APIs to get state about a room"""
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.state = hs.get_state_handler()
|
self.state = hs.get_state_handler()
|
||||||
|
@ -91,7 +91,7 @@ class MessageHandler:
|
||||||
room_id: str,
|
room_id: str,
|
||||||
event_type: str,
|
event_type: str,
|
||||||
state_key: str,
|
state_key: str,
|
||||||
) -> dict:
|
) -> Optional[EventBase]:
|
||||||
"""Get data from a room.
|
"""Get data from a room.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -115,6 +115,10 @@ class MessageHandler:
|
||||||
data = await self.state.get_current_state(room_id, event_type, state_key)
|
data = await self.state.get_current_state(room_id, event_type, state_key)
|
||||||
elif membership == Membership.LEAVE:
|
elif membership == Membership.LEAVE:
|
||||||
key = (event_type, state_key)
|
key = (event_type, state_key)
|
||||||
|
# If the membership is not JOIN, then the event ID should exist.
|
||||||
|
assert (
|
||||||
|
membership_event_id is not None
|
||||||
|
), "check_user_in_room_or_world_readable returned invalid data"
|
||||||
room_state = await self.state_store.get_state_for_events(
|
room_state = await self.state_store.get_state_for_events(
|
||||||
[membership_event_id], StateFilter.from_types([key])
|
[membership_event_id], StateFilter.from_types([key])
|
||||||
)
|
)
|
||||||
|
@ -186,10 +190,12 @@ class MessageHandler:
|
||||||
|
|
||||||
event = last_events[0]
|
event = last_events[0]
|
||||||
if visible_events:
|
if visible_events:
|
||||||
room_state = await self.state_store.get_state_for_events(
|
room_state_events = await self.state_store.get_state_for_events(
|
||||||
[event.event_id], state_filter=state_filter
|
[event.event_id], state_filter=state_filter
|
||||||
)
|
)
|
||||||
room_state = room_state[event.event_id]
|
room_state = room_state_events[
|
||||||
|
event.event_id
|
||||||
|
] # type: Mapping[Any, EventBase]
|
||||||
else:
|
else:
|
||||||
raise AuthError(
|
raise AuthError(
|
||||||
403,
|
403,
|
||||||
|
@ -210,10 +216,14 @@ class MessageHandler:
|
||||||
)
|
)
|
||||||
room_state = await self.store.get_events(state_ids.values())
|
room_state = await self.store.get_events(state_ids.values())
|
||||||
elif membership == Membership.LEAVE:
|
elif membership == Membership.LEAVE:
|
||||||
room_state = await self.state_store.get_state_for_events(
|
# If the membership is not JOIN, then the event ID should exist.
|
||||||
|
assert (
|
||||||
|
membership_event_id is not None
|
||||||
|
), "check_user_in_room_or_world_readable returned invalid data"
|
||||||
|
room_state_events = await self.state_store.get_state_for_events(
|
||||||
[membership_event_id], state_filter=state_filter
|
[membership_event_id], state_filter=state_filter
|
||||||
)
|
)
|
||||||
room_state = room_state[membership_event_id]
|
room_state = room_state_events[membership_event_id]
|
||||||
|
|
||||||
now = self.clock.time_msec()
|
now = self.clock.time_msec()
|
||||||
events = await self._event_serializer.serialize_events(
|
events = await self._event_serializer.serialize_events(
|
||||||
|
|
|
@ -1044,7 +1044,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||||
|
|
||||||
|
|
||||||
class RoomMemberMasterHandler(RoomMemberHandler):
|
class RoomMemberMasterHandler(RoomMemberHandler):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
super().__init__(hs)
|
super().__init__(hs)
|
||||||
|
|
||||||
self.distributor = hs.get_distributor()
|
self.distributor = hs.get_distributor()
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from twisted.web.client import PartialDownloadError
|
from twisted.web.client import PartialDownloadError
|
||||||
|
|
||||||
|
@ -22,13 +22,16 @@ from synapse.api.errors import Codes, LoginError, SynapseError
|
||||||
from synapse.config.emailconfig import ThreepidBehaviour
|
from synapse.config.emailconfig import ThreepidBehaviour
|
||||||
from synapse.util import json_decoder
|
from synapse.util import json_decoder
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class UserInteractiveAuthChecker:
|
class UserInteractiveAuthChecker:
|
||||||
"""Abstract base class for an interactive auth checker"""
|
"""Abstract base class for an interactive auth checker"""
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def is_enabled(self) -> bool:
|
def is_enabled(self) -> bool:
|
||||||
|
@ -57,10 +60,10 @@ class UserInteractiveAuthChecker:
|
||||||
class DummyAuthChecker(UserInteractiveAuthChecker):
|
class DummyAuthChecker(UserInteractiveAuthChecker):
|
||||||
AUTH_TYPE = LoginType.DUMMY
|
AUTH_TYPE = LoginType.DUMMY
|
||||||
|
|
||||||
def is_enabled(self):
|
def is_enabled(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def check_auth(self, authdict, clientip):
|
async def check_auth(self, authdict: dict, clientip: str) -> Any:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@ -70,24 +73,24 @@ class TermsAuthChecker(UserInteractiveAuthChecker):
|
||||||
def is_enabled(self):
|
def is_enabled(self):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def check_auth(self, authdict, clientip):
|
async def check_auth(self, authdict: dict, clientip: str) -> Any:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
class RecaptchaAuthChecker(UserInteractiveAuthChecker):
|
class RecaptchaAuthChecker(UserInteractiveAuthChecker):
|
||||||
AUTH_TYPE = LoginType.RECAPTCHA
|
AUTH_TYPE = LoginType.RECAPTCHA
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
super().__init__(hs)
|
super().__init__(hs)
|
||||||
self._enabled = bool(hs.config.recaptcha_private_key)
|
self._enabled = bool(hs.config.recaptcha_private_key)
|
||||||
self._http_client = hs.get_proxied_http_client()
|
self._http_client = hs.get_proxied_http_client()
|
||||||
self._url = hs.config.recaptcha_siteverify_api
|
self._url = hs.config.recaptcha_siteverify_api
|
||||||
self._secret = hs.config.recaptcha_private_key
|
self._secret = hs.config.recaptcha_private_key
|
||||||
|
|
||||||
def is_enabled(self):
|
def is_enabled(self) -> bool:
|
||||||
return self._enabled
|
return self._enabled
|
||||||
|
|
||||||
async def check_auth(self, authdict, clientip):
|
async def check_auth(self, authdict: dict, clientip: str) -> Any:
|
||||||
try:
|
try:
|
||||||
user_response = authdict["response"]
|
user_response = authdict["response"]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
|
@ -132,11 +135,11 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
|
||||||
|
|
||||||
|
|
||||||
class _BaseThreepidAuthChecker:
|
class _BaseThreepidAuthChecker:
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
async def _check_threepid(self, medium, authdict):
|
async def _check_threepid(self, medium: str, authdict: dict) -> dict:
|
||||||
if "threepid_creds" not in authdict:
|
if "threepid_creds" not in authdict:
|
||||||
raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)
|
raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)
|
||||||
|
|
||||||
|
@ -206,31 +209,31 @@ class _BaseThreepidAuthChecker:
|
||||||
class EmailIdentityAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
|
class EmailIdentityAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
|
||||||
AUTH_TYPE = LoginType.EMAIL_IDENTITY
|
AUTH_TYPE = LoginType.EMAIL_IDENTITY
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
UserInteractiveAuthChecker.__init__(self, hs)
|
UserInteractiveAuthChecker.__init__(self, hs)
|
||||||
_BaseThreepidAuthChecker.__init__(self, hs)
|
_BaseThreepidAuthChecker.__init__(self, hs)
|
||||||
|
|
||||||
def is_enabled(self):
|
def is_enabled(self) -> bool:
|
||||||
return self.hs.config.threepid_behaviour_email in (
|
return self.hs.config.threepid_behaviour_email in (
|
||||||
ThreepidBehaviour.REMOTE,
|
ThreepidBehaviour.REMOTE,
|
||||||
ThreepidBehaviour.LOCAL,
|
ThreepidBehaviour.LOCAL,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def check_auth(self, authdict, clientip):
|
async def check_auth(self, authdict: dict, clientip: str) -> Any:
|
||||||
return await self._check_threepid("email", authdict)
|
return await self._check_threepid("email", authdict)
|
||||||
|
|
||||||
|
|
||||||
class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
|
class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
|
||||||
AUTH_TYPE = LoginType.MSISDN
|
AUTH_TYPE = LoginType.MSISDN
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
UserInteractiveAuthChecker.__init__(self, hs)
|
UserInteractiveAuthChecker.__init__(self, hs)
|
||||||
_BaseThreepidAuthChecker.__init__(self, hs)
|
_BaseThreepidAuthChecker.__init__(self, hs)
|
||||||
|
|
||||||
def is_enabled(self):
|
def is_enabled(self) -> bool:
|
||||||
return bool(self.hs.config.account_threepid_delegate_msisdn)
|
return bool(self.hs.config.account_threepid_delegate_msisdn)
|
||||||
|
|
||||||
async def check_auth(self, authdict, clientip):
|
async def check_auth(self, authdict: dict, clientip: str) -> Any:
|
||||||
return await self._check_threepid("msisdn", authdict)
|
return await self._check_threepid("msisdn", authdict)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue