0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-30 01:33:59 +01:00

Add missing type hints to synapse.api. (#11109)

* Convert UserPresenceState to attrs.
* Remove args/kwargs from error classes and explicitly pass msg/errorcode.
This commit is contained in:
Patrick Cloke 2021-10-18 15:01:10 -04:00 committed by GitHub
parent cc33d9eee2
commit 3ab55d43bd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 84 additions and 99 deletions

1
changelog.d/11109.misc Normal file
View file

@ -0,0 +1 @@
Add missing type hints to `synapse.api` module.

View file

@ -100,6 +100,9 @@ files =
tests/util/test_itertools.py,
tests/util/test_stream_change_cache.py
[mypy-synapse.api.*]
disallow_untyped_defs = True
[mypy-synapse.events.*]
disallow_untyped_defs = True

View file

@ -245,7 +245,7 @@ class Auth:
async def validate_appservice_can_control_user_id(
self, app_service: ApplicationService, user_id: str
):
) -> None:
"""Validates that the app service is allowed to control
the given user.
@ -618,5 +618,13 @@ class Auth:
% (user_id, room_id),
)
async def check_auth_blocking(self, *args, **kwargs) -> None:
await self._auth_blocking.check_auth_blocking(*args, **kwargs)
async def check_auth_blocking(
self,
user_id: Optional[str] = None,
threepid: Optional[dict] = None,
user_type: Optional[str] = None,
requester: Optional[Requester] = None,
) -> None:
await self._auth_blocking.check_auth_blocking(
user_id=user_id, threepid=threepid, user_type=user_type, requester=requester
)

View file

@ -18,7 +18,7 @@
import logging
import typing
from http import HTTPStatus
from typing import Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union
from twisted.web import http
@ -143,7 +143,7 @@ class SynapseError(CodeMessageException):
super().__init__(code, msg)
self.errcode = errcode
def error_dict(self):
def error_dict(self) -> "JsonDict":
return cs_error(self.msg, self.errcode)
@ -175,7 +175,7 @@ class ProxiedRequestError(SynapseError):
else:
self._additional_fields = dict(additional_fields)
def error_dict(self):
def error_dict(self) -> "JsonDict":
return cs_error(self.msg, self.errcode, **self._additional_fields)
@ -196,7 +196,7 @@ class ConsentNotGivenError(SynapseError):
)
self._consent_uri = consent_uri
def error_dict(self):
def error_dict(self) -> "JsonDict":
return cs_error(self.msg, self.errcode, consent_uri=self._consent_uri)
@ -262,14 +262,10 @@ class InteractiveAuthIncompleteError(Exception):
class UnrecognizedRequestError(SynapseError):
"""An error indicating we don't understand the request you're trying to make"""
def __init__(self, *args, **kwargs):
if "errcode" not in kwargs:
kwargs["errcode"] = Codes.UNRECOGNIZED
if len(args) == 0:
message = "Unrecognized request"
else:
message = args[0]
super().__init__(400, message, **kwargs)
def __init__(
self, msg: str = "Unrecognized request", errcode: str = Codes.UNRECOGNIZED
):
super().__init__(400, msg, errcode)
class NotFoundError(SynapseError):
@ -284,10 +280,8 @@ class AuthError(SynapseError):
other poorly-defined times.
"""
def __init__(self, *args, **kwargs):
if "errcode" not in kwargs:
kwargs["errcode"] = Codes.FORBIDDEN
super().__init__(*args, **kwargs)
def __init__(self, code: int, msg: str, errcode: str = Codes.FORBIDDEN):
super().__init__(code, msg, errcode)
class InvalidClientCredentialsError(SynapseError):
@ -321,7 +315,7 @@ class InvalidClientTokenError(InvalidClientCredentialsError):
super().__init__(msg=msg, errcode="M_UNKNOWN_TOKEN")
self._soft_logout = soft_logout
def error_dict(self):
def error_dict(self) -> "JsonDict":
d = super().error_dict()
d["soft_logout"] = self._soft_logout
return d
@ -345,7 +339,7 @@ class ResourceLimitError(SynapseError):
self.limit_type = limit_type
super().__init__(code, msg, errcode=errcode)
def error_dict(self):
def error_dict(self) -> "JsonDict":
return cs_error(
self.msg,
self.errcode,
@ -357,32 +351,17 @@ class ResourceLimitError(SynapseError):
class EventSizeError(SynapseError):
"""An error raised when an event is too big."""
def __init__(self, *args, **kwargs):
if "errcode" not in kwargs:
kwargs["errcode"] = Codes.TOO_LARGE
super().__init__(413, *args, **kwargs)
class EventStreamError(SynapseError):
"""An error raised when there a problem with the event stream."""
def __init__(self, *args, **kwargs):
if "errcode" not in kwargs:
kwargs["errcode"] = Codes.BAD_PAGINATION
super().__init__(*args, **kwargs)
def __init__(self, msg: str):
super().__init__(413, msg, Codes.TOO_LARGE)
class LoginError(SynapseError):
"""An error raised when there was a problem logging in."""
pass
class StoreError(SynapseError):
"""An error raised when there was a problem storing some data."""
pass
class InvalidCaptchaError(SynapseError):
def __init__(
@ -395,7 +374,7 @@ class InvalidCaptchaError(SynapseError):
super().__init__(code, msg, errcode)
self.error_url = error_url
def error_dict(self):
def error_dict(self) -> "JsonDict":
return cs_error(self.msg, self.errcode, error_url=self.error_url)
@ -412,7 +391,7 @@ class LimitExceededError(SynapseError):
super().__init__(code, msg, errcode)
self.retry_after_ms = retry_after_ms
def error_dict(self):
def error_dict(self) -> "JsonDict":
return cs_error(self.msg, self.errcode, retry_after_ms=self.retry_after_ms)
@ -443,10 +422,8 @@ class UnsupportedRoomVersionError(SynapseError):
class ThreepidValidationError(SynapseError):
"""An error raised when there was a problem authorising an event."""
def __init__(self, *args, **kwargs):
if "errcode" not in kwargs:
kwargs["errcode"] = Codes.FORBIDDEN
super().__init__(*args, **kwargs)
def __init__(self, msg: str, errcode: str = Codes.FORBIDDEN):
super().__init__(400, msg, errcode)
class IncompatibleRoomVersionError(SynapseError):
@ -466,7 +443,7 @@ class IncompatibleRoomVersionError(SynapseError):
self._room_version = room_version
def error_dict(self):
def error_dict(self) -> "JsonDict":
return cs_error(self.msg, self.errcode, room_version=self._room_version)
@ -494,7 +471,7 @@ class RequestSendFailed(RuntimeError):
errors (like programming errors).
"""
def __init__(self, inner_exception, can_retry):
def __init__(self, inner_exception: BaseException, can_retry: bool):
super().__init__(
"Failed to send request: %s: %s"
% (type(inner_exception).__name__, inner_exception)
@ -503,7 +480,7 @@ class RequestSendFailed(RuntimeError):
self.can_retry = can_retry
def cs_error(msg: str, code: str = Codes.UNKNOWN, **kwargs):
def cs_error(msg: str, code: str = Codes.UNKNOWN, **kwargs: Any) -> "JsonDict":
"""Utility method for constructing an error response for client-server
interactions.
@ -551,7 +528,7 @@ class FederationError(RuntimeError):
msg = "%s %s: %s" % (level, code, reason)
super().__init__(msg)
def get_dict(self):
def get_dict(self) -> "JsonDict":
return {
"level": self.level,
"code": self.code,
@ -580,7 +557,7 @@ class HttpResponseException(CodeMessageException):
super().__init__(code, msg)
self.response = response
def to_synapse_error(self):
def to_synapse_error(self) -> SynapseError:
"""Make a SynapseError based on an HTTPResponseException
This is useful when a proxied request has failed, and we need to

View file

@ -231,24 +231,24 @@ class FilterCollection:
def include_redundant_members(self) -> bool:
return self._room_state_filter.include_redundant_members()
def filter_presence(self, events):
def filter_presence(
self, events: Iterable[UserPresenceState]
) -> List[UserPresenceState]:
return self._presence_filter.filter(events)
def filter_account_data(self, events):
def filter_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]:
return self._account_data.filter(events)
def filter_room_state(self, events):
def filter_room_state(self, events: Iterable[EventBase]) -> List[EventBase]:
return self._room_state_filter.filter(self._room_filter.filter(events))
def filter_room_timeline(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
def filter_room_timeline(self, events: Iterable[EventBase]) -> List[EventBase]:
return self._room_timeline_filter.filter(self._room_filter.filter(events))
def filter_room_ephemeral(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
def filter_room_ephemeral(self, events: Iterable[JsonDict]) -> List[JsonDict]:
return self._room_ephemeral_filter.filter(self._room_filter.filter(events))
def filter_room_account_data(
self, events: Iterable[FilterEvent]
) -> List[FilterEvent]:
def filter_room_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]:
return self._room_account_data.filter(self._room_filter.filter(events))
def blocks_all_presence(self) -> bool:
@ -309,7 +309,7 @@ class Filter:
# except for presence which actually gets passed around as its own
# namedtuple type.
if isinstance(event, UserPresenceState):
sender = event.user_id
sender: Optional[str] = event.user_id
room_id = None
ev_type = "m.presence"
contains_url = False

View file

@ -12,49 +12,48 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
from typing import Any, Optional
import attr
from synapse.api.constants import PresenceState
from synapse.types import JsonDict
class UserPresenceState(
namedtuple(
"UserPresenceState",
(
"user_id",
"state",
"last_active_ts",
"last_federation_update_ts",
"last_user_sync_ts",
"status_msg",
"currently_active",
),
)
):
@attr.s(slots=True, frozen=True, auto_attribs=True)
class UserPresenceState:
"""Represents the current presence state of the user.
user_id (str)
last_active (int): Time in msec that the user last interacted with server.
last_federation_update (int): Time in msec since either a) we sent a presence
user_id
last_active: Time in msec that the user last interacted with server.
last_federation_update: Time in msec since either a) we sent a presence
update to other servers or b) we received a presence update, depending
on if is a local user or not.
last_user_sync (int): Time in msec that the user last *completed* a sync
last_user_sync: Time in msec that the user last *completed* a sync
(or event stream).
status_msg (str): User set status message.
status_msg: User set status message.
"""
def as_dict(self):
return dict(self._asdict())
user_id: str
state: str
last_active_ts: int
last_federation_update_ts: int
last_user_sync_ts: int
status_msg: Optional[str]
currently_active: bool
def as_dict(self) -> JsonDict:
return attr.asdict(self)
@staticmethod
def from_dict(d):
def from_dict(d: JsonDict) -> "UserPresenceState":
return UserPresenceState(**d)
def copy_and_replace(self, **kwargs):
return self._replace(**kwargs)
def copy_and_replace(self, **kwargs: Any) -> "UserPresenceState":
return attr.evolve(self, **kwargs)
@classmethod
def default(cls, user_id):
def default(cls, user_id: str) -> "UserPresenceState":
"""Returns a default presence state."""
return cls(
user_id=user_id,

View file

@ -161,7 +161,7 @@ class Ratelimiter:
return allowed, time_allowed
def _prune_message_counts(self, time_now_s: float):
def _prune_message_counts(self, time_now_s: float) -> None:
"""Remove message count entries that have not exceeded their defined
rate_hz limit
@ -190,7 +190,7 @@ class Ratelimiter:
update: bool = True,
n_actions: int = 1,
_time_now_s: Optional[float] = None,
):
) -> None:
"""Checks if an action can be performed. If not, raises a LimitExceededError
Checks if the user has ratelimiting disabled in the database by looking

View file

@ -19,6 +19,7 @@ from hashlib import sha256
from urllib.parse import urlencode
from synapse.config import ConfigError
from synapse.config.homeserver import HomeServerConfig
SYNAPSE_CLIENT_API_PREFIX = "/_synapse/client"
CLIENT_API_PREFIX = "/_matrix/client"
@ -34,11 +35,7 @@ LEGACY_MEDIA_PREFIX = "/_matrix/media/v1"
class ConsentURIBuilder:
def __init__(self, hs_config):
"""
Args:
hs_config (synapse.config.homeserver.HomeServerConfig):
"""
def __init__(self, hs_config: HomeServerConfig):
if hs_config.key.form_secret is None:
raise ConfigError("form_secret not set in config")
if hs_config.server.public_baseurl is None:
@ -47,15 +44,15 @@ class ConsentURIBuilder:
self._hmac_secret = hs_config.key.form_secret.encode("utf-8")
self._public_baseurl = hs_config.server.public_baseurl
def build_user_consent_uri(self, user_id):
def build_user_consent_uri(self, user_id: str) -> str:
"""Build a URI which we can give to the user to do their privacy
policy consent
Args:
user_id (str): mxid or username of user
user_id: mxid or username of user
Returns
(str) the URI where the user can do consent
The URI where the user can do consent
"""
mac = hmac.new(
key=self._hmac_secret, msg=user_id.encode("ascii"), digestmod=sha256

View file

@ -1489,7 +1489,7 @@ def format_user_presence_state(
The "user_id" is optional so that this function can be used to format presence
updates for client /sync responses and for federation /send requests.
"""
content = {"presence": state.state}
content: JsonDict = {"presence": state.state}
if include_user_id:
content["user_id"] = state.user_id
if state.last_active_ts:

View file

@ -2237,7 +2237,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
# accident.
row = {"client_secret": None, "validated_at": None}
else:
raise ThreepidValidationError(400, "Unknown session_id")
raise ThreepidValidationError("Unknown session_id")
retrieved_client_secret = row["client_secret"]
validated_at = row["validated_at"]
@ -2252,14 +2252,14 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
if not row:
raise ThreepidValidationError(
400, "Validation token not found or has expired"
"Validation token not found or has expired"
)
expires = row["expires"]
next_link = row["next_link"]
if retrieved_client_secret != client_secret:
raise ThreepidValidationError(
400, "This client_secret does not match the provided session_id"
"This client_secret does not match the provided session_id"
)
# If the session is already validated, no need to revalidate
@ -2268,7 +2268,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
if expires <= current_ts:
raise ThreepidValidationError(
400, "This token has expired. Please request a new one"
"This token has expired. Please request a new one"
)
# Looks good. Validate the session