0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-06-10 06:38:55 +02:00

Require type hints in the handlers module. (#10831)

Adds missing type hints to methods in the synapse.handlers
module and requires all methods to have type hints there.

This also removes the unused construct_auth_difference method
from the FederationHandler.
This commit is contained in:
Patrick Cloke 2021-09-20 08:56:23 -04:00 committed by GitHub
parent 437961744c
commit b3590614da
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
35 changed files with 194 additions and 295 deletions

1
changelog.d/10831.misc Normal file
View file

@ -0,0 +1 @@
Add missing type hints to handlers.

View file

@ -91,6 +91,9 @@ files =
tests/util/test_itertools.py,
tests/util/test_stream_change_cache.py
[mypy-synapse.handlers.*]
disallow_untyped_defs = True
[mypy-synapse.rest.*]
disallow_untyped_defs = True

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List
from typing import Any, List, Tuple, Type
from synapse.util.module_loader import load_module
@ -25,7 +25,7 @@ class PasswordAuthProviderConfig(Config):
section = "authproviders"
def read_config(self, config, **kwargs):
self.password_providers: List[Any] = []
self.password_providers: List[Tuple[Type, Any]] = []
providers = []
# We want to be backwards compatible with the old `ldap_config`

View file

@ -16,6 +16,7 @@ import logging
from typing import TYPE_CHECKING, Optional
from synapse.api.ratelimiting import Ratelimiter
from synapse.types import Requester
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -63,16 +64,21 @@ class BaseHandler:
self.event_builder_factory = hs.get_event_builder_factory()
async def ratelimit(self, requester, update=True, is_admin_redaction=False):
async def ratelimit(
self,
requester: Requester,
update: bool = True,
is_admin_redaction: bool = False,
) -> None:
"""Ratelimits requests.
Args:
requester (Requester)
update (bool): Whether to record that a request is being processed.
requester
update: Whether to record that a request is being processed.
Set to False when doing multiple checks for one request (e.g.
to check up front if we would reject the request), and set to
True for the last call for a given request.
is_admin_redaction (bool): Whether this is a room admin/moderator
is_admin_redaction: Whether this is a room admin/moderator
redacting an event. If so then we may apply different
ratelimits depending on config.

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from typing import TYPE_CHECKING, List, Tuple
from typing import TYPE_CHECKING, Any, List, Tuple
from synapse.replication.http.account_data import (
ReplicationAddTagRestServlet,
@ -171,7 +171,7 @@ class AccountDataEventSource:
return self.store.get_max_account_data_stream_id()
async def get_new_events(
self, user: UserID, from_key: int, **kwargs
self, user: UserID, from_key: int, **kwargs: Any
) -> Tuple[List[JsonDict], int]:
user_id = user.to_string()
last_stream_id = from_key

View file

@ -99,7 +99,7 @@ class AccountValidityHandler:
on_legacy_send_mail: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None,
on_legacy_renew: Optional[ON_LEGACY_RENEW_CALLBACK] = None,
on_legacy_admin_request: Optional[ON_LEGACY_ADMIN_REQUEST] = None,
):
) -> None:
"""Register callbacks from module for each hook."""
if is_user_expired is not None:
self._is_user_expired_callbacks.append(is_user_expired)
@ -165,7 +165,7 @@ class AccountValidityHandler:
return False
async def on_user_registration(self, user_id: str):
async def on_user_registration(self, user_id: str) -> None:
"""Tell third-party modules about a user's registration.
Args:

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Union
from prometheus_client import Counter
@ -58,7 +58,7 @@ class ApplicationServicesHandler:
self.current_max = 0
self.is_processing = False
def notify_interested_services(self, max_token: RoomStreamToken):
def notify_interested_services(self, max_token: RoomStreamToken) -> None:
"""Notifies (pushes) all application services interested in this event.
Pushing is done asynchronously, so this method won't block for any
@ -82,7 +82,7 @@ class ApplicationServicesHandler:
self._notify_interested_services(max_token)
@wrap_as_background_process("notify_interested_services")
async def _notify_interested_services(self, max_token: RoomStreamToken):
async def _notify_interested_services(self, max_token: RoomStreamToken) -> None:
with Measure(self.clock, "notify_interested_services"):
self.is_processing = True
try:
@ -100,7 +100,7 @@ class ApplicationServicesHandler:
for event in events:
events_by_room.setdefault(event.room_id, []).append(event)
async def handle_event(event):
async def handle_event(event: EventBase) -> None:
# Gather interested services
services = await self._get_services_for_event(event)
if len(services) == 0:
@ -116,9 +116,9 @@ class ApplicationServicesHandler:
if not self.started_scheduler:
async def start_scheduler():
async def start_scheduler() -> None:
try:
return await self.scheduler.start()
await self.scheduler.start()
except Exception:
logger.error("Application Services Failure")
@ -137,7 +137,7 @@ class ApplicationServicesHandler:
"appservice_sender"
).observe((now - ts) / 1000)
async def handle_room_events(events):
async def handle_room_events(events: Iterable[EventBase]) -> None:
for event in events:
await handle_event(event)
@ -184,7 +184,7 @@ class ApplicationServicesHandler:
stream_key: str,
new_token: Optional[int],
users: Optional[Collection[Union[str, UserID]]] = None,
):
) -> None:
"""This is called by the notifier in the background
when a ephemeral event handled by the homeserver.
@ -226,7 +226,7 @@ class ApplicationServicesHandler:
stream_key: str,
new_token: Optional[int],
users: Collection[Union[str, UserID]],
):
) -> None:
logger.debug("Checking interested services for %s" % (stream_key))
with Measure(self.clock, "notify_interested_services_ephemeral"):
for service in services:

View file

@ -29,6 +29,7 @@ from typing import (
Mapping,
Optional,
Tuple,
Type,
Union,
cast,
)
@ -439,7 +440,7 @@ class AuthHandler(BaseHandler):
return ui_auth_types
def get_enabled_auth_types(self):
def get_enabled_auth_types(self) -> Iterable[str]:
"""Return the enabled user-interactive authentication types
Returns the UI-Auth types which are supported by the homeserver's current
@ -702,7 +703,7 @@ class AuthHandler(BaseHandler):
except StoreError:
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
async def _expire_old_sessions(self):
async def _expire_old_sessions(self) -> None:
"""
Invalidate any user interactive authentication sessions that have expired.
"""
@ -1352,7 +1353,7 @@ class AuthHandler(BaseHandler):
await self.auth.check_auth_blocking(res.user_id)
return res
async def delete_access_token(self, access_token: str):
async def delete_access_token(self, access_token: str) -> None:
"""Invalidate a single access token
Args:
@ -1381,7 +1382,7 @@ class AuthHandler(BaseHandler):
user_id: str,
except_token_id: Optional[int] = None,
device_id: Optional[str] = None,
):
) -> None:
"""Invalidate access tokens belonging to a user
Args:
@ -1409,7 +1410,7 @@ class AuthHandler(BaseHandler):
async def add_threepid(
self, user_id: str, medium: str, address: str, validated_at: int
):
) -> None:
# check if medium has a valid value
if medium not in ["email", "msisdn"]:
raise SynapseError(
@ -1480,7 +1481,7 @@ class AuthHandler(BaseHandler):
Hashed password.
"""
def _do_hash():
def _do_hash() -> str:
# Normalise the Unicode in the password
pw = unicodedata.normalize("NFKC", password)
@ -1504,7 +1505,7 @@ class AuthHandler(BaseHandler):
Whether self.hash(password) == stored_hash.
"""
def _do_validate_hash(checked_hash: bytes):
def _do_validate_hash(checked_hash: bytes) -> bool:
# Normalise the Unicode in the password
pw = unicodedata.normalize("NFKC", password)
@ -1581,7 +1582,7 @@ class AuthHandler(BaseHandler):
client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None,
new_user: bool = False,
):
) -> None:
"""Having figured out a mxid for this user, complete the HTTP request
Args:
@ -1627,7 +1628,7 @@ class AuthHandler(BaseHandler):
extra_attributes: Optional[JsonDict] = None,
new_user: bool = False,
user_profile_data: Optional[ProfileInfo] = None,
):
) -> None:
"""
The synchronous portion of complete_sso_login.
@ -1726,7 +1727,7 @@ class AuthHandler(BaseHandler):
del self._extra_attributes[user_id]
@staticmethod
def add_query_param_to_url(url: str, param_name: str, param: Any):
def add_query_param_to_url(url: str, param_name: str, param: Any) -> str:
url_parts = list(urllib.parse.urlparse(url))
query = urllib.parse.parse_qsl(url_parts[4], keep_blank_values=True)
query.append((param_name, param))
@ -1734,9 +1735,9 @@ class AuthHandler(BaseHandler):
return urllib.parse.urlunparse(url_parts)
@attr.s(slots=True)
@attr.s(slots=True, auto_attribs=True)
class MacaroonGenerator:
hs = attr.ib()
hs: "HomeServer"
def generate_guest_access_token(self, user_id: str) -> str:
macaroon = self._generate_base_macaroon(user_id)
@ -1816,7 +1817,9 @@ class PasswordProvider:
"""
@classmethod
def load(cls, module, config, module_api: ModuleApi) -> "PasswordProvider":
def load(
cls, module: Type, config: JsonDict, module_api: ModuleApi
) -> "PasswordProvider":
try:
pp = module(config=config, account_handler=module_api)
except Exception as e:
@ -1824,7 +1827,7 @@ class PasswordProvider:
raise
return cls(pp, module_api)
def __init__(self, pp, module_api: ModuleApi):
def __init__(self, pp: "PasswordProvider", module_api: ModuleApi):
self._pp = pp
self._module_api = module_api
@ -1838,7 +1841,7 @@ class PasswordProvider:
if g:
self._supported_login_types.update(g())
def __str__(self):
def __str__(self) -> str:
return str(self._pp)
def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
@ -1876,19 +1879,19 @@ class PasswordProvider:
"""
# first grandfather in a call to check_password
if login_type == LoginType.PASSWORD:
g = getattr(self._pp, "check_password", None)
if g:
check_password = getattr(self._pp, "check_password", None)
if check_password:
qualified_user_id = self._module_api.get_qualified_user_id(username)
is_valid = await self._pp.check_password(
is_valid = await check_password(
qualified_user_id, login_dict["password"]
)
if is_valid:
return qualified_user_id, None
g = getattr(self._pp, "check_auth", None)
if not g:
check_auth = getattr(self._pp, "check_auth", None)
if not check_auth:
return None
result = await g(username, login_type, login_dict)
result = await check_auth(username, login_type, login_dict)
# Check if the return value is a str or a tuple
if isinstance(result, str):

View file

@ -34,20 +34,20 @@ logger = logging.getLogger(__name__)
class CasError(Exception):
"""Used to catch errors when validating the CAS ticket."""
def __init__(self, error, error_description=None):
def __init__(self, error: str, error_description: Optional[str] = None):
self.error = error
self.error_description = error_description
def __str__(self):
def __str__(self) -> str:
if self.error_description:
return f"{self.error}: {self.error_description}"
return self.error
@attr.s(slots=True, frozen=True)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class CasResponse:
username = attr.ib(type=str)
attributes = attr.ib(type=Dict[str, List[Optional[str]]])
username: str
attributes: Dict[str, List[Optional[str]]]
class CasHandler:
@ -133,11 +133,9 @@ class CasHandler:
body = pde.response
except HttpResponseException as e:
description = (
(
'Authorization server responded with a "{status}" error '
"while exchanging the authorization code."
).format(status=e.code),
)
'Authorization server responded with a "{status}" error '
"while exchanging the authorization code."
).format(status=e.code)
raise CasError("server_error", description) from e
return self._parse_cas_response(body)

View file

@ -267,7 +267,7 @@ class DeviceHandler(DeviceWorkerHandler):
hs.get_distributor().observe("user_left_room", self.user_left_room)
def _check_device_name_length(self, name: Optional[str]):
def _check_device_name_length(self, name: Optional[str]) -> None:
"""
Checks whether a device name is longer than the maximum allowed length.

View file

@ -202,7 +202,7 @@ class E2eKeysHandler:
# Now fetch any devices that we don't have in our cache
@trace
async def do_remote_query(destination):
async def do_remote_query(destination: str) -> None:
"""This is called when we are querying the device list of a user on
a remote homeserver and their device list is not in the device list
cache. If we share a room with this user and we're not querying for
@ -447,7 +447,7 @@ class E2eKeysHandler:
}
@trace
async def claim_client_keys(destination):
async def claim_client_keys(destination: str) -> None:
set_tag("destination", destination)
device_keys = remote_queries[destination]
try:

View file

@ -25,6 +25,7 @@ from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.events import EventBase
from synapse.events.builder import EventBuilder
from synapse.events.snapshot import EventContext
from synapse.types import StateMap, get_domain_from_id
from synapse.util.metrics import Measure
@ -45,7 +46,11 @@ class EventAuthHandler:
self._server_name = hs.hostname
async def check_from_context(
self, room_version: str, event, context, do_sig_check=True
self,
room_version: str,
event: EventBase,
context: EventContext,
do_sig_check: bool = True,
) -> None:
auth_event_ids = event.auth_event_ids()
auth_events_by_id = await self._store.get_events(auth_event_ids)

View file

@ -1221,136 +1221,6 @@ class FederationHandler(BaseHandler):
return missing_events
async def construct_auth_difference(
self, local_auth: Iterable[EventBase], remote_auth: Iterable[EventBase]
) -> Dict:
"""Given a local and remote auth chain, find the differences. This
assumes that we have already processed all events in remote_auth
Params:
local_auth
remote_auth
Returns:
dict
"""
logger.debug("construct_auth_difference Start!")
# TODO: Make sure we are OK with local_auth or remote_auth having more
# auth events in them than strictly necessary.
def sort_fun(ev):
return ev.depth, ev.event_id
logger.debug("construct_auth_difference after sort_fun!")
# We find the differences by starting at the "bottom" of each list
# and iterating up on both lists. The lists are ordered by depth and
# then event_id, we iterate up both lists until we find the event ids
# don't match. Then we look at depth/event_id to see which side is
# missing that event, and iterate only up that list. Repeat.
remote_list = list(remote_auth)
remote_list.sort(key=sort_fun)
local_list = list(local_auth)
local_list.sort(key=sort_fun)
local_iter = iter(local_list)
remote_iter = iter(remote_list)
logger.debug("construct_auth_difference before get_next!")
def get_next(it, opt=None):
try:
return next(it)
except Exception:
return opt
current_local = get_next(local_iter)
current_remote = get_next(remote_iter)
logger.debug("construct_auth_difference before while")
missing_remotes = []
missing_locals = []
while current_local or current_remote:
if current_remote is None:
missing_locals.append(current_local)
current_local = get_next(local_iter)
continue
if current_local is None:
missing_remotes.append(current_remote)
current_remote = get_next(remote_iter)
continue
if current_local.event_id == current_remote.event_id:
current_local = get_next(local_iter)
current_remote = get_next(remote_iter)
continue
if current_local.depth < current_remote.depth:
missing_locals.append(current_local)
current_local = get_next(local_iter)
continue
if current_local.depth > current_remote.depth:
missing_remotes.append(current_remote)
current_remote = get_next(remote_iter)
continue
# They have the same depth, so we fall back to the event_id order
if current_local.event_id < current_remote.event_id:
missing_locals.append(current_local)
current_local = get_next(local_iter)
if current_local.event_id > current_remote.event_id:
missing_remotes.append(current_remote)
current_remote = get_next(remote_iter)
continue
logger.debug("construct_auth_difference after while")
# missing locals should be sent to the server
# We should find why we are missing remotes, as they will have been
# rejected.
# Remove events from missing_remotes if they are referencing a missing
# remote. We only care about the "root" rejected ones.
missing_remote_ids = [e.event_id for e in missing_remotes]
base_remote_rejected = list(missing_remotes)
for e in missing_remotes:
for e_id in e.auth_event_ids():
if e_id in missing_remote_ids:
try:
base_remote_rejected.remove(e)
except ValueError:
pass
reason_map = {}
for e in base_remote_rejected:
reason = await self.store.get_rejection_reason(e.event_id)
if reason is None:
# TODO: e is not in the current state, so we should
# construct some proof of that.
continue
reason_map[e.event_id] = reason
logger.debug("construct_auth_difference returning")
return {
"auth_chain": local_auth,
"rejects": {
e.event_id: {"reason": reason_map[e.event_id], "proof": None}
for e in base_remote_rejected
},
"missing": [e.event_id for e in missing_locals],
}
@log_function
async def exchange_third_party_invite(
self, sender_user_id: str, target_user_id: str, room_id: str, signed: JsonDict

View file

@ -1016,7 +1016,7 @@ class FederationEventHandler:
except Exception:
logger.exception("Failed to resync device for %s", sender)
async def _handle_marker_event(self, origin: str, marker_event: EventBase):
async def _handle_marker_event(self, origin: str, marker_event: EventBase) -> None:
"""Handles backfilling the insertion event when we receive a marker
event that points to one.
@ -1109,7 +1109,7 @@ class FederationEventHandler:
event_map: Dict[str, EventBase] = {}
async def get_event(event_id: str):
async def get_event(event_id: str) -> None:
with nested_logging_context(event_id):
try:
event = await self._federation_client.get_pdu(
@ -1218,7 +1218,7 @@ class FederationEventHandler:
if not event_infos:
return
async def prep(ev_info: _NewEventInfo):
async def prep(ev_info: _NewEventInfo) -> EventContext:
event = ev_info.event
with nested_logging_context(suffix=event.event_id):
res = await self._state_handler.compute_event_context(event)
@ -1692,7 +1692,7 @@ class FederationEventHandler:
async def _run_push_actions_and_persist_event(
self, event: EventBase, context: EventContext, backfilled: bool = False
):
) -> None:
"""Run the push actions for a received event, and persist it.
Args:

View file

@ -14,7 +14,7 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Dict, Iterable, List, Set
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Iterable, List, Set
from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
from synapse.types import GroupID, JsonDict, get_domain_from_id
@ -25,12 +25,14 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
def _create_rerouter(func_name):
def _create_rerouter(func_name: str) -> Callable[..., Awaitable[JsonDict]]:
"""Returns an async function that looks at the group id and calls the function
on federation or the local group server if the group is local
"""
async def f(self, group_id, *args, **kwargs):
async def f(
self: "GroupsLocalWorkerHandler", group_id: str, *args: Any, **kwargs: Any
) -> JsonDict:
if not GroupID.is_valid(group_id):
raise SynapseError(400, "%s is not a legal group ID" % (group_id,))

View file

@ -13,7 +13,7 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Optional, Tuple
from typing import TYPE_CHECKING, List, Optional, Tuple
from twisted.internet import defer
@ -150,7 +150,7 @@ class InitialSyncHandler(BaseHandler):
if limit is None:
limit = 10
async def handle_room(event: RoomsForUser):
async def handle_room(event: RoomsForUser) -> None:
d: JsonDict = {
"room_id": event.room_id,
"membership": event.membership,
@ -411,7 +411,7 @@ class InitialSyncHandler(BaseHandler):
presence_handler = self.hs.get_presence_handler()
async def get_presence():
async def get_presence() -> List[JsonDict]:
# If presence is disabled, return an empty list
if not self.hs.config.server.use_presence:
return []
@ -428,7 +428,7 @@ class InitialSyncHandler(BaseHandler):
for s in states
]
async def get_receipts():
async def get_receipts() -> List[JsonDict]:
receipts = await self.store.get_linearized_receipts_for_room(
room_id, to_key=now_token.receipt_key
)

View file

@ -46,6 +46,7 @@ from synapse.events import EventBase
from synapse.events.builder import EventBuilder
from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator
from synapse.handlers.directory import DirectoryHandler
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
@ -298,7 +299,7 @@ class MessageHandler:
for user_id, profile in users_with_profile.items()
}
def maybe_schedule_expiry(self, event: EventBase):
def maybe_schedule_expiry(self, event: EventBase) -> None:
"""Schedule the expiry of an event if there's not already one scheduled,
or if the one running is for an event that will expire after the provided
timestamp.
@ -318,7 +319,7 @@ class MessageHandler:
# a task scheduled for a timestamp that's sooner than the provided one.
self._schedule_expiry_for_event(event.event_id, expiry_ts)
async def _schedule_next_expiry(self):
async def _schedule_next_expiry(self) -> None:
"""Retrieve the ID and the expiry timestamp of the next event to be expired,
and schedule an expiry task for it.
@ -331,7 +332,7 @@ class MessageHandler:
event_id, expiry_ts = res
self._schedule_expiry_for_event(event_id, expiry_ts)
def _schedule_expiry_for_event(self, event_id: str, expiry_ts: int):
def _schedule_expiry_for_event(self, event_id: str, expiry_ts: int) -> None:
"""Schedule an expiry task for the provided event if there's not already one
scheduled at a timestamp that's sooner than the provided one.
@ -367,7 +368,7 @@ class MessageHandler:
event_id,
)
async def _expire_event(self, event_id: str):
async def _expire_event(self, event_id: str) -> None:
"""Retrieve and expire an event that needs to be expired from the database.
If the event doesn't exist in the database, log it and delete the expiry date
@ -1229,7 +1230,10 @@ class EventCreationHandler:
self._external_cache_joined_hosts_updates[state_entry.state_group] = None
async def _validate_canonical_alias(
self, directory_handler, room_alias_str: str, expected_room_id: str
self,
directory_handler: DirectoryHandler,
room_alias_str: str,
expected_room_id: str,
) -> None:
"""
Ensure that the given room alias points to the expected room ID.
@ -1477,7 +1481,7 @@ class EventCreationHandler:
# If there's an expiry timestamp on the event, schedule its expiry.
self._message_handler.maybe_schedule_expiry(event)
def _notify():
def _notify() -> None:
try:
self.notifier.on_new_room_event(
event, event_pos, max_stream_token, extra_users=extra_users
@ -1523,7 +1527,7 @@ class EventCreationHandler:
except Exception:
logger.exception("Error bumping presence active time")
async def _send_dummy_events_to_fill_extremities(self):
async def _send_dummy_events_to_fill_extremities(self) -> None:
"""Background task to send dummy events into rooms that have a large
number of extremities
"""
@ -1600,7 +1604,7 @@ class EventCreationHandler:
)
return False
def _expire_rooms_to_exclude_from_dummy_event_insertion(self):
def _expire_rooms_to_exclude_from_dummy_event_insertion(self) -> None:
expire_before = self.clock.time_msec() - _DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY
to_expire = set()
for room_id, time in self._rooms_to_exclude_from_dummy_event_insertion.items():

View file

@ -14,7 +14,7 @@
# limitations under the License.
import inspect
import logging
from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar, Union
from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, TypeVar, Union
from urllib.parse import urlencode, urlparse
import attr
@ -249,11 +249,11 @@ class OidcHandler:
class OidcError(Exception):
"""Used to catch errors when calling the token_endpoint"""
def __init__(self, error, error_description=None):
def __init__(self, error: str, error_description: Optional[str] = None):
self.error = error
self.error_description = error_description
def __str__(self):
def __str__(self) -> str:
if self.error_description:
return f"{self.error}: {self.error_description}"
return self.error
@ -1057,13 +1057,13 @@ class JwtClientSecret:
self._cached_secret = b""
self._cached_secret_replacement_time = 0
def __str__(self):
def __str__(self) -> str:
# if client_auth_method is client_secret_basic, then ClientAuth.prepare calls
# encode_client_secret_basic, which calls "{}".format(secret), which ends up
# here.
return self._get_secret().decode("ascii")
def __bytes__(self):
def __bytes__(self) -> bytes:
# if client_auth_method is client_secret_post, then ClientAuth.prepare calls
# encode_client_secret_post, which ends up here.
return self._get_secret()
@ -1197,21 +1197,21 @@ class OidcSessionTokenGenerator:
)
@attr.s(frozen=True, slots=True)
@attr.s(frozen=True, slots=True, auto_attribs=True)
class OidcSessionData:
"""The attributes which are stored in a OIDC session cookie"""
# the Identity Provider being used
idp_id = attr.ib(type=str)
idp_id: str
# The `nonce` parameter passed to the OIDC provider.
nonce = attr.ib(type=str)
nonce: str
# The URL the client gave when it initiated the flow. ("" if this is a UI Auth)
client_redirect_url = attr.ib(type=str)
client_redirect_url: str
# The session ID of the ongoing UI Auth ("" if this is a login)
ui_auth_session_id = attr.ib(type=str)
ui_auth_session_id: str
class UserAttributeDict(TypedDict):
@ -1290,20 +1290,20 @@ class OidcMappingProvider(Generic[C]):
# Used to clear out "None" values in templates
def jinja_finalize(thing):
def jinja_finalize(thing: Any) -> Any:
return thing if thing is not None else ""
env = Environment(finalize=jinja_finalize)
@attr.s(slots=True, frozen=True)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class JinjaOidcMappingConfig:
subject_claim = attr.ib(type=str)
localpart_template = attr.ib(type=Optional[Template])
display_name_template = attr.ib(type=Optional[Template])
email_template = attr.ib(type=Optional[Template])
extra_attributes = attr.ib(type=Dict[str, Template])
subject_claim: str
localpart_template: Optional[Template]
display_name_template: Optional[Template]
email_template: Optional[Template]
extra_attributes: Dict[str, Template]
class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):

View file

@ -15,6 +15,8 @@
import logging
from typing import TYPE_CHECKING, Any, Dict, Optional, Set
import attr
from twisted.python.failure import Failure
from synapse.api.constants import EventTypes, Membership
@ -24,7 +26,7 @@ from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.state import StateFilter
from synapse.streams.config import PaginationConfig
from synapse.types import Requester
from synapse.types import JsonDict, Requester
from synapse.util.async_helpers import ReadWriteLock
from synapse.util.stringutils import random_string
from synapse.visibility import filter_events_for_client
@ -36,15 +38,12 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
@attr.s(slots=True, auto_attribs=True)
class PurgeStatus:
"""Object tracking the status of a purge request
This class contains information on the progress of a purge request, for
return by get_purge_status.
Attributes:
status (int): Tracks whether this request has completed. One of
STATUS_{ACTIVE,COMPLETE,FAILED}
"""
STATUS_ACTIVE = 0
@ -57,10 +56,10 @@ class PurgeStatus:
STATUS_FAILED: "failed",
}
def __init__(self):
self.status = PurgeStatus.STATUS_ACTIVE
# Tracks whether this request has completed. One of STATUS_{ACTIVE,COMPLETE,FAILED}.
status: int = STATUS_ACTIVE
def asdict(self):
def asdict(self) -> JsonDict:
return {"status": PurgeStatus.STATUS_TEXT[self.status]}
@ -107,7 +106,7 @@ class PaginationHandler:
async def purge_history_for_rooms_in_range(
self, min_ms: Optional[int], max_ms: Optional[int]
):
) -> None:
"""Purge outdated events from rooms within the given retention range.
If a default retention policy is defined in the server's configuration and its
@ -291,7 +290,7 @@ class PaginationHandler:
self._purges_in_progress_by_room.discard(room_id)
# remove the purge from the list 24 hours after it completes
def clear_purge():
def clear_purge() -> None:
del self._purges_by_id[purge_id]
self.hs.get_reactor().callLater(24 * 3600, clear_purge)

View file

@ -26,18 +26,22 @@ import contextlib
import logging
from bisect import bisect
from contextlib import contextmanager
from types import TracebackType
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Collection,
Dict,
FrozenSet,
Generator,
Iterable,
List,
Optional,
Set,
Tuple,
Type,
Union,
)
@ -240,7 +244,7 @@ class BasePresenceHandler(abc.ABC):
"""
@abc.abstractmethod
async def bump_presence_active_time(self, user: UserID):
async def bump_presence_active_time(self, user: UserID) -> None:
"""We've seen the user do something that indicates they're interacting
with the app.
"""
@ -274,7 +278,7 @@ class BasePresenceHandler(abc.ABC):
async def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: list
):
) -> None:
"""Process streams received over replication."""
await self._federation_queue.process_replication_rows(
stream_name, instance_name, token, rows
@ -286,7 +290,7 @@ class BasePresenceHandler(abc.ABC):
async def maybe_send_presence_to_interested_destinations(
self, states: List[UserPresenceState]
):
) -> None:
"""If this instance is a federation sender, send the states to all
destinations that are interested. Filters out any states for remote
users.
@ -309,7 +313,7 @@ class BasePresenceHandler(abc.ABC):
for destination, host_states in hosts_to_states.items():
self._federation.send_presence_to_destinations(host_states, [destination])
async def send_full_presence_to_users(self, user_ids: Collection[str]):
async def send_full_presence_to_users(self, user_ids: Collection[str]) -> None:
"""
Adds to the list of users who should receive a full snapshot of presence
upon their next sync. Note that this only works for local users.
@ -363,7 +367,12 @@ class BasePresenceHandler(abc.ABC):
class _NullContextManager(ContextManager[None]):
"""A context manager which does nothing."""
def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
pass
@ -468,7 +477,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
if self._user_to_num_current_syncs[user_id] == 1:
self.mark_as_coming_online(user_id)
def _end():
def _end() -> None:
# We check that the user_id is in user_to_num_current_syncs because
# user_to_num_current_syncs may have been cleared if we are
# shutting down.
@ -480,7 +489,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
self.mark_as_going_offline(user_id)
@contextlib.contextmanager
def _user_syncing():
def _user_syncing() -> Generator[None, None, None]:
try:
yield
finally:
@ -503,7 +512,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
async def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: list
):
) -> None:
await super().process_replication_rows(stream_name, instance_name, token, rows)
if stream_name != PresenceStream.NAME:
@ -689,7 +698,7 @@ class PresenceHandler(BasePresenceHandler):
# Start a LoopingCall in 30s that fires every 5s.
# The initial delay is to allow disconnected clients a chance to
# reconnect before we treat them as offline.
def run_timeout_handler():
def run_timeout_handler() -> Awaitable[None]:
return run_as_background_process(
"handle_presence_timeouts", self._handle_timeouts
)
@ -698,7 +707,7 @@ class PresenceHandler(BasePresenceHandler):
30, self.clock.looping_call, run_timeout_handler, 5000
)
def run_persister():
def run_persister() -> Awaitable[None]:
return run_as_background_process(
"persist_presence_changes", self._persist_unpersisted_changes
)
@ -942,8 +951,8 @@ class PresenceHandler(BasePresenceHandler):
when users disconnect/reconnect.
Args:
user_id (str)
affect_presence (bool): If false this function will be a no-op.
user_id
affect_presence: If false this function will be a no-op.
Useful for streams that are not associated with an actual
client that is being used by a user.
"""
@ -978,7 +987,7 @@ class PresenceHandler(BasePresenceHandler):
]
)
async def _end():
async def _end() -> None:
try:
self.user_to_num_current_syncs[user_id] -= 1
@ -994,7 +1003,7 @@ class PresenceHandler(BasePresenceHandler):
logger.exception("Error updating presence after sync")
@contextmanager
def _user_syncing():
def _user_syncing() -> Generator[None, None, None]:
try:
yield
finally:
@ -1264,7 +1273,7 @@ class PresenceHandler(BasePresenceHandler):
if self._event_processing:
return
async def _process_presence():
async def _process_presence() -> None:
assert not self._event_processing
self._event_processing = True
@ -1513,7 +1522,7 @@ class PresenceEventSource:
room_ids: Optional[List[str]] = None,
include_offline: bool = True,
explicit_room_id: Optional[str] = None,
**kwargs,
**kwargs: Any,
) -> Tuple[List[UserPresenceState], int]:
# The process for getting presence events are:
# 1. Get the rooms the user is in.
@ -2074,7 +2083,7 @@ class PresenceFederationQueue:
if self._queue_presence_updates:
self._clock.looping_call(self._clear_queue, self._CLEAR_ITEMS_EVERY_MS)
def _clear_queue(self):
def _clear_queue(self) -> None:
"""Clear out older entries from the queue."""
clear_before = self._clock.time_msec() - self._KEEP_ITEMS_IN_QUEUE_FOR_MS
@ -2205,7 +2214,7 @@ class PresenceFederationQueue:
async def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: list
):
) -> None:
if stream_name != PresenceFederationStream.NAME:
return

View file

@ -254,7 +254,7 @@ class ProfileHandler(BaseHandler):
requester: Requester,
new_avatar_url: str,
by_admin: bool = False,
):
) -> None:
"""Set a new avatar URL for a user.
Args:
@ -425,7 +425,7 @@ class ProfileHandler(BaseHandler):
raise
@wrap_as_background_process("Update remote profile")
async def _update_remote_profile_cache(self):
async def _update_remote_profile_cache(self) -> None:
"""Called periodically to check profiles of remote users we haven't
checked in a while.
"""

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
from synapse.api.constants import ReadReceiptEventFields
from synapse.appservice import ApplicationService
@ -216,7 +216,7 @@ class ReceiptEventSource:
return visible_events
async def get_new_events(
self, from_key: int, room_ids: List[str], user: UserID, **kwargs
self, from_key: int, room_ids: List[str], user: UserID, **kwargs: Any
) -> Tuple[List[JsonDict], int]:
from_key = int(from_key)
to_key = self.get_current_key()

View file

@ -125,7 +125,7 @@ class RegistrationHandler(BaseHandler):
localpart: str,
guest_access_token: Optional[str] = None,
assigned_user_id: Optional[str] = None,
):
) -> None:
if types.contains_invalid_mxid_characters(localpart):
raise SynapseError(
400,

View file

@ -1,6 +1,4 @@
# Copyright 2014 - 2016 OpenMarket Ltd
# Copyright 2018-2019 New Vector Ltd
# Copyright 2019 The Matrix.org Foundation C.I.C.
# Copyright 2016-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -186,7 +184,7 @@ class RoomCreationHandler(BaseHandler):
async def _upgrade_room(
self, requester: Requester, old_room_id: str, new_version: RoomVersion
):
) -> str:
"""
Args:
requester: the user requesting the upgrade
@ -512,7 +510,7 @@ class RoomCreationHandler(BaseHandler):
old_room_id: str,
new_room_id: str,
old_room_state: StateMap[str],
):
) -> None:
# check to see if we have a canonical alias.
canonical_alias_event = None
canonical_alias_event_id = old_room_state.get((EventTypes.CanonicalAlias, ""))
@ -902,7 +900,7 @@ class RoomCreationHandler(BaseHandler):
event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
def create(etype: str, content: JsonDict, **kwargs) -> JsonDict:
def create(etype: str, content: JsonDict, **kwargs: Any) -> JsonDict:
e = {"type": etype, "content": content}
e.update(event_keys)
@ -910,7 +908,7 @@ class RoomCreationHandler(BaseHandler):
return e
async def send(etype: str, content: JsonDict, **kwargs) -> int:
async def send(etype: str, content: JsonDict, **kwargs: Any) -> int:
event = create(etype, content, **kwargs)
logger.debug("Sending %s in new room", etype)
# Allow these events to be sent even if the user is shadow-banned to
@ -1033,7 +1031,7 @@ class RoomCreationHandler(BaseHandler):
creator_id: str,
is_public: bool,
room_version: RoomVersion,
):
) -> str:
# autogen room IDs and try to create it. We may clash, so just
# try a few times till one goes through, giving up eventually.
attempts = 0
@ -1097,7 +1095,7 @@ class RoomContextHandler:
users = await self.store.get_users_in_room(room_id)
is_peeking = user.to_string() not in users
async def filter_evts(events):
async def filter_evts(events: List[EventBase]) -> List[EventBase]:
if use_admin_priviledge:
return events
return await filter_events_for_client(

View file

@ -14,7 +14,7 @@
import logging
from collections import namedtuple
from typing import TYPE_CHECKING, Optional, Tuple
from typing import TYPE_CHECKING, Any, Optional, Tuple
import msgpack
from unpaddedbase64 import decode_base64, encode_base64
@ -33,7 +33,7 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.types import JsonDict, ThirdPartyInstanceID
from synapse.util.caches.descriptors import cached
from synapse.util.caches.descriptors import _CacheContext, cached
from synapse.util.caches.response_cache import ResponseCache
from ._base import BaseHandler
@ -169,7 +169,7 @@ class RoomListHandler(BaseHandler):
ignore_non_federatable=from_federation,
)
def build_room_entry(room):
def build_room_entry(room: JsonDict) -> JsonDict:
entry = {
"room_id": room["room_id"],
"name": room["name"],
@ -249,10 +249,10 @@ class RoomListHandler(BaseHandler):
self,
room_id: str,
num_joined_users: int,
cache_context,
cache_context: _CacheContext,
with_alias: bool = True,
allow_private: bool = False,
) -> Optional[dict]:
) -> Optional[JsonDict]:
"""Returns the entry for a room
Args:
@ -507,7 +507,7 @@ class RoomListNextBatch(
)
)
def copy_and_replace(self, **kwds) -> "RoomListNextBatch":
def copy_and_replace(self, **kwds: Any) -> "RoomListNextBatch":
return self._replace(**kwds)

View file

@ -225,7 +225,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
room_id: Optional[str],
n_invites: int,
update: bool = True,
):
) -> None:
"""Ratelimit more than one invite sent by the given requester in the given room.
Args:
@ -249,7 +249,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
requester: Optional[Requester],
room_id: Optional[str],
invitee_user_id: str,
):
) -> None:
"""Ratelimit invites by room and by target user.
If room ID is missing then we just rate limit by target user.
@ -386,7 +386,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
return result_event.event_id, result_event.internal_metadata.stream_ordering
async def copy_room_tags_and_direct_to_room(
self, old_room_id, new_room_id, user_id
self, old_room_id: str, new_room_id: str, user_id: str
) -> None:
"""Copies the tags and direct room state from one room to another.
@ -1030,7 +1030,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
event: EventBase,
context: EventContext,
ratelimit: bool = True,
):
) -> None:
"""
Change the membership status of a user in a room.

View file

@ -541,7 +541,7 @@ class RoomSummaryHandler:
origin: str,
requested_room_id: str,
suggested_only: bool,
):
) -> JsonDict:
"""
Implementation of the room hierarchy Federation API.

View file

@ -40,15 +40,15 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
@attr.s(slots=True)
@attr.s(slots=True, auto_attribs=True)
class Saml2SessionData:
"""Data we track about SAML2 sessions"""
# time the session was created, in milliseconds
creation_time = attr.ib()
creation_time: int
# The user interactive authentication session ID associated with this SAML
# session (or None if this SAML session is for an initial login).
ui_auth_session_id = attr.ib(type=Optional[str], default=None)
ui_auth_session_id: Optional[str] = None
class SamlHandler(BaseHandler):
@ -359,7 +359,7 @@ class SamlHandler(BaseHandler):
return remote_user_id
def expire_sessions(self):
def expire_sessions(self) -> None:
expire_before = self.clock.time_msec() - self._saml2_session_lifetime
to_expire = set()
for reqid, data in self._outstanding_requests_dict.items():
@ -391,10 +391,10 @@ MXID_MAPPER_MAP: Dict[str, Callable[[str], str]] = {
}
@attr.s
@attr.s(auto_attribs=True)
class SamlConfig:
mxid_source_attribute = attr.ib()
mxid_mapper = attr.ib()
mxid_source_attribute: str
mxid_mapper: Callable[[str], str]
class DefaultSamlMappingProvider:

View file

@ -17,7 +17,7 @@ import logging
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from io import BytesIO
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Any, Optional
from pkg_resources import parse_version
@ -79,7 +79,7 @@ async def _sendmail(
msg = BytesIO(msg_bytes)
d: "Deferred[object]" = Deferred()
def build_sender_factory(**kwargs) -> ESMTPSenderFactory:
def build_sender_factory(**kwargs: Any) -> ESMTPSenderFactory:
return ESMTPSenderFactory(
username,
password,

View file

@ -205,7 +205,7 @@ class SsoHandler:
self._consent_at_registration = hs.config.consent.user_consent_at_registration
def register_identity_provider(self, p: SsoIdentityProvider):
def register_identity_provider(self, p: SsoIdentityProvider) -> None:
p_id = p.idp_id
assert p_id not in self._identity_providers
self._identity_providers[p_id] = p
@ -856,7 +856,7 @@ class SsoHandler:
async def handle_terms_accepted(
self, request: Request, session_id: str, terms_version: str
):
) -> None:
"""Handle a request to the new-user 'consent' endpoint
Will serve an HTTP response to the request.
@ -959,7 +959,7 @@ class SsoHandler:
new_user=True,
)
def _expire_old_sessions(self):
def _expire_old_sessions(self) -> None:
to_expire = []
now = int(self._clock.time_msec())

View file

@ -68,7 +68,7 @@ class StatsHandler:
self._is_processing = True
async def process():
async def process() -> None:
try:
await self._unsafe_process()
finally:

View file

@ -364,7 +364,9 @@ class SyncHandler:
)
else:
async def current_sync_callback(before_token, after_token) -> SyncResult:
async def current_sync_callback(
before_token: StreamToken, after_token: StreamToken
) -> SyncResult:
return await self.current_sync_for_user(sync_config, since_token)
result = await self.notifier.wait_for_events(
@ -1532,9 +1534,9 @@ class SyncHandler:
newly_joined_rooms = room_changes.newly_joined_rooms
newly_left_rooms = room_changes.newly_left_rooms
async def handle_room_entries(room_entry: "RoomSyncResultBuilder"):
async def handle_room_entries(room_entry: "RoomSyncResultBuilder") -> None:
logger.debug("Generating room entry for %s", room_entry.room_id)
res = await self._generate_room_entry(
await self._generate_room_entry(
sync_result_builder,
ignored_users,
room_entry,
@ -1544,7 +1546,6 @@ class SyncHandler:
always_include=sync_result_builder.full_state,
)
logger.debug("Generated room entry for %s", room_entry.room_id)
return res
await concurrently_execute(handle_room_entries, room_entries, 10)
@ -1925,7 +1926,7 @@ class SyncHandler:
tags: Optional[Dict[str, Dict[str, Any]]],
account_data: Dict[str, JsonDict],
always_include: bool = False,
):
) -> None:
"""Populates the `joined` and `archived` section of `sync_result_builder`
based on the `room_builder`.

View file

@ -14,7 +14,7 @@
import logging
import random
from collections import namedtuple
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple
from synapse.api.errors import AuthError, ShadowBanError, SynapseError
from synapse.appservice import ApplicationService
@ -485,7 +485,7 @@ class TypingNotificationEventSource:
return (events, handler._latest_room_serial)
async def get_new_events(
self, from_key: int, room_ids: Iterable[str], **kwargs
self, from_key: int, room_ids: Iterable[str], **kwargs: Any
) -> Tuple[List[JsonDict], int]:
with Measure(self.clock, "typing.get_new_events"):
from_key = int(from_key)

View file

@ -70,7 +70,7 @@ class DummyAuthChecker(UserInteractiveAuthChecker):
class TermsAuthChecker(UserInteractiveAuthChecker):
AUTH_TYPE = LoginType.TERMS
def is_enabled(self):
def is_enabled(self) -> bool:
return True
async def check_auth(self, authdict: dict, clientip: str) -> Any:

View file

@ -114,7 +114,7 @@ class UserDirectoryHandler(StateDeltasHandler):
if self._is_processing:
return
async def process():
async def process() -> None:
try:
await self._unsafe_process()
finally: