0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-15 22:42:23 +01:00

Wait for streams to catch up when processing HTTP replication. (#14820)

This should hopefully mitigate a class of races where data gets out of
sync due a HTTP replication request racing with the replication streams.
This commit is contained in:
Erik Johnston 2023-01-18 19:35:29 +00:00 committed by GitHub
parent e8f2bf5c40
commit 9187fd940e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 225 additions and 144 deletions

1
changelog.d/14820.bugfix Normal file
View file

@ -0,0 +1 @@
Fix rare races when using workers.

View file

@ -2259,6 +2259,10 @@ class FederationEventHandler:
event_and_contexts, backfilled=backfilled event_and_contexts, backfilled=backfilled
) )
# After persistence we always need to notify replication there may
# be new data.
self._notifier.notify_replication()
if self._ephemeral_messages_enabled: if self._ephemeral_messages_enabled:
for event in events: for event in events:
# If there's an expiry timestamp on the event, schedule its expiry. # If there's an expiry timestamp on the event, schedule its expiry.

View file

@ -17,7 +17,7 @@ import logging
import re import re
import urllib.parse import urllib.parse
from inspect import signature from inspect import signature
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Tuple from typing import TYPE_CHECKING, Any, Awaitable, Callable, ClassVar, Dict, List, Tuple
from prometheus_client import Counter, Gauge from prometheus_client import Counter, Gauge
@ -27,6 +27,7 @@ from twisted.web.server import Request
from synapse.api.errors import HttpResponseException, SynapseError from synapse.api.errors import HttpResponseException, SynapseError
from synapse.http import RequestTimedOutError from synapse.http import RequestTimedOutError
from synapse.http.server import HttpServer from synapse.http.server import HttpServer
from synapse.http.servlet import parse_json_object_from_request
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.logging import opentracing from synapse.logging import opentracing
from synapse.logging.opentracing import trace_with_opname from synapse.logging.opentracing import trace_with_opname
@ -53,6 +54,9 @@ _outgoing_request_counter = Counter(
) )
_STREAM_POSITION_KEY = "_INT_STREAM_POS"
class ReplicationEndpoint(metaclass=abc.ABCMeta): class ReplicationEndpoint(metaclass=abc.ABCMeta):
"""Helper base class for defining new replication HTTP endpoints. """Helper base class for defining new replication HTTP endpoints.
@ -94,6 +98,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
a connection error is received. a connection error is received.
RETRY_ON_CONNECT_ERROR_ATTEMPTS (int): Number of attempts to retry when RETRY_ON_CONNECT_ERROR_ATTEMPTS (int): Number of attempts to retry when
receiving connection errors, each will backoff exponentially longer. receiving connection errors, each will backoff exponentially longer.
WAIT_FOR_STREAMS (bool): Whether to wait for replication streams to
catch up before processing the request and/or response. Defaults to
True.
""" """
NAME: str = abc.abstractproperty() # type: ignore NAME: str = abc.abstractproperty() # type: ignore
@ -104,6 +111,8 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
RETRY_ON_CONNECT_ERROR = True RETRY_ON_CONNECT_ERROR = True
RETRY_ON_CONNECT_ERROR_ATTEMPTS = 5 # =63s (2^6-1) RETRY_ON_CONNECT_ERROR_ATTEMPTS = 5 # =63s (2^6-1)
WAIT_FOR_STREAMS: ClassVar[bool] = True
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
if self.CACHE: if self.CACHE:
self.response_cache: ResponseCache[str] = ResponseCache( self.response_cache: ResponseCache[str] = ResponseCache(
@ -126,6 +135,10 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
if hs.config.worker.worker_replication_secret: if hs.config.worker.worker_replication_secret:
self._replication_secret = hs.config.worker.worker_replication_secret self._replication_secret = hs.config.worker.worker_replication_secret
self._streams = hs.get_replication_command_handler().get_streams_to_replicate()
self._replication = hs.get_replication_data_handler()
self._instance_name = hs.get_instance_name()
def _check_auth(self, request: Request) -> None: def _check_auth(self, request: Request) -> None:
# Get the authorization header. # Get the authorization header.
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
@ -160,7 +173,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
@abc.abstractmethod @abc.abstractmethod
async def _handle_request( async def _handle_request(
self, request: Request, **kwargs: Any self, request: Request, content: JsonDict, **kwargs: Any
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
"""Handle incoming request. """Handle incoming request.
@ -201,6 +214,10 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
@trace_with_opname("outgoing_replication_request") @trace_with_opname("outgoing_replication_request")
async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any: async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any:
# We have to pull these out here to avoid circular dependencies...
streams = hs.get_replication_command_handler().get_streams_to_replicate()
replication = hs.get_replication_data_handler()
with outgoing_gauge.track_inprogress(): with outgoing_gauge.track_inprogress():
if instance_name == local_instance_name: if instance_name == local_instance_name:
raise Exception("Trying to send HTTP request to self") raise Exception("Trying to send HTTP request to self")
@ -219,6 +236,24 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
data = await cls._serialize_payload(**kwargs) data = await cls._serialize_payload(**kwargs)
if cls.METHOD != "GET" and cls.WAIT_FOR_STREAMS:
# Include the current stream positions that we write to. We
# don't do this for GETs as they don't have a body, and we
# generally assume that a GET won't rely on data we have
# written.
if _STREAM_POSITION_KEY in data:
raise Exception(
"data to send contains %r key", _STREAM_POSITION_KEY
)
data[_STREAM_POSITION_KEY] = {
"streams": {
stream.NAME: stream.current_token(local_instance_name)
for stream in streams
},
"instance_name": local_instance_name,
}
url_args = [ url_args = [
urllib.parse.quote(kwargs[name], safe="") for name in cls.PATH_ARGS urllib.parse.quote(kwargs[name], safe="") for name in cls.PATH_ARGS
] ]
@ -308,6 +343,18 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
) from e ) from e
_outgoing_request_counter.labels(cls.NAME, 200).inc() _outgoing_request_counter.labels(cls.NAME, 200).inc()
# Wait on any streams that the remote may have written to.
for stream_name, position in result.get(
_STREAM_POSITION_KEY, {}
).items():
await replication.wait_for_stream_position(
instance_name=instance_name,
stream_name=stream_name,
position=position,
raise_on_timeout=False,
)
return result return result
return send_request return send_request
@ -353,6 +400,23 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
if self._replication_secret: if self._replication_secret:
self._check_auth(request) self._check_auth(request)
if self.METHOD == "GET":
# GET APIs always have an empty body.
content = {}
else:
content = parse_json_object_from_request(request)
# Wait on any streams that the remote may have written to.
for stream_name, position in content.get(_STREAM_POSITION_KEY, {"streams": {}})[
"streams"
].items():
await self._replication.wait_for_stream_position(
instance_name=content[_STREAM_POSITION_KEY]["instance_name"],
stream_name=stream_name,
position=position,
raise_on_timeout=False,
)
if self.CACHE: if self.CACHE:
txn_id = kwargs.pop("txn_id") txn_id = kwargs.pop("txn_id")
@ -361,13 +425,28 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
# correctly yet. In particular, there may be issues to do with logging # correctly yet. In particular, there may be issues to do with logging
# context lifetimes. # context lifetimes.
return await self.response_cache.wrap( code, response = await self.response_cache.wrap(
txn_id, self._handle_request, request, **kwargs txn_id, self._handle_request, request, content, **kwargs
) )
else:
# The `@cancellable` decorator may be applied to `_handle_request`. But we # The `@cancellable` decorator may be applied to `_handle_request`. But we
# told `HttpServer.register_paths` that our handler is `_check_auth_and_handle`, # told `HttpServer.register_paths` that our handler is `_check_auth_and_handle`,
# so we have to set up the cancellable flag ourselves. # so we have to set up the cancellable flag ourselves.
request.is_render_cancellable = is_function_cancellable(self._handle_request) request.is_render_cancellable = is_function_cancellable(
self._handle_request
)
return await self._handle_request(request, **kwargs) code, response = await self._handle_request(request, content, **kwargs)
# Return streams we may have written to in the course of processing this
# request.
if _STREAM_POSITION_KEY in response:
raise Exception("data to send contains %r key", _STREAM_POSITION_KEY)
if self.WAIT_FOR_STREAMS:
response[_STREAM_POSITION_KEY] = {
stream.NAME: stream.current_token(self._instance_name)
for stream in self._streams
}
return code, response

View file

@ -18,7 +18,6 @@ from typing import TYPE_CHECKING, Tuple
from twisted.web.server import Request from twisted.web.server import Request
from synapse.http.server import HttpServer from synapse.http.server import HttpServer
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict from synapse.types import JsonDict
@ -61,10 +60,8 @@ class ReplicationAddUserAccountDataRestServlet(ReplicationEndpoint):
return payload return payload
async def _handle_request( # type: ignore[override] async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str, account_data_type: str self, request: Request, content: JsonDict, user_id: str, account_data_type: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
max_stream_id = await self.handler.add_account_data_for_user( max_stream_id = await self.handler.add_account_data_for_user(
user_id, account_data_type, content["content"] user_id, account_data_type, content["content"]
) )
@ -101,7 +98,7 @@ class ReplicationRemoveUserAccountDataRestServlet(ReplicationEndpoint):
return {} return {}
async def _handle_request( # type: ignore[override] async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str, account_data_type: str self, request: Request, content: JsonDict, user_id: str, account_data_type: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
max_stream_id = await self.handler.remove_account_data_for_user( max_stream_id = await self.handler.remove_account_data_for_user(
user_id, account_data_type user_id, account_data_type
@ -143,10 +140,13 @@ class ReplicationAddRoomAccountDataRestServlet(ReplicationEndpoint):
return payload return payload
async def _handle_request( # type: ignore[override] async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str, room_id: str, account_data_type: str self,
request: Request,
content: JsonDict,
user_id: str,
room_id: str,
account_data_type: str,
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
max_stream_id = await self.handler.add_account_data_to_room( max_stream_id = await self.handler.add_account_data_to_room(
user_id, room_id, account_data_type, content["content"] user_id, room_id, account_data_type, content["content"]
) )
@ -183,7 +183,12 @@ class ReplicationRemoveRoomAccountDataRestServlet(ReplicationEndpoint):
return {} return {}
async def _handle_request( # type: ignore[override] async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str, room_id: str, account_data_type: str self,
request: Request,
content: JsonDict,
user_id: str,
room_id: str,
account_data_type: str,
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
max_stream_id = await self.handler.remove_account_data_for_room( max_stream_id = await self.handler.remove_account_data_for_room(
user_id, room_id, account_data_type user_id, room_id, account_data_type
@ -225,10 +230,8 @@ class ReplicationAddTagRestServlet(ReplicationEndpoint):
return payload return payload
async def _handle_request( # type: ignore[override] async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str, room_id: str, tag: str self, request: Request, content: JsonDict, user_id: str, room_id: str, tag: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
max_stream_id = await self.handler.add_tag_to_room( max_stream_id = await self.handler.add_tag_to_room(
user_id, room_id, tag, content["content"] user_id, room_id, tag, content["content"]
) )
@ -266,7 +269,7 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint):
return {} return {}
async def _handle_request( # type: ignore[override] async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str, room_id: str, tag: str self, request: Request, content: JsonDict, user_id: str, room_id: str, tag: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
max_stream_id = await self.handler.remove_tag_from_room( max_stream_id = await self.handler.remove_tag_from_room(
user_id, user_id,

View file

@ -18,7 +18,6 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from twisted.web.server import Request from twisted.web.server import Request
from synapse.http.server import HttpServer from synapse.http.server import HttpServer
from synapse.http.servlet import parse_json_object_from_request
from synapse.logging.opentracing import active_span from synapse.logging.opentracing import active_span
from synapse.replication.http._base import ReplicationEndpoint from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict from synapse.types import JsonDict
@ -78,7 +77,7 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
return {} return {}
async def _handle_request( # type: ignore[override] async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str self, request: Request, content: JsonDict, user_id: str
) -> Tuple[int, Optional[JsonDict]]: ) -> Tuple[int, Optional[JsonDict]]:
user_devices = await self.device_list_updater.user_device_resync(user_id) user_devices = await self.device_list_updater.user_device_resync(user_id)
@ -138,9 +137,8 @@ class ReplicationMultiUserDevicesResyncRestServlet(ReplicationEndpoint):
return {"user_ids": user_ids} return {"user_ids": user_ids}
async def _handle_request( # type: ignore[override] async def _handle_request( # type: ignore[override]
self, request: Request self, request: Request, content: JsonDict
) -> Tuple[int, Dict[str, Optional[JsonDict]]]: ) -> Tuple[int, Dict[str, Optional[JsonDict]]]:
content = parse_json_object_from_request(request)
user_ids: List[str] = content["user_ids"] user_ids: List[str] = content["user_ids"]
logger.info("Resync for %r", user_ids) logger.info("Resync for %r", user_ids)
@ -205,10 +203,8 @@ class ReplicationUploadKeysForUserRestServlet(ReplicationEndpoint):
} }
async def _handle_request( # type: ignore[override] async def _handle_request( # type: ignore[override]
self, request: Request self, request: Request, content: JsonDict
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
user_id = content["user_id"] user_id = content["user_id"]
device_id = content["device_id"] device_id = content["device_id"]
keys = content["keys"] keys = content["keys"]

View file

@ -21,7 +21,6 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.events import EventBase, make_event_from_dict from synapse.events import EventBase, make_event_from_dict
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.http.server import HttpServer from synapse.http.server import HttpServer
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
@ -114,10 +113,8 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
return payload return payload
async def _handle_request(self, request: Request) -> Tuple[int, JsonDict]: # type: ignore[override] async def _handle_request(self, request: Request, content: JsonDict) -> Tuple[int, JsonDict]: # type: ignore[override]
with Measure(self.clock, "repl_fed_send_events_parse"): with Measure(self.clock, "repl_fed_send_events_parse"):
content = parse_json_object_from_request(request)
room_id = content["room_id"] room_id = content["room_id"]
backfilled = content["backfilled"] backfilled = content["backfilled"]
@ -181,11 +178,8 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
return {"origin": origin, "content": content} return {"origin": origin, "content": content}
async def _handle_request( # type: ignore[override] async def _handle_request( # type: ignore[override]
self, request: Request, edu_type: str self, request: Request, content: JsonDict, edu_type: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
with Measure(self.clock, "repl_fed_send_edu_parse"):
content = parse_json_object_from_request(request)
origin = content["origin"] origin = content["origin"]
edu_content = content["content"] edu_content = content["content"]
@ -231,11 +225,8 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint):
return {"args": args} return {"args": args}
async def _handle_request( # type: ignore[override] async def _handle_request( # type: ignore[override]
self, request: Request, query_type: str self, request: Request, content: JsonDict, query_type: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
with Measure(self.clock, "repl_fed_query_parse"):
content = parse_json_object_from_request(request)
args = content["args"] args = content["args"]
args["origin"] = content["origin"] args["origin"] = content["origin"]
@ -274,7 +265,7 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
return {} return {}
async def _handle_request( # type: ignore[override] async def _handle_request( # type: ignore[override]
self, request: Request, room_id: str self, request: Request, content: JsonDict, room_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
await self.store.clean_room_for_join(room_id) await self.store.clean_room_for_join(room_id)
@ -307,9 +298,8 @@ class ReplicationStoreRoomOnOutlierMembershipRestServlet(ReplicationEndpoint):
return {"room_version": room_version.identifier} return {"room_version": room_version.identifier}
async def _handle_request( # type: ignore[override] async def _handle_request( # type: ignore[override]
self, request: Request, room_id: str self, request: Request, content: JsonDict, room_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
room_version = KNOWN_ROOM_VERSIONS[content["room_version"]] room_version = KNOWN_ROOM_VERSIONS[content["room_version"]]
await self.store.maybe_store_room_on_outlier_membership(room_id, room_version) await self.store.maybe_store_room_on_outlier_membership(room_id, room_version)
return 200, {} return 200, {}

View file

@ -18,7 +18,6 @@ from typing import TYPE_CHECKING, Optional, Tuple, cast
from twisted.web.server import Request from twisted.web.server import Request
from synapse.http.server import HttpServer from synapse.http.server import HttpServer
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict from synapse.types import JsonDict
@ -73,10 +72,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
} }
async def _handle_request( # type: ignore[override] async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str self, request: Request, content: JsonDict, user_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
device_id = content["device_id"] device_id = content["device_id"]
initial_display_name = content["initial_display_name"] initial_display_name = content["initial_display_name"]
is_guest = content["is_guest"] is_guest = content["is_guest"]

View file

@ -17,7 +17,6 @@ from typing import TYPE_CHECKING, List, Optional, Tuple
from twisted.web.server import Request from twisted.web.server import Request
from synapse.http.server import HttpServer from synapse.http.server import HttpServer
from synapse.http.servlet import parse_json_object_from_request
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.replication.http._base import ReplicationEndpoint from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict, Requester, UserID from synapse.types import JsonDict, Requester, UserID
@ -79,10 +78,8 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
} }
async def _handle_request( # type: ignore[override] async def _handle_request( # type: ignore[override]
self, request: SynapseRequest, room_id: str, user_id: str self, request: SynapseRequest, content: JsonDict, room_id: str, user_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
remote_room_hosts = content["remote_room_hosts"] remote_room_hosts = content["remote_room_hosts"]
event_content = content["content"] event_content = content["content"]
@ -147,11 +144,10 @@ class ReplicationRemoteKnockRestServlet(ReplicationEndpoint):
async def _handle_request( # type: ignore[override] async def _handle_request( # type: ignore[override]
self, self,
request: SynapseRequest, request: SynapseRequest,
content: JsonDict,
room_id: str, room_id: str,
user_id: str, user_id: str,
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
remote_room_hosts = content["remote_room_hosts"] remote_room_hosts = content["remote_room_hosts"]
event_content = content["content"] event_content = content["content"]
@ -217,10 +213,8 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
} }
async def _handle_request( # type: ignore[override] async def _handle_request( # type: ignore[override]
self, request: SynapseRequest, invite_event_id: str self, request: SynapseRequest, content: JsonDict, invite_event_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
txn_id = content["txn_id"] txn_id = content["txn_id"]
event_content = content["content"] event_content = content["content"]
@ -285,10 +279,9 @@ class ReplicationRemoteRescindKnockRestServlet(ReplicationEndpoint):
async def _handle_request( # type: ignore[override] async def _handle_request( # type: ignore[override]
self, self,
request: SynapseRequest, request: SynapseRequest,
content: JsonDict,
knock_event_id: str, knock_event_id: str,
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
txn_id = content["txn_id"] txn_id = content["txn_id"]
event_content = content["content"] event_content = content["content"]
@ -347,7 +340,12 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
return {} return {}
async def _handle_request( # type: ignore[override] async def _handle_request( # type: ignore[override]
self, request: Request, room_id: str, user_id: str, change: str self,
request: Request,
content: JsonDict,
room_id: str,
user_id: str,
change: str,
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
logger.info("user membership change: %s in %s", user_id, room_id) logger.info("user membership change: %s in %s", user_id, room_id)

View file

@ -18,7 +18,6 @@ from typing import TYPE_CHECKING, Tuple
from twisted.web.server import Request from twisted.web.server import Request
from synapse.http.server import HttpServer from synapse.http.server import HttpServer
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict, UserID from synapse.types import JsonDict, UserID
@ -56,7 +55,7 @@ class ReplicationBumpPresenceActiveTime(ReplicationEndpoint):
return {} return {}
async def _handle_request( # type: ignore[override] async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str self, request: Request, content: JsonDict, user_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
await self._presence_handler.bump_presence_active_time( await self._presence_handler.bump_presence_active_time(
UserID.from_string(user_id) UserID.from_string(user_id)
@ -107,10 +106,8 @@ class ReplicationPresenceSetState(ReplicationEndpoint):
} }
async def _handle_request( # type: ignore[override] async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str self, request: Request, content: JsonDict, user_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
await self._presence_handler.set_state( await self._presence_handler.set_state(
UserID.from_string(user_id), UserID.from_string(user_id),
content["state"], content["state"],

View file

@ -18,7 +18,6 @@ from typing import TYPE_CHECKING, Tuple
from twisted.web.server import Request from twisted.web.server import Request
from synapse.http.server import HttpServer from synapse.http.server import HttpServer
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict from synapse.types import JsonDict
@ -61,10 +60,8 @@ class ReplicationRemovePusherRestServlet(ReplicationEndpoint):
return payload return payload
async def _handle_request( # type: ignore[override] async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str self, request: Request, content: JsonDict, user_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
app_id = content["app_id"] app_id = content["app_id"]
pushkey = content["pushkey"] pushkey = content["pushkey"]

View file

@ -18,7 +18,6 @@ from typing import TYPE_CHECKING, Optional, Tuple
from twisted.web.server import Request from twisted.web.server import Request
from synapse.http.server import HttpServer from synapse.http.server import HttpServer
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict from synapse.types import JsonDict
@ -96,10 +95,8 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
} }
async def _handle_request( # type: ignore[override] async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str self, request: Request, content: JsonDict, user_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
await self.registration_handler.check_registration_ratelimit(content["address"]) await self.registration_handler.check_registration_ratelimit(content["address"])
# Always default admin users to approved (since it means they were created by # Always default admin users to approved (since it means they were created by
@ -150,10 +147,8 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint):
return {"auth_result": auth_result, "access_token": access_token} return {"auth_result": auth_result, "access_token": access_token}
async def _handle_request( # type: ignore[override] async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str self, request: Request, content: JsonDict, user_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
auth_result = content["auth_result"] auth_result = content["auth_result"]
access_token = content["access_token"] access_token = content["access_token"]

View file

@ -21,7 +21,6 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase, make_event_from_dict from synapse.events import EventBase, make_event_from_dict
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.http.server import HttpServer from synapse.http.server import HttpServer
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict, Requester, UserID from synapse.types import JsonDict, Requester, UserID
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
@ -114,11 +113,9 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
return payload return payload
async def _handle_request( # type: ignore[override] async def _handle_request( # type: ignore[override]
self, request: Request, event_id: str self, request: Request, content: JsonDict, event_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
with Measure(self.clock, "repl_send_event_parse"): with Measure(self.clock, "repl_send_event_parse"):
content = parse_json_object_from_request(request)
event_dict = content["event"] event_dict = content["event"]
room_ver = KNOWN_ROOM_VERSIONS[content["room_version"]] room_ver = KNOWN_ROOM_VERSIONS[content["room_version"]]
internal_metadata = content["internal_metadata"] internal_metadata = content["internal_metadata"]

View file

@ -21,7 +21,6 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase, make_event_from_dict from synapse.events import EventBase, make_event_from_dict
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.http.server import HttpServer from synapse.http.server import HttpServer
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict, Requester, UserID from synapse.types import JsonDict, Requester, UserID
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
@ -114,10 +113,9 @@ class ReplicationSendEventsRestServlet(ReplicationEndpoint):
return payload return payload
async def _handle_request( # type: ignore[override] async def _handle_request( # type: ignore[override]
self, request: Request self, request: Request, payload: JsonDict
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
with Measure(self.clock, "repl_send_events_parse"): with Measure(self.clock, "repl_send_events_parse"):
payload = parse_json_object_from_request(request)
events_and_context = [] events_and_context = []
events = payload["events"] events = payload["events"]

View file

@ -57,7 +57,7 @@ class ReplicationUpdateCurrentStateRestServlet(ReplicationEndpoint):
return {} return {}
async def _handle_request( # type: ignore[override] async def _handle_request( # type: ignore[override]
self, request: Request, room_id: str self, request: Request, content: JsonDict, room_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
writer_instance = self._events_shard_config.get_instance(room_id) writer_instance = self._events_shard_config.get_instance(room_id)
if writer_instance != self._instance_name: if writer_instance != self._instance_name:

View file

@ -54,6 +54,10 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
PATH_ARGS = ("stream_name",) PATH_ARGS = ("stream_name",)
METHOD = "GET" METHOD = "GET"
# We don't want to wait for replication streams to catch up, as this gets
# called in the process of catching replication streams up.
WAIT_FOR_STREAMS = False
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
@ -67,7 +71,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
return {"from_token": from_token, "upto_token": upto_token} return {"from_token": from_token, "upto_token": upto_token}
async def _handle_request( # type: ignore[override] async def _handle_request( # type: ignore[override]
self, request: Request, stream_name: str self, request: Request, content: JsonDict, stream_name: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
stream = self.streams.get(stream_name) stream = self.streams.get(stream_name)
if stream is None: if stream is None:

View file

@ -16,6 +16,7 @@
import logging import logging
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
from twisted.internet import defer
from twisted.internet.defer import Deferred from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IAddress, IConnector from twisted.internet.interfaces import IAddress, IConnector
from twisted.internet.protocol import ReconnectingClientFactory from twisted.internet.protocol import ReconnectingClientFactory
@ -314,10 +315,21 @@ class ReplicationDataHandler:
self.send_handler.wake_destination(server) self.send_handler.wake_destination(server)
async def wait_for_stream_position( async def wait_for_stream_position(
self, instance_name: str, stream_name: str, position: int self,
instance_name: str,
stream_name: str,
position: int,
raise_on_timeout: bool = True,
) -> None: ) -> None:
"""Wait until this instance has received updates up to and including """Wait until this instance has received updates up to and including
the given stream position. the given stream position.
Args:
instance_name
stream_name
position
raise_on_timeout: Whether to raise an exception if we time out
waiting for the updates, or if we log an error and return.
""" """
if instance_name == self._instance_name: if instance_name == self._instance_name:
@ -345,7 +357,16 @@ class ReplicationDataHandler:
# We measure here to get in flight counts and average waiting time. # We measure here to get in flight counts and average waiting time.
with Measure(self._clock, "repl.wait_for_stream_position"): with Measure(self._clock, "repl.wait_for_stream_position"):
logger.info("Waiting for repl stream %r to reach %s", stream_name, position) logger.info("Waiting for repl stream %r to reach %s", stream_name, position)
try:
await make_deferred_yieldable(deferred) await make_deferred_yieldable(deferred)
except defer.TimeoutError:
logger.error("Timed out waiting for stream %s", stream_name)
if raise_on_timeout:
raise
return
logger.info( logger.info(
"Finished waiting for repl stream %r to reach %s", stream_name, position "Finished waiting for repl stream %r to reach %s", stream_name, position
) )

View file

@ -199,11 +199,6 @@ class ReplicationStreamer:
# The token has advanced but there is no data to # The token has advanced but there is no data to
# send, so we send a `POSITION` to inform other # send, so we send a `POSITION` to inform other
# workers of the updated position. # workers of the updated position.
if stream.NAME == EventsStream.NAME:
# XXX: We only do this for the EventStream as it
# turns out that e.g. account data streams share
# their "current token" with each other, meaning
# that it is *not* safe to send a POSITION.
# Note: `last_token` may not *actually* be the # Note: `last_token` may not *actually* be the
# last token we sent out in a RDATA or POSITION. # last token we sent out in a RDATA or POSITION.

View file

@ -378,6 +378,12 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
self._current_positions.values(), default=1 self._current_positions.values(), default=1
) )
if not writers:
# If there have been no explicit writers given then any instance can
# write to the stream. In which case, let's pre-seed our own
# position with the current minimum.
self._current_positions[self._instance_name] = self._persisted_upto_position
def _load_current_ids( def _load_current_ids(
self, self,
db_conn: LoggingDatabaseConnection, db_conn: LoggingDatabaseConnection,
@ -695,24 +701,22 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
heapq.heappush(self._known_persisted_positions, new_id) heapq.heappush(self._known_persisted_positions, new_id)
# If we're a writer and we don't have any active writes we update our
# current position to the latest position seen. This allows the instance
# to report a recent position when asked, rather than a potentially old
# one (if this instance hasn't written anything for a while).
our_current_position = self._current_positions.get(self._instance_name)
if (
our_current_position
and not self._unfinished_ids
and not self._in_flight_fetches
):
self._current_positions[self._instance_name] = max(
our_current_position, new_id
)
# We move the current min position up if the minimum current positions # We move the current min position up if the minimum current positions
# of all instances is higher (since by definition all positions less # of all instances is higher (since by definition all positions less
# that that have been persisted). # that that have been persisted).
min_curr = min(self._current_positions.values(), default=0) our_current_position = self._current_positions.get(self._instance_name, 0)
min_curr = min(
(
token
for name, token in self._current_positions.items()
if name != self._instance_name
),
default=our_current_position,
)
if our_current_position and (self._unfinished_ids or self._in_flight_fetches):
min_curr = min(min_curr, our_current_position)
self._persisted_upto_position = max(min_curr, self._persisted_upto_position) self._persisted_upto_position = max(min_curr, self._persisted_upto_position)
# We now iterate through the seen positions, discarding those that are # We now iterate through the seen positions, discarding those that are

View file

@ -604,6 +604,12 @@ class RoomStreamToken:
elif self.instance_map: elif self.instance_map:
entries = [] entries = []
for name, pos in self.instance_map.items(): for name, pos in self.instance_map.items():
if pos <= self.stream:
# Ignore instances who are below the minimum stream position
# (we might know they've advanced without seeing a recent
# write from them).
continue
instance_id = await store.get_id_for_instance(name) instance_id = await store.get_id_for_instance(name)
entries.append(f"{instance_id}.{pos}") entries.append(f"{instance_id}.{pos}")

View file

@ -44,7 +44,7 @@ class CancellableReplicationEndpoint(ReplicationEndpoint):
@cancellable @cancellable
async def _handle_request( # type: ignore[override] async def _handle_request( # type: ignore[override]
self, request: Request self, request: Request, content: JsonDict
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
await self.clock.sleep(1.0) await self.clock.sleep(1.0)
return HTTPStatus.OK, {"result": True} return HTTPStatus.OK, {"result": True}
@ -54,6 +54,7 @@ class UncancellableReplicationEndpoint(ReplicationEndpoint):
NAME = "uncancellable_sleep" NAME = "uncancellable_sleep"
PATH_ARGS = () PATH_ARGS = ()
CACHE = False CACHE = False
WAIT_FOR_STREAMS = False
def __init__(self, hs: HomeServer): def __init__(self, hs: HomeServer):
super().__init__(hs) super().__init__(hs)
@ -64,7 +65,7 @@ class UncancellableReplicationEndpoint(ReplicationEndpoint):
return {} return {}
async def _handle_request( # type: ignore[override] async def _handle_request( # type: ignore[override]
self, request: Request self, request: Request, content: JsonDict
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
await self.clock.sleep(1.0) await self.clock.sleep(1.0)
return HTTPStatus.OK, {"result": True} return HTTPStatus.OK, {"result": True}
@ -85,7 +86,7 @@ class ReplicationEndpointCancellationTestCase(unittest.HomeserverTestCase):
def test_cancellable_disconnect(self) -> None: def test_cancellable_disconnect(self) -> None:
"""Test that handlers with the `@cancellable` flag can be cancelled.""" """Test that handlers with the `@cancellable` flag can be cancelled."""
path = f"{REPLICATION_PREFIX}/{CancellableReplicationEndpoint.NAME}/" path = f"{REPLICATION_PREFIX}/{CancellableReplicationEndpoint.NAME}/"
channel = self.make_request("POST", path, await_result=False) channel = self.make_request("POST", path, await_result=False, content={})
test_disconnect( test_disconnect(
self.reactor, self.reactor,
channel, channel,
@ -96,7 +97,7 @@ class ReplicationEndpointCancellationTestCase(unittest.HomeserverTestCase):
def test_uncancellable_disconnect(self) -> None: def test_uncancellable_disconnect(self) -> None:
"""Test that handlers without the `@cancellable` flag cannot be cancelled.""" """Test that handlers without the `@cancellable` flag cannot be cancelled."""
path = f"{REPLICATION_PREFIX}/{UncancellableReplicationEndpoint.NAME}/" path = f"{REPLICATION_PREFIX}/{UncancellableReplicationEndpoint.NAME}/"
channel = self.make_request("POST", path, await_result=False) channel = self.make_request("POST", path, await_result=False, content={})
test_disconnect( test_disconnect(
self.reactor, self.reactor,
channel, channel,

View file

@ -349,8 +349,8 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# The first ID gen will notice that it can advance its token to 7 as it # The first ID gen will notice that it can advance its token to 7 as it
# has no in progress writes... # has no in progress writes...
self.assertEqual(first_id_gen.get_positions(), {"first": 7, "second": 7}) self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7) self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3)
self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7) self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
# ... but the second ID gen doesn't know that. # ... but the second ID gen doesn't know that.
@ -366,8 +366,9 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(stream_id, 8) self.assertEqual(stream_id, 8)
self.assertEqual( self.assertEqual(
first_id_gen.get_positions(), {"first": 7, "second": 7} first_id_gen.get_positions(), {"first": 3, "second": 7}
) )
self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)
self.get_success(_get_next_async()) self.get_success(_get_next_async())
@ -473,7 +474,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
id_gen = self._create_id_generator("first", writers=["first", "second"]) id_gen = self._create_id_generator("first", writers=["first", "second"])
self.assertEqual(id_gen.get_positions(), {"first": 5, "second": 5}) self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
self.assertEqual(id_gen.get_persisted_upto_position(), 5) self.assertEqual(id_gen.get_persisted_upto_position(), 5)
@ -720,7 +721,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.get_success(_get_next_async2()) self.get_success(_get_next_async2())
self.assertEqual(id_gen_1.get_positions(), {"first": -2, "second": -2}) self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2})
self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2}) self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2})
self.assertEqual(id_gen_1.get_persisted_upto_position(), -2) self.assertEqual(id_gen_1.get_persisted_upto_position(), -2)
self.assertEqual(id_gen_2.get_persisted_upto_position(), -2) self.assertEqual(id_gen_2.get_persisted_upto_position(), -2)
@ -816,15 +817,12 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
first_id_gen = self._create_id_generator("first", writers=["first", "second"]) first_id_gen = self._create_id_generator("first", writers=["first", "second"])
second_id_gen = self._create_id_generator("second", writers=["first", "second"]) second_id_gen = self._create_id_generator("second", writers=["first", "second"])
# The first ID gen will notice that it can advance its token to 7 as it self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 6})
# has no in progress writes... self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3)
self.assertEqual(first_id_gen.get_positions(), {"first": 7, "second": 6})
self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 6) self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 6)
self.assertEqual(first_id_gen.get_persisted_upto_position(), 7) self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)
# ... but the second ID gen doesn't know that.
self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7}) self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 3) self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 3)
self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7) self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
self.assertEqual(first_id_gen.get_persisted_upto_position(), 7) self.assertEqual(second_id_gen.get_persisted_upto_position(), 7)