Add type hints to some handlers (#8505)

This commit is contained in:
Patrick Cloke 2020-10-09 07:20:51 -04:00 committed by GitHub
parent a97cec18bb
commit a93f3121f8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 60 additions and 22 deletions

1
changelog.d/8505.misc Normal file
View file

@ -0,0 +1 @@
Add type hints to various parts of the code base.

View file

@ -15,9 +15,12 @@ files =
synapse/events/builder.py, synapse/events/builder.py,
synapse/events/spamcheck.py, synapse/events/spamcheck.py,
synapse/federation, synapse/federation,
synapse/handlers/account_data.py,
synapse/handlers/auth.py, synapse/handlers/auth.py,
synapse/handlers/cas_handler.py, synapse/handlers/cas_handler.py,
synapse/handlers/deactivate_account.py,
synapse/handlers/device.py, synapse/handlers/device.py,
synapse/handlers/devicemessage.py,
synapse/handlers/directory.py, synapse/handlers/directory.py,
synapse/handlers/events.py, synapse/handlers/events.py,
synapse/handlers/federation.py, synapse/handlers/federation.py,
@ -26,7 +29,9 @@ files =
synapse/handlers/message.py, synapse/handlers/message.py,
synapse/handlers/oidc_handler.py, synapse/handlers/oidc_handler.py,
synapse/handlers/pagination.py, synapse/handlers/pagination.py,
synapse/handlers/password_policy.py,
synapse/handlers/presence.py, synapse/handlers/presence.py,
synapse/handlers/read_marker.py,
synapse/handlers/room.py, synapse/handlers/room.py,
synapse/handlers/room_member.py, synapse/handlers/room_member.py,
synapse/handlers/room_member_worker.py, synapse/handlers/room_member_worker.py,

View file

@ -861,7 +861,7 @@ class FederationHandlerRegistry:
self._edu_type_to_instance = {} # type: Dict[str, str] self._edu_type_to_instance = {} # type: Dict[str, str]
def register_edu_handler( def register_edu_handler(
self, edu_type: str, handler: Callable[[str, dict], Awaitable[None]] self, edu_type: str, handler: Callable[[str, JsonDict], Awaitable[None]]
): ):
"""Sets the handler callable that will be used to handle an incoming """Sets the handler callable that will be used to handle an incoming
federation EDU of the given type. federation EDU of the given type.

View file

@ -12,16 +12,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING, List, Tuple
from synapse.types import JsonDict, UserID
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
class AccountDataEventSource: class AccountDataEventSource:
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() self.store = hs.get_datastore()
def get_current_key(self, direction="f"): def get_current_key(self, direction: str = "f") -> int:
return self.store.get_max_account_data_stream_id() return self.store.get_max_account_data_stream_id()
async def get_new_events(self, user, from_key, **kwargs): async def get_new_events(
self, user: UserID, from_key: int, **kwargs
) -> Tuple[List[JsonDict], int]:
user_id = user.to_string() user_id = user.to_string()
last_stream_id = from_key last_stream_id = from_key

View file

@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Optional from typing import TYPE_CHECKING, Optional
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
@ -22,13 +22,16 @@ from synapse.types import UserID, create_requester
from ._base import BaseHandler from ._base import BaseHandler
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DeactivateAccountHandler(BaseHandler): class DeactivateAccountHandler(BaseHandler):
"""Handler which deals with deactivating user accounts.""" """Handler which deals with deactivating user accounts."""
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.hs = hs self.hs = hs
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
@ -137,7 +140,7 @@ class DeactivateAccountHandler(BaseHandler):
return identity_server_supports_unbinding return identity_server_supports_unbinding
async def _reject_pending_invites_for_user(self, user_id: str): async def _reject_pending_invites_for_user(self, user_id: str) -> None:
"""Reject pending invites addressed to a given user ID. """Reject pending invites addressed to a given user ID.
Args: Args:

View file

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Any, Dict from typing import TYPE_CHECKING, Any, Dict
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.logging.context import run_in_background from synapse.logging.context import run_in_background
@ -24,18 +24,22 @@ from synapse.logging.opentracing import (
set_tag, set_tag,
start_active_span, start_active_span,
) )
from synapse.types import UserID, get_domain_from_id from synapse.types import JsonDict, UserID, get_domain_from_id
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DeviceMessageHandler: class DeviceMessageHandler:
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
""" """
Args: Args:
hs (synapse.server.HomeServer): server hs: server
""" """
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
@ -48,7 +52,7 @@ class DeviceMessageHandler:
self._device_list_updater = hs.get_device_handler().device_list_updater self._device_list_updater = hs.get_device_handler().device_list_updater
async def on_direct_to_device_edu(self, origin, content): async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None:
local_messages = {} local_messages = {}
sender_user_id = content["sender"] sender_user_id = content["sender"]
if origin != get_domain_from_id(sender_user_id): if origin != get_domain_from_id(sender_user_id):
@ -95,7 +99,7 @@ class DeviceMessageHandler:
message_type: str, message_type: str,
sender_user_id: str, sender_user_id: str,
by_device: Dict[str, Dict[str, Any]], by_device: Dict[str, Dict[str, Any]],
): ) -> None:
"""Checks inbound device messages for unknown remote devices, and if """Checks inbound device messages for unknown remote devices, and if
found marks the remote cache for the user as stale. found marks the remote cache for the user as stale.
""" """
@ -138,11 +142,16 @@ class DeviceMessageHandler:
self._device_list_updater.user_device_resync, sender_user_id self._device_list_updater.user_device_resync, sender_user_id
) )
async def send_device_message(self, sender_user_id, message_type, messages): async def send_device_message(
self,
sender_user_id: str,
message_type: str,
messages: Dict[str, Dict[str, JsonDict]],
) -> None:
set_tag("number_of_messages", len(messages)) set_tag("number_of_messages", len(messages))
set_tag("sender", sender_user_id) set_tag("sender", sender_user_id)
local_messages = {} local_messages = {}
remote_messages = {} remote_messages = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
for user_id, by_device in messages.items(): for user_id, by_device in messages.items():
# we use UserID.from_string to catch invalid user ids # we use UserID.from_string to catch invalid user ids
if self.is_mine(UserID.from_string(user_id)): if self.is_mine(UserID.from_string(user_id)):

View file

@ -16,14 +16,18 @@
import logging import logging
import re import re
from typing import TYPE_CHECKING
from synapse.api.errors import Codes, PasswordRefusedError from synapse.api.errors import Codes, PasswordRefusedError
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PasswordPolicyHandler: class PasswordPolicyHandler:
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.policy = hs.config.password_policy self.policy = hs.config.password_policy
self.enabled = hs.config.password_policy_enabled self.enabled = hs.config.password_policy_enabled
@ -33,11 +37,11 @@ class PasswordPolicyHandler:
self.regexp_uppercase = re.compile("[A-Z]") self.regexp_uppercase = re.compile("[A-Z]")
self.regexp_lowercase = re.compile("[a-z]") self.regexp_lowercase = re.compile("[a-z]")
def validate_password(self, password): def validate_password(self, password: str) -> None:
"""Checks whether a given password complies with the server's policy. """Checks whether a given password complies with the server's policy.
Args: Args:
password (str): The password to check against the server's policy. password: The password to check against the server's policy.
Raises: Raises:
PasswordRefusedError: The password doesn't comply with the server's policy. PasswordRefusedError: The password doesn't comply with the server's policy.

View file

@ -14,23 +14,29 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from ._base import BaseHandler from ._base import BaseHandler
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ReadMarkerHandler(BaseHandler): class ReadMarkerHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.server_name = hs.config.server_name self.server_name = hs.config.server_name
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.read_marker_linearizer = Linearizer(name="read_marker") self.read_marker_linearizer = Linearizer(name="read_marker")
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
async def received_client_read_marker(self, room_id, user_id, event_id): async def received_client_read_marker(
self, room_id: str, user_id: str, event_id: str
) -> None:
"""Updates the read marker for a given user in a given room if the event ID given """Updates the read marker for a given user in a given room if the event ID given
is ahead in the stream relative to the current read marker. is ahead in the stream relative to the current read marker.

View file

@ -339,7 +339,7 @@ class Notifier:
self, self,
stream_key: str, stream_key: str,
new_token: Union[int, RoomStreamToken], new_token: Union[int, RoomStreamToken],
users: Collection[UserID] = [], users: Collection[Union[str, UserID]] = [],
rooms: Collection[str] = [], rooms: Collection[str] = [],
): ):
""" Used to inform listeners that something has happened event wise. """ Used to inform listeners that something has happened event wise.

View file

@ -1220,7 +1220,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
desc="record_user_external_id", desc="record_user_external_id",
) )
async def user_set_password_hash(self, user_id: str, password_hash: str) -> None: async def user_set_password_hash(
self, user_id: str, password_hash: Optional[str]
) -> None:
""" """
NB. This does *not* evict any cache because the one use for this NB. This does *not* evict any cache because the one use for this
removes most of the entries subsequently anyway so it would be removes most of the entries subsequently anyway so it would be