Add a type hint for get_device_handler() and fix incorrect types. (#14055)

This was the last untyped handler from the HomeServer object. Since
it was being treated as Any (and thus unchecked) it was being used
incorrectly in a few places.
This commit is contained in:
Patrick Cloke 2022-11-22 14:08:04 -05:00 committed by GitHub
parent 9b4cb1e2ed
commit 6d47b7e325
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 185 additions and 77 deletions

1
changelog.d/14055.misc Normal file
View file

@ -0,0 +1 @@
Add missing type hints to `HomeServer`.

View file

@ -16,6 +16,7 @@ import logging
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.handlers.device import DeviceHandler
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import Codes, Requester, UserID, create_requester from synapse.types import Codes, Requester, UserID, create_requester
@ -76,6 +77,9 @@ class DeactivateAccountHandler:
True if identity server supports removing threepids, otherwise False. True if identity server supports removing threepids, otherwise False.
""" """
# This can only be called on the main process.
assert isinstance(self._device_handler, DeviceHandler)
# Check if this user can be deactivated # Check if this user can be deactivated
if not await self._third_party_rules.check_can_deactivate_user( if not await self._third_party_rules.check_can_deactivate_user(
user_id, by_admin user_id, by_admin

View file

@ -65,6 +65,8 @@ DELETE_STALE_DEVICES_INTERVAL_MS = 24 * 60 * 60 * 1000
class DeviceWorkerHandler: class DeviceWorkerHandler:
device_list_updater: "DeviceListWorkerUpdater"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.hs = hs self.hs = hs
@ -76,6 +78,8 @@ class DeviceWorkerHandler:
self.server_name = hs.hostname self.server_name = hs.hostname
self._msc3852_enabled = hs.config.experimental.msc3852_enabled self._msc3852_enabled = hs.config.experimental.msc3852_enabled
self.device_list_updater = DeviceListWorkerUpdater(hs)
@trace @trace
async def get_devices_by_user(self, user_id: str) -> List[JsonDict]: async def get_devices_by_user(self, user_id: str) -> List[JsonDict]:
""" """
@ -99,6 +103,19 @@ class DeviceWorkerHandler:
log_kv(device_map) log_kv(device_map)
return devices return devices
async def get_dehydrated_device(
self, user_id: str
) -> Optional[Tuple[str, JsonDict]]:
"""Retrieve the information for a dehydrated device.
Args:
user_id: the user whose dehydrated device we are looking for
Returns:
a tuple whose first item is the device ID, and the second item is
the dehydrated device information
"""
return await self.store.get_dehydrated_device(user_id)
@trace @trace
async def get_device(self, user_id: str, device_id: str) -> JsonDict: async def get_device(self, user_id: str, device_id: str) -> JsonDict:
"""Retrieve the given device """Retrieve the given device
@ -127,7 +144,7 @@ class DeviceWorkerHandler:
@cancellable @cancellable
async def get_device_changes_in_shared_rooms( async def get_device_changes_in_shared_rooms(
self, user_id: str, room_ids: Collection[str], from_token: StreamToken self, user_id: str, room_ids: Collection[str], from_token: StreamToken
) -> Collection[str]: ) -> Set[str]:
"""Get the set of users whose devices have changed who share a room with """Get the set of users whose devices have changed who share a room with
the given user. the given user.
""" """
@ -320,6 +337,8 @@ class DeviceWorkerHandler:
class DeviceHandler(DeviceWorkerHandler): class DeviceHandler(DeviceWorkerHandler):
device_list_updater: "DeviceListUpdater"
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
@ -606,19 +625,6 @@ class DeviceHandler(DeviceWorkerHandler):
await self.delete_devices(user_id, [old_device_id]) await self.delete_devices(user_id, [old_device_id])
return device_id return device_id
async def get_dehydrated_device(
self, user_id: str
) -> Optional[Tuple[str, JsonDict]]:
"""Retrieve the information for a dehydrated device.
Args:
user_id: the user whose dehydrated device we are looking for
Returns:
a tuple whose first item is the device ID, and the second item is
the dehydrated device information
"""
return await self.store.get_dehydrated_device(user_id)
async def rehydrate_device( async def rehydrate_device(
self, user_id: str, access_token: str, device_id: str self, user_id: str, access_token: str, device_id: str
) -> dict: ) -> dict:
@ -882,7 +888,36 @@ def _update_device_from_client_ips(
) )
class DeviceListUpdater: class DeviceListWorkerUpdater:
"Handles incoming device list updates from federation and contacts the main process over replication"
def __init__(self, hs: "HomeServer"):
from synapse.replication.http.devices import (
ReplicationUserDevicesResyncRestServlet,
)
self._user_device_resync_client = (
ReplicationUserDevicesResyncRestServlet.make_client(hs)
)
async def user_device_resync(
self, user_id: str, mark_failed_as_stale: bool = True
) -> Optional[JsonDict]:
"""Fetches all devices for a user and updates the device cache with them.
Args:
user_id: The user's id whose device_list will be updated.
mark_failed_as_stale: Whether to mark the user's device list as stale
if the attempt to resync failed.
Returns:
A dict with device info as under the "devices" in the result of this
request:
https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
"""
return await self._user_device_resync_client(user_id=user_id)
class DeviceListUpdater(DeviceListWorkerUpdater):
"Handles incoming device list updates from federation and updates the DB" "Handles incoming device list updates from federation and updates the DB"
def __init__(self, hs: "HomeServer", device_handler: DeviceHandler): def __init__(self, hs: "HomeServer", device_handler: DeviceHandler):

View file

@ -27,9 +27,9 @@ from twisted.internet import defer
from synapse.api.constants import EduTypes from synapse.api.constants import EduTypes
from synapse.api.errors import CodeMessageException, Codes, NotFoundError, SynapseError from synapse.api.errors import CodeMessageException, Codes, NotFoundError, SynapseError
from synapse.handlers.device import DeviceHandler
from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
from synapse.types import ( from synapse.types import (
JsonDict, JsonDict,
UserID, UserID,
@ -56,27 +56,23 @@ class E2eKeysHandler:
self.is_mine = hs.is_mine self.is_mine = hs.is_mine
self.clock = hs.get_clock() self.clock = hs.get_clock()
self._edu_updater = SigningKeyEduUpdater(hs, self)
federation_registry = hs.get_federation_registry() federation_registry = hs.get_federation_registry()
self._is_master = hs.config.worker.worker_app is None is_master = hs.config.worker.worker_app is None
if not self._is_master: if is_master:
self._user_device_resync_client = ( edu_updater = SigningKeyEduUpdater(hs)
ReplicationUserDevicesResyncRestServlet.make_client(hs)
)
else:
# Only register this edu handler on master as it requires writing # Only register this edu handler on master as it requires writing
# device updates to the db # device updates to the db
federation_registry.register_edu_handler( federation_registry.register_edu_handler(
EduTypes.SIGNING_KEY_UPDATE, EduTypes.SIGNING_KEY_UPDATE,
self._edu_updater.incoming_signing_key_update, edu_updater.incoming_signing_key_update,
) )
# also handle the unstable version # also handle the unstable version
# FIXME: remove this when enough servers have upgraded # FIXME: remove this when enough servers have upgraded
federation_registry.register_edu_handler( federation_registry.register_edu_handler(
EduTypes.UNSTABLE_SIGNING_KEY_UPDATE, EduTypes.UNSTABLE_SIGNING_KEY_UPDATE,
self._edu_updater.incoming_signing_key_update, edu_updater.incoming_signing_key_update,
) )
# doesn't really work as part of the generic query API, because the # doesn't really work as part of the generic query API, because the
@ -319,14 +315,13 @@ class E2eKeysHandler:
# probably be tracking their device lists. However, we haven't # probably be tracking their device lists. However, we haven't
# done an initial sync on the device list so we do it now. # done an initial sync on the device list so we do it now.
try: try:
if self._is_master: resync_results = (
resync_results = await self.device_handler.device_list_updater.user_device_resync( await self.device_handler.device_list_updater.user_device_resync(
user_id user_id
) )
else: )
resync_results = await self._user_device_resync_client( if resync_results is None:
user_id=user_id raise ValueError("Device resync failed")
)
# Add the device keys to the results. # Add the device keys to the results.
user_devices = resync_results["devices"] user_devices = resync_results["devices"]
@ -605,6 +600,8 @@ class E2eKeysHandler:
async def upload_keys_for_user( async def upload_keys_for_user(
self, user_id: str, device_id: str, keys: JsonDict self, user_id: str, device_id: str, keys: JsonDict
) -> JsonDict: ) -> JsonDict:
# This can only be called from the main process.
assert isinstance(self.device_handler, DeviceHandler)
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
@ -732,6 +729,8 @@ class E2eKeysHandler:
user_id: the user uploading the keys user_id: the user uploading the keys
keys: the signing keys keys: the signing keys
""" """
# This can only be called from the main process.
assert isinstance(self.device_handler, DeviceHandler)
# if a master key is uploaded, then check it. Otherwise, load the # if a master key is uploaded, then check it. Otherwise, load the
# stored master key, to check signatures on other keys # stored master key, to check signatures on other keys
@ -823,6 +822,9 @@ class E2eKeysHandler:
Raises: Raises:
SynapseError: if the signatures dict is not valid. SynapseError: if the signatures dict is not valid.
""" """
# This can only be called from the main process.
assert isinstance(self.device_handler, DeviceHandler)
failures = {} failures = {}
# signatures to be stored. Each item will be a SignatureListItem # signatures to be stored. Each item will be a SignatureListItem
@ -1200,6 +1202,9 @@ class E2eKeysHandler:
A tuple of the retrieved key content, the key's ID and the matching VerifyKey. A tuple of the retrieved key content, the key's ID and the matching VerifyKey.
If the key cannot be retrieved, all values in the tuple will instead be None. If the key cannot be retrieved, all values in the tuple will instead be None.
""" """
# This can only be called from the main process.
assert isinstance(self.device_handler, DeviceHandler)
try: try:
remote_result = await self.federation.query_user_devices( remote_result = await self.federation.query_user_devices(
user.domain, user.to_string() user.domain, user.to_string()
@ -1396,11 +1401,14 @@ class SignatureListItem:
class SigningKeyEduUpdater: class SigningKeyEduUpdater:
"""Handles incoming signing key updates from federation and updates the DB""" """Handles incoming signing key updates from federation and updates the DB"""
def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.federation = hs.get_federation_client() self.federation = hs.get_federation_client()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.e2e_keys_handler = e2e_keys_handler
device_handler = hs.get_device_handler()
assert isinstance(device_handler, DeviceHandler)
self._device_handler = device_handler
self._remote_edu_linearizer = Linearizer(name="remote_signing_key") self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
@ -1445,9 +1453,6 @@ class SigningKeyEduUpdater:
user_id: the user whose updates we are processing user_id: the user whose updates we are processing
""" """
device_handler = self.e2e_keys_handler.device_handler
device_list_updater = device_handler.device_list_updater
async with self._remote_edu_linearizer.queue(user_id): async with self._remote_edu_linearizer.queue(user_id):
pending_updates = self._pending_updates.pop(user_id, []) pending_updates = self._pending_updates.pop(user_id, [])
if not pending_updates: if not pending_updates:
@ -1459,13 +1464,11 @@ class SigningKeyEduUpdater:
logger.info("pending updates: %r", pending_updates) logger.info("pending updates: %r", pending_updates)
for master_key, self_signing_key in pending_updates: for master_key, self_signing_key in pending_updates:
new_device_ids = ( new_device_ids = await self._device_handler.device_list_updater.process_cross_signing_key_update(
await device_list_updater.process_cross_signing_key_update( user_id,
user_id, master_key,
master_key, self_signing_key,
self_signing_key,
)
) )
device_ids = device_ids + new_device_ids device_ids = device_ids + new_device_ids
await device_handler.notify_device_update(user_id, device_ids) await self._device_handler.notify_device_update(user_id, device_ids)

View file

@ -38,6 +38,7 @@ from synapse.api.errors import (
) )
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.config.server import is_threepid_reserved from synapse.config.server import is_threepid_reserved
from synapse.handlers.device import DeviceHandler
from synapse.http.servlet import assert_params_in_dict from synapse.http.servlet import assert_params_in_dict
from synapse.replication.http.login import RegisterDeviceReplicationServlet from synapse.replication.http.login import RegisterDeviceReplicationServlet
from synapse.replication.http.register import ( from synapse.replication.http.register import (
@ -841,6 +842,9 @@ class RegistrationHandler:
refresh_token = None refresh_token = None
refresh_token_id = None refresh_token_id = None
# This can only run on the main process.
assert isinstance(self.device_handler, DeviceHandler)
registered_device_id = await self.device_handler.check_device_registered( registered_device_id = await self.device_handler.check_device_registered(
user_id, user_id,
device_id, device_id,

View file

@ -15,6 +15,7 @@ import logging
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
from synapse.api.errors import Codes, StoreError, SynapseError from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.handlers.device import DeviceHandler
from synapse.types import Requester from synapse.types import Requester
if TYPE_CHECKING: if TYPE_CHECKING:
@ -29,7 +30,10 @@ class SetPasswordHandler:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler() # This can only be instantiated on the main process.
device_handler = hs.get_device_handler()
assert isinstance(device_handler, DeviceHandler)
self._device_handler = device_handler
async def set_password( async def set_password(
self, self,

View file

@ -37,6 +37,7 @@ from twisted.web.server import Request
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
from synapse.config.sso import SsoAttributeRequirement from synapse.config.sso import SsoAttributeRequirement
from synapse.handlers.device import DeviceHandler
from synapse.handlers.register import init_counters_for_auth_provider from synapse.handlers.register import init_counters_for_auth_provider
from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http import get_request_user_agent from synapse.http import get_request_user_agent
@ -1035,6 +1036,8 @@ class SsoHandler:
) -> None: ) -> None:
"""Revoke any devices and in-flight logins tied to a provider session. """Revoke any devices and in-flight logins tied to a provider session.
Can only be called from the main process.
Args: Args:
auth_provider_id: A unique identifier for this SSO provider, e.g. auth_provider_id: A unique identifier for this SSO provider, e.g.
"oidc" or "saml". "oidc" or "saml".
@ -1042,6 +1045,12 @@ class SsoHandler:
expected_user_id: The user we're expecting to logout. If set, it will ignore expected_user_id: The user we're expecting to logout. If set, it will ignore
sessions belonging to other users and log an error. sessions belonging to other users and log an error.
""" """
# It is expected that this is the main process.
assert isinstance(
self._device_handler, DeviceHandler
), "revoking SSO sessions can only be called on the main process"
# Invalidate any running user-mapping sessions # Invalidate any running user-mapping sessions
to_delete = [] to_delete = []
for session_id, session in self._username_mapping_sessions.items(): for session_id, session in self._username_mapping_sessions.items():

View file

@ -86,6 +86,7 @@ from synapse.handlers.auth import (
ON_LOGGED_OUT_CALLBACK, ON_LOGGED_OUT_CALLBACK,
AuthHandler, AuthHandler,
) )
from synapse.handlers.device import DeviceHandler
from synapse.handlers.push_rules import RuleSpec, check_actions from synapse.handlers.push_rules import RuleSpec, check_actions
from synapse.http.client import SimpleHttpClient from synapse.http.client import SimpleHttpClient
from synapse.http.server import ( from synapse.http.server import (
@ -207,6 +208,7 @@ class ModuleApi:
self._registration_handler = hs.get_registration_handler() self._registration_handler = hs.get_registration_handler()
self._send_email_handler = hs.get_send_email_handler() self._send_email_handler = hs.get_send_email_handler()
self._push_rules_handler = hs.get_push_rules_handler() self._push_rules_handler = hs.get_push_rules_handler()
self._device_handler = hs.get_device_handler()
self.custom_template_dir = hs.config.server.custom_template_directory self.custom_template_dir = hs.config.server.custom_template_directory
try: try:
@ -784,6 +786,8 @@ class ModuleApi:
) -> Generator["defer.Deferred[Any]", Any, None]: ) -> Generator["defer.Deferred[Any]", Any, None]:
"""Invalidate an access token for a user """Invalidate an access token for a user
Can only be called from the main process.
Added in Synapse v0.25.0. Added in Synapse v0.25.0.
Args: Args:
@ -796,6 +800,10 @@ class ModuleApi:
Raises: Raises:
synapse.api.errors.AuthError: the access token is invalid synapse.api.errors.AuthError: the access token is invalid
""" """
assert isinstance(
self._device_handler, DeviceHandler
), "invalidate_access_token can only be called on the main process"
# see if the access token corresponds to a device # see if the access token corresponds to a device
user_info = yield defer.ensureDeferred( user_info = yield defer.ensureDeferred(
self._auth.get_user_by_access_token(access_token) self._auth.get_user_by_access_token(access_token)
@ -805,7 +813,7 @@ class ModuleApi:
if device_id: if device_id:
# delete the device, which will also delete its access tokens # delete the device, which will also delete its access tokens
yield defer.ensureDeferred( yield defer.ensureDeferred(
self._hs.get_device_handler().delete_devices(user_id, [device_id]) self._device_handler.delete_devices(user_id, [device_id])
) )
else: else:
# no associated device. Just delete the access token. # no associated device. Just delete the access token.

View file

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Tuple from typing import TYPE_CHECKING, Optional, Tuple
from twisted.web.server import Request from twisted.web.server import Request
@ -63,7 +63,12 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.device_list_updater = hs.get_device_handler().device_list_updater from synapse.handlers.device import DeviceHandler
handler = hs.get_device_handler()
assert isinstance(handler, DeviceHandler)
self.device_list_updater = handler.device_list_updater
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.clock = hs.get_clock() self.clock = hs.get_clock()
@ -73,7 +78,7 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
async def _handle_request( # type: ignore[override] async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str self, request: Request, user_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, Optional[JsonDict]]:
user_devices = await self.device_list_updater.user_device_resync(user_id) user_devices = await self.device_list_updater.user_device_resync(user_id)
return 200, user_devices return 200, user_devices

View file

@ -238,6 +238,10 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
""" """
Register all the admin servlets. Register all the admin servlets.
""" """
# Admin servlets aren't registered on workers.
if hs.config.worker.worker_app is not None:
return
register_servlets_for_client_rest_resource(hs, http_server) register_servlets_for_client_rest_resource(hs, http_server)
BlockRoomRestServlet(hs).register(http_server) BlockRoomRestServlet(hs).register(http_server)
ListRoomRestServlet(hs).register(http_server) ListRoomRestServlet(hs).register(http_server)
@ -254,9 +258,6 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
UserTokenRestServlet(hs).register(http_server) UserTokenRestServlet(hs).register(http_server)
UserRestServletV2(hs).register(http_server) UserRestServletV2(hs).register(http_server)
UsersRestServletV2(hs).register(http_server) UsersRestServletV2(hs).register(http_server)
DeviceRestServlet(hs).register(http_server)
DevicesRestServlet(hs).register(http_server)
DeleteDevicesRestServlet(hs).register(http_server)
UserMediaStatisticsRestServlet(hs).register(http_server) UserMediaStatisticsRestServlet(hs).register(http_server)
EventReportDetailRestServlet(hs).register(http_server) EventReportDetailRestServlet(hs).register(http_server)
EventReportsRestServlet(hs).register(http_server) EventReportsRestServlet(hs).register(http_server)
@ -280,12 +281,13 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
UserByExternalId(hs).register(http_server) UserByExternalId(hs).register(http_server)
UserByThreePid(hs).register(http_server) UserByThreePid(hs).register(http_server)
# Some servlets only get registered for the main process. DeviceRestServlet(hs).register(http_server)
if hs.config.worker.worker_app is None: DevicesRestServlet(hs).register(http_server)
SendServerNoticeServlet(hs).register(http_server) DeleteDevicesRestServlet(hs).register(http_server)
BackgroundUpdateEnabledRestServlet(hs).register(http_server) SendServerNoticeServlet(hs).register(http_server)
BackgroundUpdateRestServlet(hs).register(http_server) BackgroundUpdateEnabledRestServlet(hs).register(http_server)
BackgroundUpdateStartJobRestServlet(hs).register(http_server) BackgroundUpdateRestServlet(hs).register(http_server)
BackgroundUpdateStartJobRestServlet(hs).register(http_server)
def register_servlets_for_client_rest_resource( def register_servlets_for_client_rest_resource(
@ -294,9 +296,11 @@ def register_servlets_for_client_rest_resource(
"""Register only the servlets which need to be exposed on /_matrix/client/xxx""" """Register only the servlets which need to be exposed on /_matrix/client/xxx"""
WhoisRestServlet(hs).register(http_server) WhoisRestServlet(hs).register(http_server)
PurgeHistoryStatusRestServlet(hs).register(http_server) PurgeHistoryStatusRestServlet(hs).register(http_server)
DeactivateAccountRestServlet(hs).register(http_server)
PurgeHistoryRestServlet(hs).register(http_server) PurgeHistoryRestServlet(hs).register(http_server)
ResetPasswordRestServlet(hs).register(http_server) # The following resources can only be run on the main process.
if hs.config.worker.worker_app is None:
DeactivateAccountRestServlet(hs).register(http_server)
ResetPasswordRestServlet(hs).register(http_server)
SearchUsersRestServlet(hs).register(http_server) SearchUsersRestServlet(hs).register(http_server)
UserRegisterServlet(hs).register(http_server) UserRegisterServlet(hs).register(http_server)
AccountValidityRenewServlet(hs).register(http_server) AccountValidityRenewServlet(hs).register(http_server)

View file

@ -16,6 +16,7 @@ from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import NotFoundError, SynapseError from synapse.api.errors import NotFoundError, SynapseError
from synapse.handlers.device import DeviceHandler
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, RestServlet,
assert_params_in_dict, assert_params_in_dict,
@ -43,7 +44,9 @@ class DeviceRestServlet(RestServlet):
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler() handler = hs.get_device_handler()
assert isinstance(handler, DeviceHandler)
self.device_handler = handler
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.is_mine = hs.is_mine self.is_mine = hs.is_mine
@ -112,7 +115,9 @@ class DevicesRestServlet(RestServlet):
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler() handler = hs.get_device_handler()
assert isinstance(handler, DeviceHandler)
self.device_handler = handler
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.is_mine = hs.is_mine self.is_mine = hs.is_mine
@ -143,7 +148,9 @@ class DeleteDevicesRestServlet(RestServlet):
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler() handler = hs.get_device_handler()
assert isinstance(handler, DeviceHandler)
self.device_handler = handler
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.is_mine = hs.is_mine self.is_mine = hs.is_mine

View file

@ -20,6 +20,7 @@ from pydantic import Extra, StrictStr
from synapse.api import errors from synapse.api import errors
from synapse.api.errors import NotFoundError from synapse.api.errors import NotFoundError
from synapse.handlers.device import DeviceHandler
from synapse.http.server import HttpServer from synapse.http.server import HttpServer
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, RestServlet,
@ -80,7 +81,9 @@ class DeleteDevicesRestServlet(RestServlet):
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler() handler = hs.get_device_handler()
assert isinstance(handler, DeviceHandler)
self.device_handler = handler
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
class PostBody(RequestBodyModel): class PostBody(RequestBodyModel):
@ -125,7 +128,9 @@ class DeviceRestServlet(RestServlet):
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler() handler = hs.get_device_handler()
assert isinstance(handler, DeviceHandler)
self.device_handler = handler
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
self._msc3852_enabled = hs.config.experimental.msc3852_enabled self._msc3852_enabled = hs.config.experimental.msc3852_enabled
@ -256,7 +261,9 @@ class DehydratedDeviceServlet(RestServlet):
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler() handler = hs.get_device_handler()
assert isinstance(handler, DeviceHandler)
self.device_handler = handler
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
@ -313,7 +320,9 @@ class ClaimDehydratedDeviceServlet(RestServlet):
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler() handler = hs.get_device_handler()
assert isinstance(handler, DeviceHandler)
self.device_handler = handler
class PostBody(RequestBodyModel): class PostBody(RequestBodyModel):
device_id: StrictStr device_id: StrictStr

View file

@ -15,6 +15,7 @@
import logging import logging
from typing import TYPE_CHECKING, Tuple from typing import TYPE_CHECKING, Tuple
from synapse.handlers.device import DeviceHandler
from synapse.http.server import HttpServer from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
@ -34,7 +35,9 @@ class LogoutRestServlet(RestServlet):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler() handler = hs.get_device_handler()
assert isinstance(handler, DeviceHandler)
self._device_handler = handler
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_expired=True) requester = await self.auth.get_user_by_req(request, allow_expired=True)
@ -59,7 +62,9 @@ class LogoutAllRestServlet(RestServlet):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler() handler = hs.get_device_handler()
assert isinstance(handler, DeviceHandler)
self._device_handler = handler
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_expired=True) requester = await self.auth.get_user_by_req(request, allow_expired=True)

View file

@ -510,7 +510,7 @@ class HomeServer(metaclass=abc.ABCMeta):
) )
@cache_in_self @cache_in_self
def get_device_handler(self): def get_device_handler(self) -> DeviceWorkerHandler:
if self.config.worker.worker_app: if self.config.worker.worker_app:
return DeviceWorkerHandler(self) return DeviceWorkerHandler(self)
else: else:

View file

@ -19,7 +19,7 @@ from typing import Optional
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import NotFoundError, SynapseError from synapse.api.errors import NotFoundError, SynapseError
from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.util import Clock from synapse.util import Clock
@ -32,7 +32,9 @@ user2 = "@theresa:bbb"
class DeviceTestCase(unittest.HomeserverTestCase): class DeviceTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver("server", federation_http_client=None) hs = self.setup_test_homeserver("server", federation_http_client=None)
self.handler = hs.get_device_handler() handler = hs.get_device_handler()
assert isinstance(handler, DeviceHandler)
self.handler = handler
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
return hs return hs
@ -61,6 +63,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
self.assertEqual(res, "fco") self.assertEqual(res, "fco")
dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco")) dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
assert dev is not None
self.assertEqual(dev["display_name"], "display name") self.assertEqual(dev["display_name"], "display name")
def test_device_is_preserved_if_exists(self) -> None: def test_device_is_preserved_if_exists(self) -> None:
@ -83,6 +86,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
self.assertEqual(res2, "fco") self.assertEqual(res2, "fco")
dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco")) dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
assert dev is not None
self.assertEqual(dev["display_name"], "display name") self.assertEqual(dev["display_name"], "display name")
def test_device_id_is_made_up_if_unspecified(self) -> None: def test_device_id_is_made_up_if_unspecified(self) -> None:
@ -95,6 +99,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
) )
dev = self.get_success(self.handler.store.get_device("@theresa:foo", device_id)) dev = self.get_success(self.handler.store.get_device("@theresa:foo", device_id))
assert dev is not None
self.assertEqual(dev["display_name"], "display") self.assertEqual(dev["display_name"], "display")
def test_get_devices_by_user(self) -> None: def test_get_devices_by_user(self) -> None:
@ -264,7 +269,9 @@ class DeviceTestCase(unittest.HomeserverTestCase):
class DehydrationTestCase(unittest.HomeserverTestCase): class DehydrationTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver("server", federation_http_client=None) hs = self.setup_test_homeserver("server", federation_http_client=None)
self.handler = hs.get_device_handler() handler = hs.get_device_handler()
assert isinstance(handler, DeviceHandler)
self.handler = handler
self.registration = hs.get_registration_handler() self.registration = hs.get_registration_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
@ -284,9 +291,9 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
) )
) )
retrieved_device_id, device_data = self.get_success( result = self.get_success(self.handler.get_dehydrated_device(user_id=user_id))
self.handler.get_dehydrated_device(user_id=user_id) assert result is not None
) retrieved_device_id, device_data = result
self.assertEqual(retrieved_device_id, stored_dehydrated_device_id) self.assertEqual(retrieved_device_id, stored_dehydrated_device_id)
self.assertEqual(device_data, {"device_data": {"foo": "bar"}}) self.assertEqual(device_data, {"device_data": {"foo": "bar"}})

View file

@ -19,6 +19,7 @@ from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin import synapse.rest.admin
from synapse.api.errors import Codes from synapse.api.errors import Codes
from synapse.handlers.device import DeviceHandler
from synapse.rest.client import login from synapse.rest.client import login
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.util import Clock from synapse.util import Clock
@ -34,7 +35,9 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
] ]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.handler = hs.get_device_handler() handler = hs.get_device_handler()
assert isinstance(handler, DeviceHandler)
self.handler = handler
self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass") self.admin_user_tok = self.login("admin", "pass")