0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-06-09 22:28:55 +02:00

Batch up replication requests to request the resyncing of remote users's devices. (#14716)

This commit is contained in:
reivilibre 2023-01-10 11:17:59 +00:00 committed by GitHub
parent 3479599387
commit ba4ea7d13f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 306 additions and 79 deletions

1
changelog.d/14716.misc Normal file
View file

@ -0,0 +1 @@
Batch up replication requests to request the resyncing of remote users's devices.

View file

@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from http import HTTPStatus
from typing import (
TYPE_CHECKING,
Any,
@ -33,6 +34,7 @@ from synapse.api.errors import (
Codes,
FederationDeniedError,
HttpResponseException,
InvalidAPICallError,
RequestSendFailed,
SynapseError,
)
@ -45,6 +47,7 @@ from synapse.types import (
JsonDict,
StreamKeyType,
StreamToken,
UserID,
get_domain_from_id,
get_verify_key_from_cross_signing_key,
)
@ -893,12 +896,47 @@ class DeviceListWorkerUpdater:
def __init__(self, hs: "HomeServer"):
from synapse.replication.http.devices import (
ReplicationMultiUserDevicesResyncRestServlet,
ReplicationUserDevicesResyncRestServlet,
)
self._user_device_resync_client = (
ReplicationUserDevicesResyncRestServlet.make_client(hs)
)
self._multi_user_device_resync_client = (
ReplicationMultiUserDevicesResyncRestServlet.make_client(hs)
)
async def multi_user_device_resync(
self, user_ids: List[str], mark_failed_as_stale: bool = True
) -> Dict[str, Optional[JsonDict]]:
"""
Like `user_device_resync` but operates on multiple users **from the same origin**
at once.
Returns:
Dict from User ID to the same Dict as `user_device_resync`.
"""
# mark_failed_as_stale is not sent. Ensure this doesn't break expectations.
assert mark_failed_as_stale
if not user_ids:
# Shortcut empty requests
return {}
try:
return await self._multi_user_device_resync_client(user_ids=user_ids)
except SynapseError as err:
if not (
err.code == HTTPStatus.NOT_FOUND and err.errcode == Codes.UNRECOGNIZED
):
raise
# Fall back to single requests
result: Dict[str, Optional[JsonDict]] = {}
for user_id in user_ids:
result[user_id] = await self._user_device_resync_client(user_id=user_id)
return result
async def user_device_resync(
self, user_id: str, mark_failed_as_stale: bool = True
@ -913,8 +951,10 @@ class DeviceListWorkerUpdater:
A dict with device info as under the "devices" in the result of this
request:
https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
None when we weren't able to fetch the device info for some reason,
e.g. due to a connection problem.
"""
return await self._user_device_resync_client(user_id=user_id)
return (await self.multi_user_device_resync([user_id]))[user_id]
class DeviceListUpdater(DeviceListWorkerUpdater):
@ -1160,19 +1200,66 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
# Allow future calls to retry resyncinc out of sync device lists.
self._resync_retry_in_progress = False
async def multi_user_device_resync(
self, user_ids: List[str], mark_failed_as_stale: bool = True
) -> Dict[str, Optional[JsonDict]]:
"""
Like `user_device_resync` but operates on multiple users **from the same origin**
at once.
Returns:
Dict from User ID to the same Dict as `user_device_resync`.
"""
if not user_ids:
return {}
origins = {UserID.from_string(user_id).domain for user_id in user_ids}
if len(origins) != 1:
raise InvalidAPICallError(f"Only one origin permitted, got {origins!r}")
result = {}
failed = set()
# TODO(Perf): Actually batch these up
for user_id in user_ids:
user_result, user_failed = await self._user_device_resync_returning_failed(
user_id
)
result[user_id] = user_result
if user_failed:
failed.add(user_id)
if mark_failed_as_stale:
await self.store.mark_remote_users_device_caches_as_stale(failed)
return result
async def user_device_resync(
self, user_id: str, mark_failed_as_stale: bool = True
) -> Optional[JsonDict]:
result, failed = await self._user_device_resync_returning_failed(user_id)
if failed and mark_failed_as_stale:
# Mark the remote user's device list as stale so we know we need to retry
# it later.
await self.store.mark_remote_users_device_caches_as_stale((user_id,))
return result
async def _user_device_resync_returning_failed(
self, user_id: str
) -> Tuple[Optional[JsonDict], bool]:
"""Fetches all devices for a user and updates the device cache with them.
Args:
user_id: The user's id whose device_list will be updated.
mark_failed_as_stale: Whether to mark the user's device list as stale
if the attempt to resync failed.
Returns:
A dict with device info as under the "devices" in the result of this
request:
https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
- A dict with device info as under the "devices" in the result of this
request:
https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
None when we weren't able to fetch the device info for some reason,
e.g. due to a connection problem.
- True iff the resync failed and the device list should be marked as stale.
"""
logger.debug("Attempting to resync the device list for %s", user_id)
log_kv({"message": "Doing resync to update device list."})
@ -1181,12 +1268,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
try:
result = await self.federation.query_user_devices(origin, user_id)
except NotRetryingDestination:
if mark_failed_as_stale:
# Mark the remote user's device list as stale so we know we need to retry
# it later.
await self.store.mark_remote_user_device_cache_as_stale(user_id)
return None
return None, True
except (RequestSendFailed, HttpResponseException) as e:
logger.warning(
"Failed to handle device list update for %s: %s",
@ -1194,23 +1276,18 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
e,
)
if mark_failed_as_stale:
# Mark the remote user's device list as stale so we know we need to retry
# it later.
await self.store.mark_remote_user_device_cache_as_stale(user_id)
# We abort on exceptions rather than accepting the update
# as otherwise synapse will 'forget' that its device list
# is out of date. If we bail then we will retry the resync
# next time we get a device list update for this user_id.
# This makes it more likely that the device lists will
# eventually become consistent.
return None
return None, True
except FederationDeniedError as e:
set_tag("error", True)
log_kv({"reason": "FederationDeniedError"})
logger.info(e)
return None
return None, False
except Exception as e:
set_tag("error", True)
log_kv(
@ -1218,12 +1295,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
)
logger.exception("Failed to handle device list update for %s", user_id)
if mark_failed_as_stale:
# Mark the remote user's device list as stale so we know we need to retry
# it later.
await self.store.mark_remote_user_device_cache_as_stale(user_id)
return None
return None, True
log_kv({"result": result})
stream_id = result["stream_id"]
devices = result["devices"]
@ -1305,7 +1377,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
# point.
self._seen_updates[user_id] = {stream_id}
return result
return result, False
async def process_cross_signing_key_update(
self,

View file

@ -195,7 +195,7 @@ class DeviceMessageHandler:
sender_user_id,
unknown_devices,
)
await self.store.mark_remote_user_device_cache_as_stale(sender_user_id)
await self.store.mark_remote_users_device_caches_as_stale((sender_user_id,))
# Immediately attempt a resync in the background
run_in_background(self._user_device_resync, user_id=sender_user_id)

View file

@ -36,8 +36,8 @@ from synapse.types import (
get_domain_from_id,
get_verify_key_from_cross_signing_key,
)
from synapse.util import json_decoder, unwrapFirstError
from synapse.util.async_helpers import Linearizer, delay_cancellation
from synapse.util import json_decoder
from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.cancellation import cancellable
from synapse.util.retryutils import NotRetryingDestination
@ -238,24 +238,28 @@ class E2eKeysHandler:
# Now fetch any devices that we don't have in our cache
# TODO It might make sense to propagate cancellations into the
# deferreds which are querying remote homeservers.
await make_deferred_yieldable(
delay_cancellation(
defer.gatherResults(
[
run_in_background(
self._query_devices_for_destination,
results,
cross_signing_keys,
failures,
destination,
queries,
timeout,
)
for destination, queries in remote_queries_not_in_cache.items()
],
consumeErrors=True,
).addErrback(unwrapFirstError)
logger.debug(
"%d destinations to query devices for", len(remote_queries_not_in_cache)
)
async def _query(
destination_queries: Tuple[str, Dict[str, Iterable[str]]]
) -> None:
destination, queries = destination_queries
return await self._query_devices_for_destination(
results,
cross_signing_keys,
failures,
destination,
queries,
timeout,
)
await concurrently_execute(
_query,
remote_queries_not_in_cache.items(),
10,
delay_cancellation=True,
)
ret = {"device_keys": results, "failures": failures}
@ -300,28 +304,41 @@ class E2eKeysHandler:
# queries. We use the more efficient batched query_client_keys for all
# remaining users
user_ids_updated = []
for (user_id, device_list) in destination_query.items():
if user_id in user_ids_updated:
continue
if device_list:
continue
# Perform a user device resync for each user only once and only as long as:
# - they have an empty device_list
# - they are in some rooms that this server can see
users_to_resync_devices = {
user_id
for (user_id, device_list) in destination_query.items()
if (not device_list) and (await self.store.get_rooms_for_user(user_id))
}
room_ids = await self.store.get_rooms_for_user(user_id)
if not room_ids:
continue
logger.debug(
"%d users to resync devices for from destination %s",
len(users_to_resync_devices),
destination,
)
# We've decided we're sharing a room with this user and should
# probably be tracking their device lists. However, we haven't
# done an initial sync on the device list so we do it now.
try:
resync_results = (
await self.device_handler.device_list_updater.user_device_resync(
user_id
)
try:
user_resync_results = (
await self.device_handler.device_list_updater.multi_user_device_resync(
list(users_to_resync_devices)
)
)
for user_id in users_to_resync_devices:
resync_results = user_resync_results[user_id]
if resync_results is None:
raise ValueError("Device resync failed")
# TODO: It's weird that we'll store a failure against a
# destination, yet continue processing users from that
# destination.
# We might want to consider changing this, but for now
# I'm leaving it as I found it.
failures[destination] = _exception_to_failure(
ValueError(f"Device resync failed for {user_id!r}")
)
continue
# Add the device keys to the results.
user_devices = resync_results["devices"]
@ -339,8 +356,8 @@ class E2eKeysHandler:
if self_signing_key:
cross_signing_keys["self_signing_keys"][user_id] = self_signing_key
except Exception as e:
failures[destination] = _exception_to_failure(e)
except Exception as e:
failures[destination] = _exception_to_failure(e)
if len(destination_query) == len(user_ids_updated):
# We've updated all the users in the query and we do not need to

View file

@ -1423,7 +1423,7 @@ class FederationEventHandler:
"""
try:
await self._store.mark_remote_user_device_cache_as_stale(sender)
await self._store.mark_remote_users_device_caches_as_stale((sender,))
# Immediately attempt a resync in the background
if self._config.worker.worker_app:

View file

@ -13,12 +13,13 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Optional, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from twisted.web.server import Request
from synapse.http.server import HttpServer
from synapse.http.servlet import parse_json_object_from_request
from synapse.logging.opentracing import active_span
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict
@ -84,6 +85,76 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
return 200, user_devices
class ReplicationMultiUserDevicesResyncRestServlet(ReplicationEndpoint):
"""Ask master to resync the device list for multiple users from the same
remote server by contacting their server.
This must happen on master so that the results can be correctly cached in
the database and streamed to workers.
Request format:
POST /_synapse/replication/multi_user_device_resync
{
"user_ids": ["@alice:example.org", "@bob:example.org", ...]
}
Response is roughly equivalent to ` /_matrix/federation/v1/user/devices/:user_id`
response, but there is a map from user ID to response, e.g.:
{
"@alice:example.org": {
"devices": [
{
"device_id": "JLAFKJWSCS",
"keys": { ... },
"device_display_name": "Alice's Mobile Phone"
}
]
},
...
}
"""
NAME = "multi_user_device_resync"
PATH_ARGS = ()
CACHE = False
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
from synapse.handlers.device import DeviceHandler
handler = hs.get_device_handler()
assert isinstance(handler, DeviceHandler)
self.device_list_updater = handler.device_list_updater
self.store = hs.get_datastores().main
self.clock = hs.get_clock()
@staticmethod
async def _serialize_payload(user_ids: List[str]) -> JsonDict: # type: ignore[override]
return {"user_ids": user_ids}
async def _handle_request( # type: ignore[override]
self, request: Request
) -> Tuple[int, Dict[str, Optional[JsonDict]]]:
content = parse_json_object_from_request(request)
user_ids: List[str] = content["user_ids"]
logger.info("Resync for %r", user_ids)
span = active_span()
if span:
span.set_tag("user_ids", f"{user_ids!r}")
multi_user_devices = await self.device_list_updater.multi_user_device_resync(
user_ids
)
return 200, multi_user_devices
class ReplicationUploadKeysForUserRestServlet(ReplicationEndpoint):
"""Ask master to upload keys for the user and send them out over federation to
update other servers.
@ -151,4 +222,5 @@ class ReplicationUploadKeysForUserRestServlet(ReplicationEndpoint):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReplicationUserDevicesResyncRestServlet(hs).register(http_server)
ReplicationMultiUserDevicesResyncRestServlet(hs).register(http_server)
ReplicationUploadKeysForUserRestServlet(hs).register(http_server)

View file

@ -54,7 +54,7 @@ from synapse.storage.util.id_generators import (
AbstractStreamIdTracker,
StreamIdGenerator,
)
from synapse.types import JsonDict, get_verify_key_from_cross_signing_key
from synapse.types import JsonDict, StrCollection, get_verify_key_from_cross_signing_key
from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.lrucache import LruCache
@ -1069,16 +1069,30 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
return {row["user_id"] for row in rows}
async def mark_remote_user_device_cache_as_stale(self, user_id: str) -> None:
async def mark_remote_users_device_caches_as_stale(
self, user_ids: StrCollection
) -> None:
"""Records that the server has reason to believe the cache of the devices
for the remote users is out of date.
"""
await self.db_pool.simple_upsert(
table="device_lists_remote_resync",
keyvalues={"user_id": user_id},
values={},
insertion_values={"added_ts": self._clock.time_msec()},
desc="mark_remote_user_device_cache_as_stale",
def _mark_remote_users_device_caches_as_stale_txn(
txn: LoggingTransaction,
) -> None:
# TODO add insertion_values support to simple_upsert_many and use
# that!
for user_id in user_ids:
self.db_pool.simple_upsert_txn(
txn,
table="device_lists_remote_resync",
keyvalues={"user_id": user_id},
values={},
insertion_values={"added_ts": self._clock.time_msec()},
)
await self.db_pool.runInteraction(
"mark_remote_users_device_caches_as_stale",
_mark_remote_users_device_caches_as_stale_txn,
)
async def mark_remote_user_device_cache_as_valid(self, user_id: str) -> None:

View file

@ -77,6 +77,10 @@ JsonMapping = Mapping[str, Any]
# A JSON-serialisable object.
JsonSerializable = object
# Collection[str] that does not include str itself; str being a Sequence[str]
# is very misleading and results in bugs.
StrCollection = Union[Tuple[str, ...], List[str], Set[str]]
# Note that this seems to require inheriting *directly* from Interface in order
# for mypy-zope to realize it is an interface.

View file

@ -205,7 +205,10 @@ T = TypeVar("T")
async def concurrently_execute(
func: Callable[[T], Any], args: Iterable[T], limit: int
func: Callable[[T], Any],
args: Iterable[T],
limit: int,
delay_cancellation: bool = False,
) -> None:
"""Executes the function with each argument concurrently while limiting
the number of concurrent executions.
@ -215,6 +218,8 @@ async def concurrently_execute(
args: List of arguments to pass to func, each invocation of func
gets a single argument.
limit: Maximum number of conccurent executions.
delay_cancellation: Whether to delay cancellation until after the invocations
have finished.
Returns:
None, when all function invocations have finished. The return values
@ -233,9 +238,16 @@ async def concurrently_execute(
# We use `itertools.islice` to handle the case where the number of args is
# less than the limit, avoiding needlessly spawning unnecessary background
# tasks.
await yieldable_gather_results(
_concurrently_execute_inner, (value for value in itertools.islice(it, limit))
)
if delay_cancellation:
await yieldable_gather_results_delaying_cancellation(
_concurrently_execute_inner,
(value for value in itertools.islice(it, limit)),
)
else:
await yieldable_gather_results(
_concurrently_execute_inner,
(value for value in itertools.islice(it, limit)),
)
P = ParamSpec("P")
@ -292,6 +304,41 @@ async def yieldable_gather_results(
raise dfe.subFailure.value from None
async def yieldable_gather_results_delaying_cancellation(
func: Callable[Concatenate[T, P], Awaitable[R]],
iter: Iterable[T],
*args: P.args,
**kwargs: P.kwargs,
) -> List[R]:
"""Executes the function with each argument concurrently.
Cancellation is delayed until after all the results have been gathered.
See `yieldable_gather_results`.
Args:
func: Function to execute that returns a Deferred
iter: An iterable that yields items that get passed as the first
argument to the function
*args: Arguments to be passed to each call to func
**kwargs: Keyword arguments to be passed to each call to func
Returns
A list containing the results of the function
"""
try:
return await make_deferred_yieldable(
delay_cancellation(
defer.gatherResults(
[run_in_background(func, item, *args, **kwargs) for item in iter], # type: ignore[arg-type]
consumeErrors=True,
)
)
)
except defer.FirstError as dfe:
assert isinstance(dfe.subFailure.value, BaseException)
raise dfe.subFailure.value from None
T1 = TypeVar("T1")
T2 = TypeVar("T2")
T3 = TypeVar("T3")