mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-15 23:03:51 +01:00
Add type hints to profile and base handlers. (#8609)
This commit is contained in:
parent
9e0f22874f
commit
de5cafe980
6 changed files with 72 additions and 41 deletions
1
changelog.d/8609.misc
Normal file
1
changelog.d/8609.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Add type hints to profile and base handler.
|
4
mypy.ini
4
mypy.ini
|
@ -15,8 +15,9 @@ files =
|
||||||
synapse/events/builder.py,
|
synapse/events/builder.py,
|
||||||
synapse/events/spamcheck.py,
|
synapse/events/spamcheck.py,
|
||||||
synapse/federation,
|
synapse/federation,
|
||||||
synapse/handlers/appservice.py,
|
synapse/handlers/_base.py,
|
||||||
synapse/handlers/account_data.py,
|
synapse/handlers/account_data.py,
|
||||||
|
synapse/handlers/appservice.py,
|
||||||
synapse/handlers/auth.py,
|
synapse/handlers/auth.py,
|
||||||
synapse/handlers/cas_handler.py,
|
synapse/handlers/cas_handler.py,
|
||||||
synapse/handlers/deactivate_account.py,
|
synapse/handlers/deactivate_account.py,
|
||||||
|
@ -32,6 +33,7 @@ files =
|
||||||
synapse/handlers/pagination.py,
|
synapse/handlers/pagination.py,
|
||||||
synapse/handlers/password_policy.py,
|
synapse/handlers/password_policy.py,
|
||||||
synapse/handlers/presence.py,
|
synapse/handlers/presence.py,
|
||||||
|
synapse/handlers/profile.py,
|
||||||
synapse/handlers/read_marker.py,
|
synapse/handlers/read_marker.py,
|
||||||
synapse/handlers/room.py,
|
synapse/handlers/room.py,
|
||||||
synapse/handlers/room_member.py,
|
synapse/handlers/room_member.py,
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
import synapse.state
|
import synapse.state
|
||||||
import synapse.storage
|
import synapse.storage
|
||||||
|
@ -22,6 +23,9 @@ from synapse.api.constants import EventTypes, Membership
|
||||||
from synapse.api.ratelimiting import Ratelimiter
|
from synapse.api.ratelimiting import Ratelimiter
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.app.homeserver import HomeServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -30,11 +34,7 @@ class BaseHandler:
|
||||||
Common base class for the event handlers.
|
Common base class for the event handlers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
hs (synapse.server.HomeServer):
|
|
||||||
"""
|
|
||||||
self.store = hs.get_datastore() # type: synapse.storage.DataStore
|
self.store = hs.get_datastore() # type: synapse.storage.DataStore
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.notifier = hs.get_notifier()
|
self.notifier = hs.get_notifier()
|
||||||
|
@ -56,7 +56,7 @@ class BaseHandler:
|
||||||
clock=self.clock,
|
clock=self.clock,
|
||||||
rate_hz=self.hs.config.rc_admin_redaction.per_second,
|
rate_hz=self.hs.config.rc_admin_redaction.per_second,
|
||||||
burst_count=self.hs.config.rc_admin_redaction.burst_count,
|
burst_count=self.hs.config.rc_admin_redaction.burst_count,
|
||||||
)
|
) # type: Optional[Ratelimiter]
|
||||||
else:
|
else:
|
||||||
self.admin_redaction_ratelimiter = None
|
self.admin_redaction_ratelimiter = None
|
||||||
|
|
||||||
|
@ -127,15 +127,15 @@ class BaseHandler:
|
||||||
if guest_access != "can_join":
|
if guest_access != "can_join":
|
||||||
if context:
|
if context:
|
||||||
current_state_ids = await context.get_current_state_ids()
|
current_state_ids = await context.get_current_state_ids()
|
||||||
current_state = await self.store.get_events(
|
current_state_dict = await self.store.get_events(
|
||||||
list(current_state_ids.values())
|
list(current_state_ids.values())
|
||||||
)
|
)
|
||||||
|
current_state = list(current_state_dict.values())
|
||||||
else:
|
else:
|
||||||
current_state = await self.state_handler.get_current_state(
|
current_state_map = await self.state_handler.get_current_state(
|
||||||
event.room_id
|
event.room_id
|
||||||
)
|
)
|
||||||
|
current_state = list(current_state_map.values())
|
||||||
current_state = list(current_state.values())
|
|
||||||
|
|
||||||
logger.info("maybe_kick_guest_users %r", current_state)
|
logger.info("maybe_kick_guest_users %r", current_state)
|
||||||
await self.kick_guest_users(current_state)
|
await self.kick_guest_users(current_state)
|
||||||
|
|
|
@ -293,6 +293,10 @@ class InitialSyncHandler(BaseHandler):
|
||||||
user_id, room_id, pagin_config, membership, is_peeking
|
user_id, room_id, pagin_config, membership, is_peeking
|
||||||
)
|
)
|
||||||
elif membership == Membership.LEAVE:
|
elif membership == Membership.LEAVE:
|
||||||
|
# The member_event_id will always be available if membership is set
|
||||||
|
# to leave.
|
||||||
|
assert member_event_id
|
||||||
|
|
||||||
result = await self._room_initial_sync_parted(
|
result = await self._room_initial_sync_parted(
|
||||||
user_id, room_id, pagin_config, membership, member_event_id, is_peeking
|
user_id, room_id, pagin_config, membership, member_event_id, is_peeking
|
||||||
)
|
)
|
||||||
|
@ -315,7 +319,7 @@ class InitialSyncHandler(BaseHandler):
|
||||||
user_id: str,
|
user_id: str,
|
||||||
room_id: str,
|
room_id: str,
|
||||||
pagin_config: PaginationConfig,
|
pagin_config: PaginationConfig,
|
||||||
membership: Membership,
|
membership: str,
|
||||||
member_event_id: str,
|
member_event_id: str,
|
||||||
is_peeking: bool,
|
is_peeking: bool,
|
||||||
) -> JsonDict:
|
) -> JsonDict:
|
||||||
|
@ -367,7 +371,7 @@ class InitialSyncHandler(BaseHandler):
|
||||||
user_id: str,
|
user_id: str,
|
||||||
room_id: str,
|
room_id: str,
|
||||||
pagin_config: PaginationConfig,
|
pagin_config: PaginationConfig,
|
||||||
membership: Membership,
|
membership: str,
|
||||||
is_peeking: bool,
|
is_peeking: bool,
|
||||||
) -> JsonDict:
|
) -> JsonDict:
|
||||||
current_state = await self.state.get_current_state(room_id=room_id)
|
current_state = await self.state.get_current_state(room_id=room_id)
|
||||||
|
|
|
@ -12,9 +12,9 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
AuthError,
|
AuthError,
|
||||||
|
@ -25,10 +25,19 @@ from synapse.api.errors import (
|
||||||
SynapseError,
|
SynapseError,
|
||||||
)
|
)
|
||||||
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||||
from synapse.types import UserID, create_requester, get_domain_from_id
|
from synapse.types import (
|
||||||
|
JsonDict,
|
||||||
|
Requester,
|
||||||
|
UserID,
|
||||||
|
create_requester,
|
||||||
|
get_domain_from_id,
|
||||||
|
)
|
||||||
|
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.app.homeserver import HomeServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
MAX_DISPLAYNAME_LEN = 256
|
MAX_DISPLAYNAME_LEN = 256
|
||||||
|
@ -45,7 +54,7 @@ class ProfileHandler(BaseHandler):
|
||||||
PROFILE_UPDATE_MS = 60 * 1000
|
PROFILE_UPDATE_MS = 60 * 1000
|
||||||
PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
|
PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
super().__init__(hs)
|
super().__init__(hs)
|
||||||
|
|
||||||
self.federation = hs.get_federation_client()
|
self.federation = hs.get_federation_client()
|
||||||
|
@ -60,7 +69,7 @@ class ProfileHandler(BaseHandler):
|
||||||
self._update_remote_profile_cache, self.PROFILE_UPDATE_MS
|
self._update_remote_profile_cache, self.PROFILE_UPDATE_MS
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_profile(self, user_id):
|
async def get_profile(self, user_id: str) -> JsonDict:
|
||||||
target_user = UserID.from_string(user_id)
|
target_user = UserID.from_string(user_id)
|
||||||
|
|
||||||
if self.hs.is_mine(target_user):
|
if self.hs.is_mine(target_user):
|
||||||
|
@ -91,7 +100,7 @@ class ProfileHandler(BaseHandler):
|
||||||
except HttpResponseException as e:
|
except HttpResponseException as e:
|
||||||
raise e.to_synapse_error()
|
raise e.to_synapse_error()
|
||||||
|
|
||||||
async def get_profile_from_cache(self, user_id):
|
async def get_profile_from_cache(self, user_id: str) -> JsonDict:
|
||||||
"""Get the profile information from our local cache. If the user is
|
"""Get the profile information from our local cache. If the user is
|
||||||
ours then the profile information will always be corect. Otherwise,
|
ours then the profile information will always be corect. Otherwise,
|
||||||
it may be out of date/missing.
|
it may be out of date/missing.
|
||||||
|
@ -115,7 +124,7 @@ class ProfileHandler(BaseHandler):
|
||||||
profile = await self.store.get_from_remote_profile_cache(user_id)
|
profile = await self.store.get_from_remote_profile_cache(user_id)
|
||||||
return profile or {}
|
return profile or {}
|
||||||
|
|
||||||
async def get_displayname(self, target_user):
|
async def get_displayname(self, target_user: UserID) -> str:
|
||||||
if self.hs.is_mine(target_user):
|
if self.hs.is_mine(target_user):
|
||||||
try:
|
try:
|
||||||
displayname = await self.store.get_profile_displayname(
|
displayname = await self.store.get_profile_displayname(
|
||||||
|
@ -143,15 +152,19 @@ class ProfileHandler(BaseHandler):
|
||||||
return result["displayname"]
|
return result["displayname"]
|
||||||
|
|
||||||
async def set_displayname(
|
async def set_displayname(
|
||||||
self, target_user, requester, new_displayname, by_admin=False
|
self,
|
||||||
):
|
target_user: UserID,
|
||||||
|
requester: Requester,
|
||||||
|
new_displayname: str,
|
||||||
|
by_admin: bool = False,
|
||||||
|
) -> None:
|
||||||
"""Set the displayname of a user
|
"""Set the displayname of a user
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
target_user (UserID): the user whose displayname is to be changed.
|
target_user: the user whose displayname is to be changed.
|
||||||
requester (Requester): The user attempting to make this change.
|
requester: The user attempting to make this change.
|
||||||
new_displayname (str): The displayname to give this user.
|
new_displayname: The displayname to give this user.
|
||||||
by_admin (bool): Whether this change was made by an administrator.
|
by_admin: Whether this change was made by an administrator.
|
||||||
"""
|
"""
|
||||||
if not self.hs.is_mine(target_user):
|
if not self.hs.is_mine(target_user):
|
||||||
raise SynapseError(400, "User is not hosted on this homeserver")
|
raise SynapseError(400, "User is not hosted on this homeserver")
|
||||||
|
@ -176,8 +189,9 @@ class ProfileHandler(BaseHandler):
|
||||||
400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,)
|
400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
displayname_to_set = new_displayname # type: Optional[str]
|
||||||
if new_displayname == "":
|
if new_displayname == "":
|
||||||
new_displayname = None
|
displayname_to_set = None
|
||||||
|
|
||||||
# If the admin changes the display name of a user, the requesting user cannot send
|
# If the admin changes the display name of a user, the requesting user cannot send
|
||||||
# the join event to update the displayname in the rooms.
|
# the join event to update the displayname in the rooms.
|
||||||
|
@ -185,7 +199,9 @@ class ProfileHandler(BaseHandler):
|
||||||
if by_admin:
|
if by_admin:
|
||||||
requester = create_requester(target_user)
|
requester = create_requester(target_user)
|
||||||
|
|
||||||
await self.store.set_profile_displayname(target_user.localpart, new_displayname)
|
await self.store.set_profile_displayname(
|
||||||
|
target_user.localpart, displayname_to_set
|
||||||
|
)
|
||||||
|
|
||||||
if self.hs.config.user_directory_search_all_users:
|
if self.hs.config.user_directory_search_all_users:
|
||||||
profile = await self.store.get_profileinfo(target_user.localpart)
|
profile = await self.store.get_profileinfo(target_user.localpart)
|
||||||
|
@ -195,7 +211,7 @@ class ProfileHandler(BaseHandler):
|
||||||
|
|
||||||
await self._update_join_states(requester, target_user)
|
await self._update_join_states(requester, target_user)
|
||||||
|
|
||||||
async def get_avatar_url(self, target_user):
|
async def get_avatar_url(self, target_user: UserID) -> str:
|
||||||
if self.hs.is_mine(target_user):
|
if self.hs.is_mine(target_user):
|
||||||
try:
|
try:
|
||||||
avatar_url = await self.store.get_profile_avatar_url(
|
avatar_url = await self.store.get_profile_avatar_url(
|
||||||
|
@ -222,15 +238,19 @@ class ProfileHandler(BaseHandler):
|
||||||
return result["avatar_url"]
|
return result["avatar_url"]
|
||||||
|
|
||||||
async def set_avatar_url(
|
async def set_avatar_url(
|
||||||
self, target_user, requester, new_avatar_url, by_admin=False
|
self,
|
||||||
|
target_user: UserID,
|
||||||
|
requester: Requester,
|
||||||
|
new_avatar_url: str,
|
||||||
|
by_admin: bool = False,
|
||||||
):
|
):
|
||||||
"""Set a new avatar URL for a user.
|
"""Set a new avatar URL for a user.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
target_user (UserID): the user whose avatar URL is to be changed.
|
target_user: the user whose avatar URL is to be changed.
|
||||||
requester (Requester): The user attempting to make this change.
|
requester: The user attempting to make this change.
|
||||||
new_avatar_url (str): The avatar URL to give this user.
|
new_avatar_url: The avatar URL to give this user.
|
||||||
by_admin (bool): Whether this change was made by an administrator.
|
by_admin: Whether this change was made by an administrator.
|
||||||
"""
|
"""
|
||||||
if not self.hs.is_mine(target_user):
|
if not self.hs.is_mine(target_user):
|
||||||
raise SynapseError(400, "User is not hosted on this homeserver")
|
raise SynapseError(400, "User is not hosted on this homeserver")
|
||||||
|
@ -267,7 +287,7 @@ class ProfileHandler(BaseHandler):
|
||||||
|
|
||||||
await self._update_join_states(requester, target_user)
|
await self._update_join_states(requester, target_user)
|
||||||
|
|
||||||
async def on_profile_query(self, args):
|
async def on_profile_query(self, args: JsonDict) -> JsonDict:
|
||||||
user = UserID.from_string(args["user_id"])
|
user = UserID.from_string(args["user_id"])
|
||||||
if not self.hs.is_mine(user):
|
if not self.hs.is_mine(user):
|
||||||
raise SynapseError(400, "User is not hosted on this homeserver")
|
raise SynapseError(400, "User is not hosted on this homeserver")
|
||||||
|
@ -292,7 +312,9 @@ class ProfileHandler(BaseHandler):
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def _update_join_states(self, requester, target_user):
|
async def _update_join_states(
|
||||||
|
self, requester: Requester, target_user: UserID
|
||||||
|
) -> None:
|
||||||
if not self.hs.is_mine(target_user):
|
if not self.hs.is_mine(target_user):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -323,15 +345,17 @@ class ProfileHandler(BaseHandler):
|
||||||
"Failed to update join event for room %s - %s", room_id, str(e)
|
"Failed to update join event for room %s - %s", room_id, str(e)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def check_profile_query_allowed(self, target_user, requester=None):
|
async def check_profile_query_allowed(
|
||||||
|
self, target_user: UserID, requester: Optional[UserID] = None
|
||||||
|
) -> None:
|
||||||
"""Checks whether a profile query is allowed. If the
|
"""Checks whether a profile query is allowed. If the
|
||||||
'require_auth_for_profile_requests' config flag is set to True and a
|
'require_auth_for_profile_requests' config flag is set to True and a
|
||||||
'requester' is provided, the query is only allowed if the two users
|
'requester' is provided, the query is only allowed if the two users
|
||||||
share a room.
|
share a room.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
target_user (UserID): The owner of the queried profile.
|
target_user: The owner of the queried profile.
|
||||||
requester (None|UserID): The user querying for the profile.
|
requester: The user querying for the profile.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
SynapseError(403): The two users share no room, or ne user couldn't
|
SynapseError(403): The two users share no room, or ne user couldn't
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from synapse.api.errors import StoreError
|
from synapse.api.errors import StoreError
|
||||||
from synapse.storage._base import SQLBaseStore
|
from synapse.storage._base import SQLBaseStore
|
||||||
|
@ -72,7 +72,7 @@ class ProfileWorkerStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def set_profile_displayname(
|
async def set_profile_displayname(
|
||||||
self, user_localpart: str, new_displayname: str
|
self, user_localpart: str, new_displayname: Optional[str]
|
||||||
) -> None:
|
) -> None:
|
||||||
await self.db_pool.simple_update_one(
|
await self.db_pool.simple_update_one(
|
||||||
table="profiles",
|
table="profiles",
|
||||||
|
@ -144,7 +144,7 @@ class ProfileWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
async def get_remote_profile_cache_entries_that_expire(
|
async def get_remote_profile_cache_entries_that_expire(
|
||||||
self, last_checked: int
|
self, last_checked: int
|
||||||
) -> Dict[str, str]:
|
) -> List[Dict[str, str]]:
|
||||||
"""Get all users who haven't been checked since `last_checked`
|
"""Get all users who haven't been checked since `last_checked`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue