forked from MirrorHub/synapse
Add a type hint for get_device_handler()
and fix incorrect types. (#14055)
This was the last untyped handler from the HomeServer object. Since it was being treated as Any (and thus unchecked) it was being used incorrectly in a few places.
This commit is contained in:
parent
9b4cb1e2ed
commit
6d47b7e325
16 changed files with 185 additions and 77 deletions
1
changelog.d/14055.misc
Normal file
1
changelog.d/14055.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Add missing type hints to `HomeServer`.
|
|
@ -16,6 +16,7 @@ import logging
|
|||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.handlers.device import DeviceHandler
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.types import Codes, Requester, UserID, create_requester
|
||||
|
||||
|
@ -76,6 +77,9 @@ class DeactivateAccountHandler:
|
|||
True if identity server supports removing threepids, otherwise False.
|
||||
"""
|
||||
|
||||
# This can only be called on the main process.
|
||||
assert isinstance(self._device_handler, DeviceHandler)
|
||||
|
||||
# Check if this user can be deactivated
|
||||
if not await self._third_party_rules.check_can_deactivate_user(
|
||||
user_id, by_admin
|
||||
|
|
|
@ -65,6 +65,8 @@ DELETE_STALE_DEVICES_INTERVAL_MS = 24 * 60 * 60 * 1000
|
|||
|
||||
|
||||
class DeviceWorkerHandler:
|
||||
device_list_updater: "DeviceListWorkerUpdater"
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.clock = hs.get_clock()
|
||||
self.hs = hs
|
||||
|
@ -76,6 +78,8 @@ class DeviceWorkerHandler:
|
|||
self.server_name = hs.hostname
|
||||
self._msc3852_enabled = hs.config.experimental.msc3852_enabled
|
||||
|
||||
self.device_list_updater = DeviceListWorkerUpdater(hs)
|
||||
|
||||
@trace
|
||||
async def get_devices_by_user(self, user_id: str) -> List[JsonDict]:
|
||||
"""
|
||||
|
@ -99,6 +103,19 @@ class DeviceWorkerHandler:
|
|||
log_kv(device_map)
|
||||
return devices
|
||||
|
||||
async def get_dehydrated_device(
|
||||
self, user_id: str
|
||||
) -> Optional[Tuple[str, JsonDict]]:
|
||||
"""Retrieve the information for a dehydrated device.
|
||||
|
||||
Args:
|
||||
user_id: the user whose dehydrated device we are looking for
|
||||
Returns:
|
||||
a tuple whose first item is the device ID, and the second item is
|
||||
the dehydrated device information
|
||||
"""
|
||||
return await self.store.get_dehydrated_device(user_id)
|
||||
|
||||
@trace
|
||||
async def get_device(self, user_id: str, device_id: str) -> JsonDict:
|
||||
"""Retrieve the given device
|
||||
|
@ -127,7 +144,7 @@ class DeviceWorkerHandler:
|
|||
@cancellable
|
||||
async def get_device_changes_in_shared_rooms(
|
||||
self, user_id: str, room_ids: Collection[str], from_token: StreamToken
|
||||
) -> Collection[str]:
|
||||
) -> Set[str]:
|
||||
"""Get the set of users whose devices have changed who share a room with
|
||||
the given user.
|
||||
"""
|
||||
|
@ -320,6 +337,8 @@ class DeviceWorkerHandler:
|
|||
|
||||
|
||||
class DeviceHandler(DeviceWorkerHandler):
|
||||
device_list_updater: "DeviceListUpdater"
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
|
||||
|
@ -606,19 +625,6 @@ class DeviceHandler(DeviceWorkerHandler):
|
|||
await self.delete_devices(user_id, [old_device_id])
|
||||
return device_id
|
||||
|
||||
async def get_dehydrated_device(
|
||||
self, user_id: str
|
||||
) -> Optional[Tuple[str, JsonDict]]:
|
||||
"""Retrieve the information for a dehydrated device.
|
||||
|
||||
Args:
|
||||
user_id: the user whose dehydrated device we are looking for
|
||||
Returns:
|
||||
a tuple whose first item is the device ID, and the second item is
|
||||
the dehydrated device information
|
||||
"""
|
||||
return await self.store.get_dehydrated_device(user_id)
|
||||
|
||||
async def rehydrate_device(
|
||||
self, user_id: str, access_token: str, device_id: str
|
||||
) -> dict:
|
||||
|
@ -882,7 +888,36 @@ def _update_device_from_client_ips(
|
|||
)
|
||||
|
||||
|
||||
class DeviceListUpdater:
|
||||
class DeviceListWorkerUpdater:
|
||||
"Handles incoming device list updates from federation and contacts the main process over replication"
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
from synapse.replication.http.devices import (
|
||||
ReplicationUserDevicesResyncRestServlet,
|
||||
)
|
||||
|
||||
self._user_device_resync_client = (
|
||||
ReplicationUserDevicesResyncRestServlet.make_client(hs)
|
||||
)
|
||||
|
||||
async def user_device_resync(
|
||||
self, user_id: str, mark_failed_as_stale: bool = True
|
||||
) -> Optional[JsonDict]:
|
||||
"""Fetches all devices for a user and updates the device cache with them.
|
||||
|
||||
Args:
|
||||
user_id: The user's id whose device_list will be updated.
|
||||
mark_failed_as_stale: Whether to mark the user's device list as stale
|
||||
if the attempt to resync failed.
|
||||
Returns:
|
||||
A dict with device info as under the "devices" in the result of this
|
||||
request:
|
||||
https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
|
||||
"""
|
||||
return await self._user_device_resync_client(user_id=user_id)
|
||||
|
||||
|
||||
class DeviceListUpdater(DeviceListWorkerUpdater):
|
||||
"Handles incoming device list updates from federation and updates the DB"
|
||||
|
||||
def __init__(self, hs: "HomeServer", device_handler: DeviceHandler):
|
||||
|
|
|
@ -27,9 +27,9 @@ from twisted.internet import defer
|
|||
|
||||
from synapse.api.constants import EduTypes
|
||||
from synapse.api.errors import CodeMessageException, Codes, NotFoundError, SynapseError
|
||||
from synapse.handlers.device import DeviceHandler
|
||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||
from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
|
||||
from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
|
||||
from synapse.types import (
|
||||
JsonDict,
|
||||
UserID,
|
||||
|
@ -56,27 +56,23 @@ class E2eKeysHandler:
|
|||
self.is_mine = hs.is_mine
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
self._edu_updater = SigningKeyEduUpdater(hs, self)
|
||||
|
||||
federation_registry = hs.get_federation_registry()
|
||||
|
||||
self._is_master = hs.config.worker.worker_app is None
|
||||
if not self._is_master:
|
||||
self._user_device_resync_client = (
|
||||
ReplicationUserDevicesResyncRestServlet.make_client(hs)
|
||||
)
|
||||
else:
|
||||
is_master = hs.config.worker.worker_app is None
|
||||
if is_master:
|
||||
edu_updater = SigningKeyEduUpdater(hs)
|
||||
|
||||
# Only register this edu handler on master as it requires writing
|
||||
# device updates to the db
|
||||
federation_registry.register_edu_handler(
|
||||
EduTypes.SIGNING_KEY_UPDATE,
|
||||
self._edu_updater.incoming_signing_key_update,
|
||||
edu_updater.incoming_signing_key_update,
|
||||
)
|
||||
# also handle the unstable version
|
||||
# FIXME: remove this when enough servers have upgraded
|
||||
federation_registry.register_edu_handler(
|
||||
EduTypes.UNSTABLE_SIGNING_KEY_UPDATE,
|
||||
self._edu_updater.incoming_signing_key_update,
|
||||
edu_updater.incoming_signing_key_update,
|
||||
)
|
||||
|
||||
# doesn't really work as part of the generic query API, because the
|
||||
|
@ -319,14 +315,13 @@ class E2eKeysHandler:
|
|||
# probably be tracking their device lists. However, we haven't
|
||||
# done an initial sync on the device list so we do it now.
|
||||
try:
|
||||
if self._is_master:
|
||||
resync_results = await self.device_handler.device_list_updater.user_device_resync(
|
||||
resync_results = (
|
||||
await self.device_handler.device_list_updater.user_device_resync(
|
||||
user_id
|
||||
)
|
||||
else:
|
||||
resync_results = await self._user_device_resync_client(
|
||||
user_id=user_id
|
||||
)
|
||||
)
|
||||
if resync_results is None:
|
||||
raise ValueError("Device resync failed")
|
||||
|
||||
# Add the device keys to the results.
|
||||
user_devices = resync_results["devices"]
|
||||
|
@ -605,6 +600,8 @@ class E2eKeysHandler:
|
|||
async def upload_keys_for_user(
|
||||
self, user_id: str, device_id: str, keys: JsonDict
|
||||
) -> JsonDict:
|
||||
# This can only be called from the main process.
|
||||
assert isinstance(self.device_handler, DeviceHandler)
|
||||
|
||||
time_now = self.clock.time_msec()
|
||||
|
||||
|
@ -732,6 +729,8 @@ class E2eKeysHandler:
|
|||
user_id: the user uploading the keys
|
||||
keys: the signing keys
|
||||
"""
|
||||
# This can only be called from the main process.
|
||||
assert isinstance(self.device_handler, DeviceHandler)
|
||||
|
||||
# if a master key is uploaded, then check it. Otherwise, load the
|
||||
# stored master key, to check signatures on other keys
|
||||
|
@ -823,6 +822,9 @@ class E2eKeysHandler:
|
|||
Raises:
|
||||
SynapseError: if the signatures dict is not valid.
|
||||
"""
|
||||
# This can only be called from the main process.
|
||||
assert isinstance(self.device_handler, DeviceHandler)
|
||||
|
||||
failures = {}
|
||||
|
||||
# signatures to be stored. Each item will be a SignatureListItem
|
||||
|
@ -1200,6 +1202,9 @@ class E2eKeysHandler:
|
|||
A tuple of the retrieved key content, the key's ID and the matching VerifyKey.
|
||||
If the key cannot be retrieved, all values in the tuple will instead be None.
|
||||
"""
|
||||
# This can only be called from the main process.
|
||||
assert isinstance(self.device_handler, DeviceHandler)
|
||||
|
||||
try:
|
||||
remote_result = await self.federation.query_user_devices(
|
||||
user.domain, user.to_string()
|
||||
|
@ -1396,11 +1401,14 @@ class SignatureListItem:
|
|||
class SigningKeyEduUpdater:
|
||||
"""Handles incoming signing key updates from federation and updates the DB"""
|
||||
|
||||
def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastores().main
|
||||
self.federation = hs.get_federation_client()
|
||||
self.clock = hs.get_clock()
|
||||
self.e2e_keys_handler = e2e_keys_handler
|
||||
|
||||
device_handler = hs.get_device_handler()
|
||||
assert isinstance(device_handler, DeviceHandler)
|
||||
self._device_handler = device_handler
|
||||
|
||||
self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
|
||||
|
||||
|
@ -1445,9 +1453,6 @@ class SigningKeyEduUpdater:
|
|||
user_id: the user whose updates we are processing
|
||||
"""
|
||||
|
||||
device_handler = self.e2e_keys_handler.device_handler
|
||||
device_list_updater = device_handler.device_list_updater
|
||||
|
||||
async with self._remote_edu_linearizer.queue(user_id):
|
||||
pending_updates = self._pending_updates.pop(user_id, [])
|
||||
if not pending_updates:
|
||||
|
@ -1459,13 +1464,11 @@ class SigningKeyEduUpdater:
|
|||
logger.info("pending updates: %r", pending_updates)
|
||||
|
||||
for master_key, self_signing_key in pending_updates:
|
||||
new_device_ids = (
|
||||
await device_list_updater.process_cross_signing_key_update(
|
||||
user_id,
|
||||
master_key,
|
||||
self_signing_key,
|
||||
)
|
||||
new_device_ids = await self._device_handler.device_list_updater.process_cross_signing_key_update(
|
||||
user_id,
|
||||
master_key,
|
||||
self_signing_key,
|
||||
)
|
||||
device_ids = device_ids + new_device_ids
|
||||
|
||||
await device_handler.notify_device_update(user_id, device_ids)
|
||||
await self._device_handler.notify_device_update(user_id, device_ids)
|
||||
|
|
|
@ -38,6 +38,7 @@ from synapse.api.errors import (
|
|||
)
|
||||
from synapse.appservice import ApplicationService
|
||||
from synapse.config.server import is_threepid_reserved
|
||||
from synapse.handlers.device import DeviceHandler
|
||||
from synapse.http.servlet import assert_params_in_dict
|
||||
from synapse.replication.http.login import RegisterDeviceReplicationServlet
|
||||
from synapse.replication.http.register import (
|
||||
|
@ -841,6 +842,9 @@ class RegistrationHandler:
|
|||
refresh_token = None
|
||||
refresh_token_id = None
|
||||
|
||||
# This can only run on the main process.
|
||||
assert isinstance(self.device_handler, DeviceHandler)
|
||||
|
||||
registered_device_id = await self.device_handler.check_device_registered(
|
||||
user_id,
|
||||
device_id,
|
||||
|
|
|
@ -15,6 +15,7 @@ import logging
|
|||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from synapse.api.errors import Codes, StoreError, SynapseError
|
||||
from synapse.handlers.device import DeviceHandler
|
||||
from synapse.types import Requester
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -29,7 +30,10 @@ class SetPasswordHandler:
|
|||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastores().main
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
self._device_handler = hs.get_device_handler()
|
||||
# This can only be instantiated on the main process.
|
||||
device_handler = hs.get_device_handler()
|
||||
assert isinstance(device_handler, DeviceHandler)
|
||||
self._device_handler = device_handler
|
||||
|
||||
async def set_password(
|
||||
self,
|
||||
|
|
|
@ -37,6 +37,7 @@ from twisted.web.server import Request
|
|||
from synapse.api.constants import LoginType
|
||||
from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
|
||||
from synapse.config.sso import SsoAttributeRequirement
|
||||
from synapse.handlers.device import DeviceHandler
|
||||
from synapse.handlers.register import init_counters_for_auth_provider
|
||||
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
|
||||
from synapse.http import get_request_user_agent
|
||||
|
@ -1035,6 +1036,8 @@ class SsoHandler:
|
|||
) -> None:
|
||||
"""Revoke any devices and in-flight logins tied to a provider session.
|
||||
|
||||
Can only be called from the main process.
|
||||
|
||||
Args:
|
||||
auth_provider_id: A unique identifier for this SSO provider, e.g.
|
||||
"oidc" or "saml".
|
||||
|
@ -1042,6 +1045,12 @@ class SsoHandler:
|
|||
expected_user_id: The user we're expecting to logout. If set, it will ignore
|
||||
sessions belonging to other users and log an error.
|
||||
"""
|
||||
|
||||
# It is expected that this is the main process.
|
||||
assert isinstance(
|
||||
self._device_handler, DeviceHandler
|
||||
), "revoking SSO sessions can only be called on the main process"
|
||||
|
||||
# Invalidate any running user-mapping sessions
|
||||
to_delete = []
|
||||
for session_id, session in self._username_mapping_sessions.items():
|
||||
|
|
|
@ -86,6 +86,7 @@ from synapse.handlers.auth import (
|
|||
ON_LOGGED_OUT_CALLBACK,
|
||||
AuthHandler,
|
||||
)
|
||||
from synapse.handlers.device import DeviceHandler
|
||||
from synapse.handlers.push_rules import RuleSpec, check_actions
|
||||
from synapse.http.client import SimpleHttpClient
|
||||
from synapse.http.server import (
|
||||
|
@ -207,6 +208,7 @@ class ModuleApi:
|
|||
self._registration_handler = hs.get_registration_handler()
|
||||
self._send_email_handler = hs.get_send_email_handler()
|
||||
self._push_rules_handler = hs.get_push_rules_handler()
|
||||
self._device_handler = hs.get_device_handler()
|
||||
self.custom_template_dir = hs.config.server.custom_template_directory
|
||||
|
||||
try:
|
||||
|
@ -784,6 +786,8 @@ class ModuleApi:
|
|||
) -> Generator["defer.Deferred[Any]", Any, None]:
|
||||
"""Invalidate an access token for a user
|
||||
|
||||
Can only be called from the main process.
|
||||
|
||||
Added in Synapse v0.25.0.
|
||||
|
||||
Args:
|
||||
|
@ -796,6 +800,10 @@ class ModuleApi:
|
|||
Raises:
|
||||
synapse.api.errors.AuthError: the access token is invalid
|
||||
"""
|
||||
assert isinstance(
|
||||
self._device_handler, DeviceHandler
|
||||
), "invalidate_access_token can only be called on the main process"
|
||||
|
||||
# see if the access token corresponds to a device
|
||||
user_info = yield defer.ensureDeferred(
|
||||
self._auth.get_user_by_access_token(access_token)
|
||||
|
@ -805,7 +813,7 @@ class ModuleApi:
|
|||
if device_id:
|
||||
# delete the device, which will also delete its access tokens
|
||||
yield defer.ensureDeferred(
|
||||
self._hs.get_device_handler().delete_devices(user_id, [device_id])
|
||||
self._device_handler.delete_devices(user_id, [device_id])
|
||||
)
|
||||
else:
|
||||
# no associated device. Just delete the access token.
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
|
||||
from twisted.web.server import Request
|
||||
|
||||
|
@ -63,7 +63,12 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
|
|||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
|
||||
self.device_list_updater = hs.get_device_handler().device_list_updater
|
||||
from synapse.handlers.device import DeviceHandler
|
||||
|
||||
handler = hs.get_device_handler()
|
||||
assert isinstance(handler, DeviceHandler)
|
||||
self.device_list_updater = handler.device_list_updater
|
||||
|
||||
self.store = hs.get_datastores().main
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
|
@ -73,7 +78,7 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
|
|||
|
||||
async def _handle_request( # type: ignore[override]
|
||||
self, request: Request, user_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
) -> Tuple[int, Optional[JsonDict]]:
|
||||
user_devices = await self.device_list_updater.user_device_resync(user_id)
|
||||
|
||||
return 200, user_devices
|
||||
|
|
|
@ -238,6 +238,10 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
|||
"""
|
||||
Register all the admin servlets.
|
||||
"""
|
||||
# Admin servlets aren't registered on workers.
|
||||
if hs.config.worker.worker_app is not None:
|
||||
return
|
||||
|
||||
register_servlets_for_client_rest_resource(hs, http_server)
|
||||
BlockRoomRestServlet(hs).register(http_server)
|
||||
ListRoomRestServlet(hs).register(http_server)
|
||||
|
@ -254,9 +258,6 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
|||
UserTokenRestServlet(hs).register(http_server)
|
||||
UserRestServletV2(hs).register(http_server)
|
||||
UsersRestServletV2(hs).register(http_server)
|
||||
DeviceRestServlet(hs).register(http_server)
|
||||
DevicesRestServlet(hs).register(http_server)
|
||||
DeleteDevicesRestServlet(hs).register(http_server)
|
||||
UserMediaStatisticsRestServlet(hs).register(http_server)
|
||||
EventReportDetailRestServlet(hs).register(http_server)
|
||||
EventReportsRestServlet(hs).register(http_server)
|
||||
|
@ -280,12 +281,13 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
|||
UserByExternalId(hs).register(http_server)
|
||||
UserByThreePid(hs).register(http_server)
|
||||
|
||||
# Some servlets only get registered for the main process.
|
||||
if hs.config.worker.worker_app is None:
|
||||
SendServerNoticeServlet(hs).register(http_server)
|
||||
BackgroundUpdateEnabledRestServlet(hs).register(http_server)
|
||||
BackgroundUpdateRestServlet(hs).register(http_server)
|
||||
BackgroundUpdateStartJobRestServlet(hs).register(http_server)
|
||||
DeviceRestServlet(hs).register(http_server)
|
||||
DevicesRestServlet(hs).register(http_server)
|
||||
DeleteDevicesRestServlet(hs).register(http_server)
|
||||
SendServerNoticeServlet(hs).register(http_server)
|
||||
BackgroundUpdateEnabledRestServlet(hs).register(http_server)
|
||||
BackgroundUpdateRestServlet(hs).register(http_server)
|
||||
BackgroundUpdateStartJobRestServlet(hs).register(http_server)
|
||||
|
||||
|
||||
def register_servlets_for_client_rest_resource(
|
||||
|
@ -294,9 +296,11 @@ def register_servlets_for_client_rest_resource(
|
|||
"""Register only the servlets which need to be exposed on /_matrix/client/xxx"""
|
||||
WhoisRestServlet(hs).register(http_server)
|
||||
PurgeHistoryStatusRestServlet(hs).register(http_server)
|
||||
DeactivateAccountRestServlet(hs).register(http_server)
|
||||
PurgeHistoryRestServlet(hs).register(http_server)
|
||||
ResetPasswordRestServlet(hs).register(http_server)
|
||||
# The following resources can only be run on the main process.
|
||||
if hs.config.worker.worker_app is None:
|
||||
DeactivateAccountRestServlet(hs).register(http_server)
|
||||
ResetPasswordRestServlet(hs).register(http_server)
|
||||
SearchUsersRestServlet(hs).register(http_server)
|
||||
UserRegisterServlet(hs).register(http_server)
|
||||
AccountValidityRenewServlet(hs).register(http_server)
|
||||
|
|
|
@ -16,6 +16,7 @@ from http import HTTPStatus
|
|||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from synapse.api.errors import NotFoundError, SynapseError
|
||||
from synapse.handlers.device import DeviceHandler
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
assert_params_in_dict,
|
||||
|
@ -43,7 +44,9 @@ class DeviceRestServlet(RestServlet):
|
|||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.device_handler = hs.get_device_handler()
|
||||
handler = hs.get_device_handler()
|
||||
assert isinstance(handler, DeviceHandler)
|
||||
self.device_handler = handler
|
||||
self.store = hs.get_datastores().main
|
||||
self.is_mine = hs.is_mine
|
||||
|
||||
|
@ -112,7 +115,9 @@ class DevicesRestServlet(RestServlet):
|
|||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.auth = hs.get_auth()
|
||||
self.device_handler = hs.get_device_handler()
|
||||
handler = hs.get_device_handler()
|
||||
assert isinstance(handler, DeviceHandler)
|
||||
self.device_handler = handler
|
||||
self.store = hs.get_datastores().main
|
||||
self.is_mine = hs.is_mine
|
||||
|
||||
|
@ -143,7 +148,9 @@ class DeleteDevicesRestServlet(RestServlet):
|
|||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.auth = hs.get_auth()
|
||||
self.device_handler = hs.get_device_handler()
|
||||
handler = hs.get_device_handler()
|
||||
assert isinstance(handler, DeviceHandler)
|
||||
self.device_handler = handler
|
||||
self.store = hs.get_datastores().main
|
||||
self.is_mine = hs.is_mine
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ from pydantic import Extra, StrictStr
|
|||
|
||||
from synapse.api import errors
|
||||
from synapse.api.errors import NotFoundError
|
||||
from synapse.handlers.device import DeviceHandler
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
|
@ -80,7 +81,9 @@ class DeleteDevicesRestServlet(RestServlet):
|
|||
super().__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.device_handler = hs.get_device_handler()
|
||||
handler = hs.get_device_handler()
|
||||
assert isinstance(handler, DeviceHandler)
|
||||
self.device_handler = handler
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
|
||||
class PostBody(RequestBodyModel):
|
||||
|
@ -125,7 +128,9 @@ class DeviceRestServlet(RestServlet):
|
|||
super().__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.device_handler = hs.get_device_handler()
|
||||
handler = hs.get_device_handler()
|
||||
assert isinstance(handler, DeviceHandler)
|
||||
self.device_handler = handler
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
self._msc3852_enabled = hs.config.experimental.msc3852_enabled
|
||||
|
||||
|
@ -256,7 +261,9 @@ class DehydratedDeviceServlet(RestServlet):
|
|||
super().__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.device_handler = hs.get_device_handler()
|
||||
handler = hs.get_device_handler()
|
||||
assert isinstance(handler, DeviceHandler)
|
||||
self.device_handler = handler
|
||||
|
||||
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
|
@ -313,7 +320,9 @@ class ClaimDehydratedDeviceServlet(RestServlet):
|
|||
super().__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.device_handler = hs.get_device_handler()
|
||||
handler = hs.get_device_handler()
|
||||
assert isinstance(handler, DeviceHandler)
|
||||
self.device_handler = handler
|
||||
|
||||
class PostBody(RequestBodyModel):
|
||||
device_id: StrictStr
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
import logging
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from synapse.handlers.device import DeviceHandler
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import RestServlet
|
||||
from synapse.http.site import SynapseRequest
|
||||
|
@ -34,7 +35,9 @@ class LogoutRestServlet(RestServlet):
|
|||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
self._device_handler = hs.get_device_handler()
|
||||
handler = hs.get_device_handler()
|
||||
assert isinstance(handler, DeviceHandler)
|
||||
self._device_handler = handler
|
||||
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_expired=True)
|
||||
|
@ -59,7 +62,9 @@ class LogoutAllRestServlet(RestServlet):
|
|||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
self._device_handler = hs.get_device_handler()
|
||||
handler = hs.get_device_handler()
|
||||
assert isinstance(handler, DeviceHandler)
|
||||
self._device_handler = handler
|
||||
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_expired=True)
|
||||
|
|
|
@ -510,7 +510,7 @@ class HomeServer(metaclass=abc.ABCMeta):
|
|||
)
|
||||
|
||||
@cache_in_self
|
||||
def get_device_handler(self):
|
||||
def get_device_handler(self) -> DeviceWorkerHandler:
|
||||
if self.config.worker.worker_app:
|
||||
return DeviceWorkerHandler(self)
|
||||
else:
|
||||
|
|
|
@ -19,7 +19,7 @@ from typing import Optional
|
|||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.api.errors import NotFoundError, SynapseError
|
||||
from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN
|
||||
from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler
|
||||
from synapse.server import HomeServer
|
||||
from synapse.util import Clock
|
||||
|
||||
|
@ -32,7 +32,9 @@ user2 = "@theresa:bbb"
|
|||
class DeviceTestCase(unittest.HomeserverTestCase):
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
hs = self.setup_test_homeserver("server", federation_http_client=None)
|
||||
self.handler = hs.get_device_handler()
|
||||
handler = hs.get_device_handler()
|
||||
assert isinstance(handler, DeviceHandler)
|
||||
self.handler = handler
|
||||
self.store = hs.get_datastores().main
|
||||
return hs
|
||||
|
||||
|
@ -61,6 +63,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
|
|||
self.assertEqual(res, "fco")
|
||||
|
||||
dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
|
||||
assert dev is not None
|
||||
self.assertEqual(dev["display_name"], "display name")
|
||||
|
||||
def test_device_is_preserved_if_exists(self) -> None:
|
||||
|
@ -83,6 +86,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
|
|||
self.assertEqual(res2, "fco")
|
||||
|
||||
dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
|
||||
assert dev is not None
|
||||
self.assertEqual(dev["display_name"], "display name")
|
||||
|
||||
def test_device_id_is_made_up_if_unspecified(self) -> None:
|
||||
|
@ -95,6 +99,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
|
||||
dev = self.get_success(self.handler.store.get_device("@theresa:foo", device_id))
|
||||
assert dev is not None
|
||||
self.assertEqual(dev["display_name"], "display")
|
||||
|
||||
def test_get_devices_by_user(self) -> None:
|
||||
|
@ -264,7 +269,9 @@ class DeviceTestCase(unittest.HomeserverTestCase):
|
|||
class DehydrationTestCase(unittest.HomeserverTestCase):
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
hs = self.setup_test_homeserver("server", federation_http_client=None)
|
||||
self.handler = hs.get_device_handler()
|
||||
handler = hs.get_device_handler()
|
||||
assert isinstance(handler, DeviceHandler)
|
||||
self.handler = handler
|
||||
self.registration = hs.get_registration_handler()
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastores().main
|
||||
|
@ -284,9 +291,9 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
)
|
||||
|
||||
retrieved_device_id, device_data = self.get_success(
|
||||
self.handler.get_dehydrated_device(user_id=user_id)
|
||||
)
|
||||
result = self.get_success(self.handler.get_dehydrated_device(user_id=user_id))
|
||||
assert result is not None
|
||||
retrieved_device_id, device_data = result
|
||||
|
||||
self.assertEqual(retrieved_device_id, stored_dehydrated_device_id)
|
||||
self.assertEqual(device_data, {"device_data": {"foo": "bar"}})
|
||||
|
|
|
@ -19,6 +19,7 @@ from twisted.test.proto_helpers import MemoryReactor
|
|||
|
||||
import synapse.rest.admin
|
||||
from synapse.api.errors import Codes
|
||||
from synapse.handlers.device import DeviceHandler
|
||||
from synapse.rest.client import login
|
||||
from synapse.server import HomeServer
|
||||
from synapse.util import Clock
|
||||
|
@ -34,7 +35,9 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
|
|||
]
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.handler = hs.get_device_handler()
|
||||
handler = hs.get_device_handler()
|
||||
assert isinstance(handler, DeviceHandler)
|
||||
self.handler = handler
|
||||
|
||||
self.admin_user = self.register_user("admin", "pass", admin=True)
|
||||
self.admin_user_tok = self.login("admin", "pass")
|
||||
|
|
Loading…
Reference in a new issue