0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-06-25 05:58:21 +02:00

Add type hints to profile and base handlers. (#8609)

This commit is contained in:
Patrick Cloke 2020-10-21 06:44:31 -04:00 committed by GitHub
parent 9e0f22874f
commit de5cafe980
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 72 additions and 41 deletions

1
changelog.d/8609.misc Normal file
View file

@ -0,0 +1 @@
Add type hints to profile and base handler.

View file

@ -15,8 +15,9 @@ files =
synapse/events/builder.py, synapse/events/builder.py,
synapse/events/spamcheck.py, synapse/events/spamcheck.py,
synapse/federation, synapse/federation,
synapse/handlers/appservice.py, synapse/handlers/_base.py,
synapse/handlers/account_data.py, synapse/handlers/account_data.py,
synapse/handlers/appservice.py,
synapse/handlers/auth.py, synapse/handlers/auth.py,
synapse/handlers/cas_handler.py, synapse/handlers/cas_handler.py,
synapse/handlers/deactivate_account.py, synapse/handlers/deactivate_account.py,
@ -32,6 +33,7 @@ files =
synapse/handlers/pagination.py, synapse/handlers/pagination.py,
synapse/handlers/password_policy.py, synapse/handlers/password_policy.py,
synapse/handlers/presence.py, synapse/handlers/presence.py,
synapse/handlers/profile.py,
synapse/handlers/read_marker.py, synapse/handlers/read_marker.py,
synapse/handlers/room.py, synapse/handlers/room.py,
synapse/handlers/room_member.py, synapse/handlers/room_member.py,

View file

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Optional
import synapse.state import synapse.state
import synapse.storage import synapse.storage
@ -22,6 +23,9 @@ from synapse.api.constants import EventTypes, Membership
from synapse.api.ratelimiting import Ratelimiter from synapse.api.ratelimiting import Ratelimiter
from synapse.types import UserID from synapse.types import UserID
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -30,11 +34,7 @@ class BaseHandler:
Common base class for the event handlers. Common base class for the event handlers.
""" """
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
"""
Args:
hs (synapse.server.HomeServer):
"""
self.store = hs.get_datastore() # type: synapse.storage.DataStore self.store = hs.get_datastore() # type: synapse.storage.DataStore
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
@ -56,7 +56,7 @@ class BaseHandler:
clock=self.clock, clock=self.clock,
rate_hz=self.hs.config.rc_admin_redaction.per_second, rate_hz=self.hs.config.rc_admin_redaction.per_second,
burst_count=self.hs.config.rc_admin_redaction.burst_count, burst_count=self.hs.config.rc_admin_redaction.burst_count,
) ) # type: Optional[Ratelimiter]
else: else:
self.admin_redaction_ratelimiter = None self.admin_redaction_ratelimiter = None
@ -127,15 +127,15 @@ class BaseHandler:
if guest_access != "can_join": if guest_access != "can_join":
if context: if context:
current_state_ids = await context.get_current_state_ids() current_state_ids = await context.get_current_state_ids()
current_state = await self.store.get_events( current_state_dict = await self.store.get_events(
list(current_state_ids.values()) list(current_state_ids.values())
) )
current_state = list(current_state_dict.values())
else: else:
current_state = await self.state_handler.get_current_state( current_state_map = await self.state_handler.get_current_state(
event.room_id event.room_id
) )
current_state = list(current_state_map.values())
current_state = list(current_state.values())
logger.info("maybe_kick_guest_users %r", current_state) logger.info("maybe_kick_guest_users %r", current_state)
await self.kick_guest_users(current_state) await self.kick_guest_users(current_state)

View file

@ -293,6 +293,10 @@ class InitialSyncHandler(BaseHandler):
user_id, room_id, pagin_config, membership, is_peeking user_id, room_id, pagin_config, membership, is_peeking
) )
elif membership == Membership.LEAVE: elif membership == Membership.LEAVE:
# The member_event_id will always be available if membership is set
# to leave.
assert member_event_id
result = await self._room_initial_sync_parted( result = await self._room_initial_sync_parted(
user_id, room_id, pagin_config, membership, member_event_id, is_peeking user_id, room_id, pagin_config, membership, member_event_id, is_peeking
) )
@ -315,7 +319,7 @@ class InitialSyncHandler(BaseHandler):
user_id: str, user_id: str,
room_id: str, room_id: str,
pagin_config: PaginationConfig, pagin_config: PaginationConfig,
membership: Membership, membership: str,
member_event_id: str, member_event_id: str,
is_peeking: bool, is_peeking: bool,
) -> JsonDict: ) -> JsonDict:
@ -367,7 +371,7 @@ class InitialSyncHandler(BaseHandler):
user_id: str, user_id: str,
room_id: str, room_id: str,
pagin_config: PaginationConfig, pagin_config: PaginationConfig,
membership: Membership, membership: str,
is_peeking: bool, is_peeking: bool,
) -> JsonDict: ) -> JsonDict:
current_state = await self.state.get_current_state(room_id=room_id) current_state = await self.state.get_current_state(room_id=room_id)

View file

@ -12,9 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
import random import random
from typing import TYPE_CHECKING, Optional
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, AuthError,
@ -25,10 +25,19 @@ from synapse.api.errors import (
SynapseError, SynapseError,
) )
from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.types import UserID, create_requester, get_domain_from_id from synapse.types import (
JsonDict,
Requester,
UserID,
create_requester,
get_domain_from_id,
)
from ._base import BaseHandler from ._base import BaseHandler
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
MAX_DISPLAYNAME_LEN = 256 MAX_DISPLAYNAME_LEN = 256
@ -45,7 +54,7 @@ class ProfileHandler(BaseHandler):
PROFILE_UPDATE_MS = 60 * 1000 PROFILE_UPDATE_MS = 60 * 1000
PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000 PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.federation = hs.get_federation_client() self.federation = hs.get_federation_client()
@ -60,7 +69,7 @@ class ProfileHandler(BaseHandler):
self._update_remote_profile_cache, self.PROFILE_UPDATE_MS self._update_remote_profile_cache, self.PROFILE_UPDATE_MS
) )
async def get_profile(self, user_id): async def get_profile(self, user_id: str) -> JsonDict:
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
if self.hs.is_mine(target_user): if self.hs.is_mine(target_user):
@ -91,7 +100,7 @@ class ProfileHandler(BaseHandler):
except HttpResponseException as e: except HttpResponseException as e:
raise e.to_synapse_error() raise e.to_synapse_error()
async def get_profile_from_cache(self, user_id): async def get_profile_from_cache(self, user_id: str) -> JsonDict:
"""Get the profile information from our local cache. If the user is """Get the profile information from our local cache. If the user is
ours then the profile information will always be corect. Otherwise, ours then the profile information will always be corect. Otherwise,
it may be out of date/missing. it may be out of date/missing.
@ -115,7 +124,7 @@ class ProfileHandler(BaseHandler):
profile = await self.store.get_from_remote_profile_cache(user_id) profile = await self.store.get_from_remote_profile_cache(user_id)
return profile or {} return profile or {}
async def get_displayname(self, target_user): async def get_displayname(self, target_user: UserID) -> str:
if self.hs.is_mine(target_user): if self.hs.is_mine(target_user):
try: try:
displayname = await self.store.get_profile_displayname( displayname = await self.store.get_profile_displayname(
@ -143,15 +152,19 @@ class ProfileHandler(BaseHandler):
return result["displayname"] return result["displayname"]
async def set_displayname( async def set_displayname(
self, target_user, requester, new_displayname, by_admin=False self,
): target_user: UserID,
requester: Requester,
new_displayname: str,
by_admin: bool = False,
) -> None:
"""Set the displayname of a user """Set the displayname of a user
Args: Args:
target_user (UserID): the user whose displayname is to be changed. target_user: the user whose displayname is to be changed.
requester (Requester): The user attempting to make this change. requester: The user attempting to make this change.
new_displayname (str): The displayname to give this user. new_displayname: The displayname to give this user.
by_admin (bool): Whether this change was made by an administrator. by_admin: Whether this change was made by an administrator.
""" """
if not self.hs.is_mine(target_user): if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this homeserver") raise SynapseError(400, "User is not hosted on this homeserver")
@ -176,8 +189,9 @@ class ProfileHandler(BaseHandler):
400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,) 400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,)
) )
displayname_to_set = new_displayname # type: Optional[str]
if new_displayname == "": if new_displayname == "":
new_displayname = None displayname_to_set = None
# If the admin changes the display name of a user, the requesting user cannot send # If the admin changes the display name of a user, the requesting user cannot send
# the join event to update the displayname in the rooms. # the join event to update the displayname in the rooms.
@ -185,7 +199,9 @@ class ProfileHandler(BaseHandler):
if by_admin: if by_admin:
requester = create_requester(target_user) requester = create_requester(target_user)
await self.store.set_profile_displayname(target_user.localpart, new_displayname) await self.store.set_profile_displayname(
target_user.localpart, displayname_to_set
)
if self.hs.config.user_directory_search_all_users: if self.hs.config.user_directory_search_all_users:
profile = await self.store.get_profileinfo(target_user.localpart) profile = await self.store.get_profileinfo(target_user.localpart)
@ -195,7 +211,7 @@ class ProfileHandler(BaseHandler):
await self._update_join_states(requester, target_user) await self._update_join_states(requester, target_user)
async def get_avatar_url(self, target_user): async def get_avatar_url(self, target_user: UserID) -> str:
if self.hs.is_mine(target_user): if self.hs.is_mine(target_user):
try: try:
avatar_url = await self.store.get_profile_avatar_url( avatar_url = await self.store.get_profile_avatar_url(
@ -222,15 +238,19 @@ class ProfileHandler(BaseHandler):
return result["avatar_url"] return result["avatar_url"]
async def set_avatar_url( async def set_avatar_url(
self, target_user, requester, new_avatar_url, by_admin=False self,
target_user: UserID,
requester: Requester,
new_avatar_url: str,
by_admin: bool = False,
): ):
"""Set a new avatar URL for a user. """Set a new avatar URL for a user.
Args: Args:
target_user (UserID): the user whose avatar URL is to be changed. target_user: the user whose avatar URL is to be changed.
requester (Requester): The user attempting to make this change. requester: The user attempting to make this change.
new_avatar_url (str): The avatar URL to give this user. new_avatar_url: The avatar URL to give this user.
by_admin (bool): Whether this change was made by an administrator. by_admin: Whether this change was made by an administrator.
""" """
if not self.hs.is_mine(target_user): if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this homeserver") raise SynapseError(400, "User is not hosted on this homeserver")
@ -267,7 +287,7 @@ class ProfileHandler(BaseHandler):
await self._update_join_states(requester, target_user) await self._update_join_states(requester, target_user)
async def on_profile_query(self, args): async def on_profile_query(self, args: JsonDict) -> JsonDict:
user = UserID.from_string(args["user_id"]) user = UserID.from_string(args["user_id"])
if not self.hs.is_mine(user): if not self.hs.is_mine(user):
raise SynapseError(400, "User is not hosted on this homeserver") raise SynapseError(400, "User is not hosted on this homeserver")
@ -292,7 +312,9 @@ class ProfileHandler(BaseHandler):
return response return response
async def _update_join_states(self, requester, target_user): async def _update_join_states(
self, requester: Requester, target_user: UserID
) -> None:
if not self.hs.is_mine(target_user): if not self.hs.is_mine(target_user):
return return
@ -323,15 +345,17 @@ class ProfileHandler(BaseHandler):
"Failed to update join event for room %s - %s", room_id, str(e) "Failed to update join event for room %s - %s", room_id, str(e)
) )
async def check_profile_query_allowed(self, target_user, requester=None): async def check_profile_query_allowed(
self, target_user: UserID, requester: Optional[UserID] = None
) -> None:
"""Checks whether a profile query is allowed. If the """Checks whether a profile query is allowed. If the
'require_auth_for_profile_requests' config flag is set to True and a 'require_auth_for_profile_requests' config flag is set to True and a
'requester' is provided, the query is only allowed if the two users 'requester' is provided, the query is only allowed if the two users
share a room. share a room.
Args: Args:
target_user (UserID): The owner of the queried profile. target_user: The owner of the queried profile.
requester (None|UserID): The user querying for the profile. requester: The user querying for the profile.
Raises: Raises:
SynapseError(403): The two users share no room, or ne user couldn't SynapseError(403): The two users share no room, or ne user couldn't

View file

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Dict, Optional from typing import Any, Dict, List, Optional
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
@ -72,7 +72,7 @@ class ProfileWorkerStore(SQLBaseStore):
) )
async def set_profile_displayname( async def set_profile_displayname(
self, user_localpart: str, new_displayname: str self, user_localpart: str, new_displayname: Optional[str]
) -> None: ) -> None:
await self.db_pool.simple_update_one( await self.db_pool.simple_update_one(
table="profiles", table="profiles",
@ -144,7 +144,7 @@ class ProfileWorkerStore(SQLBaseStore):
async def get_remote_profile_cache_entries_that_expire( async def get_remote_profile_cache_entries_that_expire(
self, last_checked: int self, last_checked: int
) -> Dict[str, str]: ) -> List[Dict[str, str]]:
"""Get all users who haven't been checked since `last_checked` """Get all users who haven't been checked since `last_checked`
""" """