mirror of
https://mau.dev/maunium/synapse.git
synced 2024-11-15 22:42:23 +01:00
Wait for streams to catch up when processing HTTP replication. (#14820)
This should hopefully mitigate a class of races where data gets out of sync due a HTTP replication request racing with the replication streams.
This commit is contained in:
parent
e8f2bf5c40
commit
9187fd940e
21 changed files with 225 additions and 144 deletions
1
changelog.d/14820.bugfix
Normal file
1
changelog.d/14820.bugfix
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Fix rare races when using workers.
|
|
@ -2259,6 +2259,10 @@ class FederationEventHandler:
|
||||||
event_and_contexts, backfilled=backfilled
|
event_and_contexts, backfilled=backfilled
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# After persistence we always need to notify replication there may
|
||||||
|
# be new data.
|
||||||
|
self._notifier.notify_replication()
|
||||||
|
|
||||||
if self._ephemeral_messages_enabled:
|
if self._ephemeral_messages_enabled:
|
||||||
for event in events:
|
for event in events:
|
||||||
# If there's an expiry timestamp on the event, schedule its expiry.
|
# If there's an expiry timestamp on the event, schedule its expiry.
|
||||||
|
|
|
@ -17,7 +17,7 @@ import logging
|
||||||
import re
|
import re
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Tuple
|
from typing import TYPE_CHECKING, Any, Awaitable, Callable, ClassVar, Dict, List, Tuple
|
||||||
|
|
||||||
from prometheus_client import Counter, Gauge
|
from prometheus_client import Counter, Gauge
|
||||||
|
|
||||||
|
@ -27,6 +27,7 @@ from twisted.web.server import Request
|
||||||
from synapse.api.errors import HttpResponseException, SynapseError
|
from synapse.api.errors import HttpResponseException, SynapseError
|
||||||
from synapse.http import RequestTimedOutError
|
from synapse.http import RequestTimedOutError
|
||||||
from synapse.http.server import HttpServer
|
from synapse.http.server import HttpServer
|
||||||
|
from synapse.http.servlet import parse_json_object_from_request
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.logging import opentracing
|
from synapse.logging import opentracing
|
||||||
from synapse.logging.opentracing import trace_with_opname
|
from synapse.logging.opentracing import trace_with_opname
|
||||||
|
@ -53,6 +54,9 @@ _outgoing_request_counter = Counter(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_STREAM_POSITION_KEY = "_INT_STREAM_POS"
|
||||||
|
|
||||||
|
|
||||||
class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
||||||
"""Helper base class for defining new replication HTTP endpoints.
|
"""Helper base class for defining new replication HTTP endpoints.
|
||||||
|
|
||||||
|
@ -94,6 +98,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
||||||
a connection error is received.
|
a connection error is received.
|
||||||
RETRY_ON_CONNECT_ERROR_ATTEMPTS (int): Number of attempts to retry when
|
RETRY_ON_CONNECT_ERROR_ATTEMPTS (int): Number of attempts to retry when
|
||||||
receiving connection errors, each will backoff exponentially longer.
|
receiving connection errors, each will backoff exponentially longer.
|
||||||
|
WAIT_FOR_STREAMS (bool): Whether to wait for replication streams to
|
||||||
|
catch up before processing the request and/or response. Defaults to
|
||||||
|
True.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
NAME: str = abc.abstractproperty() # type: ignore
|
NAME: str = abc.abstractproperty() # type: ignore
|
||||||
|
@ -104,6 +111,8 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
||||||
RETRY_ON_CONNECT_ERROR = True
|
RETRY_ON_CONNECT_ERROR = True
|
||||||
RETRY_ON_CONNECT_ERROR_ATTEMPTS = 5 # =63s (2^6-1)
|
RETRY_ON_CONNECT_ERROR_ATTEMPTS = 5 # =63s (2^6-1)
|
||||||
|
|
||||||
|
WAIT_FOR_STREAMS: ClassVar[bool] = True
|
||||||
|
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
if self.CACHE:
|
if self.CACHE:
|
||||||
self.response_cache: ResponseCache[str] = ResponseCache(
|
self.response_cache: ResponseCache[str] = ResponseCache(
|
||||||
|
@ -126,6 +135,10 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
||||||
if hs.config.worker.worker_replication_secret:
|
if hs.config.worker.worker_replication_secret:
|
||||||
self._replication_secret = hs.config.worker.worker_replication_secret
|
self._replication_secret = hs.config.worker.worker_replication_secret
|
||||||
|
|
||||||
|
self._streams = hs.get_replication_command_handler().get_streams_to_replicate()
|
||||||
|
self._replication = hs.get_replication_data_handler()
|
||||||
|
self._instance_name = hs.get_instance_name()
|
||||||
|
|
||||||
def _check_auth(self, request: Request) -> None:
|
def _check_auth(self, request: Request) -> None:
|
||||||
# Get the authorization header.
|
# Get the authorization header.
|
||||||
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
|
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
|
||||||
|
@ -160,7 +173,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def _handle_request(
|
async def _handle_request(
|
||||||
self, request: Request, **kwargs: Any
|
self, request: Request, content: JsonDict, **kwargs: Any
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
"""Handle incoming request.
|
"""Handle incoming request.
|
||||||
|
|
||||||
|
@ -201,6 +214,10 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
||||||
|
|
||||||
@trace_with_opname("outgoing_replication_request")
|
@trace_with_opname("outgoing_replication_request")
|
||||||
async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any:
|
async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any:
|
||||||
|
# We have to pull these out here to avoid circular dependencies...
|
||||||
|
streams = hs.get_replication_command_handler().get_streams_to_replicate()
|
||||||
|
replication = hs.get_replication_data_handler()
|
||||||
|
|
||||||
with outgoing_gauge.track_inprogress():
|
with outgoing_gauge.track_inprogress():
|
||||||
if instance_name == local_instance_name:
|
if instance_name == local_instance_name:
|
||||||
raise Exception("Trying to send HTTP request to self")
|
raise Exception("Trying to send HTTP request to self")
|
||||||
|
@ -219,6 +236,24 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
||||||
|
|
||||||
data = await cls._serialize_payload(**kwargs)
|
data = await cls._serialize_payload(**kwargs)
|
||||||
|
|
||||||
|
if cls.METHOD != "GET" and cls.WAIT_FOR_STREAMS:
|
||||||
|
# Include the current stream positions that we write to. We
|
||||||
|
# don't do this for GETs as they don't have a body, and we
|
||||||
|
# generally assume that a GET won't rely on data we have
|
||||||
|
# written.
|
||||||
|
if _STREAM_POSITION_KEY in data:
|
||||||
|
raise Exception(
|
||||||
|
"data to send contains %r key", _STREAM_POSITION_KEY
|
||||||
|
)
|
||||||
|
|
||||||
|
data[_STREAM_POSITION_KEY] = {
|
||||||
|
"streams": {
|
||||||
|
stream.NAME: stream.current_token(local_instance_name)
|
||||||
|
for stream in streams
|
||||||
|
},
|
||||||
|
"instance_name": local_instance_name,
|
||||||
|
}
|
||||||
|
|
||||||
url_args = [
|
url_args = [
|
||||||
urllib.parse.quote(kwargs[name], safe="") for name in cls.PATH_ARGS
|
urllib.parse.quote(kwargs[name], safe="") for name in cls.PATH_ARGS
|
||||||
]
|
]
|
||||||
|
@ -308,6 +343,18 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
_outgoing_request_counter.labels(cls.NAME, 200).inc()
|
_outgoing_request_counter.labels(cls.NAME, 200).inc()
|
||||||
|
|
||||||
|
# Wait on any streams that the remote may have written to.
|
||||||
|
for stream_name, position in result.get(
|
||||||
|
_STREAM_POSITION_KEY, {}
|
||||||
|
).items():
|
||||||
|
await replication.wait_for_stream_position(
|
||||||
|
instance_name=instance_name,
|
||||||
|
stream_name=stream_name,
|
||||||
|
position=position,
|
||||||
|
raise_on_timeout=False,
|
||||||
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
return send_request
|
return send_request
|
||||||
|
@ -353,6 +400,23 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
||||||
if self._replication_secret:
|
if self._replication_secret:
|
||||||
self._check_auth(request)
|
self._check_auth(request)
|
||||||
|
|
||||||
|
if self.METHOD == "GET":
|
||||||
|
# GET APIs always have an empty body.
|
||||||
|
content = {}
|
||||||
|
else:
|
||||||
|
content = parse_json_object_from_request(request)
|
||||||
|
|
||||||
|
# Wait on any streams that the remote may have written to.
|
||||||
|
for stream_name, position in content.get(_STREAM_POSITION_KEY, {"streams": {}})[
|
||||||
|
"streams"
|
||||||
|
].items():
|
||||||
|
await self._replication.wait_for_stream_position(
|
||||||
|
instance_name=content[_STREAM_POSITION_KEY]["instance_name"],
|
||||||
|
stream_name=stream_name,
|
||||||
|
position=position,
|
||||||
|
raise_on_timeout=False,
|
||||||
|
)
|
||||||
|
|
||||||
if self.CACHE:
|
if self.CACHE:
|
||||||
txn_id = kwargs.pop("txn_id")
|
txn_id = kwargs.pop("txn_id")
|
||||||
|
|
||||||
|
@ -361,13 +425,28 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
||||||
# correctly yet. In particular, there may be issues to do with logging
|
# correctly yet. In particular, there may be issues to do with logging
|
||||||
# context lifetimes.
|
# context lifetimes.
|
||||||
|
|
||||||
return await self.response_cache.wrap(
|
code, response = await self.response_cache.wrap(
|
||||||
txn_id, self._handle_request, request, **kwargs
|
txn_id, self._handle_request, request, content, **kwargs
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
# The `@cancellable` decorator may be applied to `_handle_request`. But we
|
# The `@cancellable` decorator may be applied to `_handle_request`. But we
|
||||||
# told `HttpServer.register_paths` that our handler is `_check_auth_and_handle`,
|
# told `HttpServer.register_paths` that our handler is `_check_auth_and_handle`,
|
||||||
# so we have to set up the cancellable flag ourselves.
|
# so we have to set up the cancellable flag ourselves.
|
||||||
request.is_render_cancellable = is_function_cancellable(self._handle_request)
|
request.is_render_cancellable = is_function_cancellable(
|
||||||
|
self._handle_request
|
||||||
|
)
|
||||||
|
|
||||||
return await self._handle_request(request, **kwargs)
|
code, response = await self._handle_request(request, content, **kwargs)
|
||||||
|
|
||||||
|
# Return streams we may have written to in the course of processing this
|
||||||
|
# request.
|
||||||
|
if _STREAM_POSITION_KEY in response:
|
||||||
|
raise Exception("data to send contains %r key", _STREAM_POSITION_KEY)
|
||||||
|
|
||||||
|
if self.WAIT_FOR_STREAMS:
|
||||||
|
response[_STREAM_POSITION_KEY] = {
|
||||||
|
stream.NAME: stream.current_token(self._instance_name)
|
||||||
|
for stream in self._streams
|
||||||
|
}
|
||||||
|
|
||||||
|
return code, response
|
||||||
|
|
|
@ -18,7 +18,6 @@ from typing import TYPE_CHECKING, Tuple
|
||||||
from twisted.web.server import Request
|
from twisted.web.server import Request
|
||||||
|
|
||||||
from synapse.http.server import HttpServer
|
from synapse.http.server import HttpServer
|
||||||
from synapse.http.servlet import parse_json_object_from_request
|
|
||||||
from synapse.replication.http._base import ReplicationEndpoint
|
from synapse.replication.http._base import ReplicationEndpoint
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
|
|
||||||
|
@ -61,10 +60,8 @@ class ReplicationAddUserAccountDataRestServlet(ReplicationEndpoint):
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
async def _handle_request( # type: ignore[override]
|
async def _handle_request( # type: ignore[override]
|
||||||
self, request: Request, user_id: str, account_data_type: str
|
self, request: Request, content: JsonDict, user_id: str, account_data_type: str
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
content = parse_json_object_from_request(request)
|
|
||||||
|
|
||||||
max_stream_id = await self.handler.add_account_data_for_user(
|
max_stream_id = await self.handler.add_account_data_for_user(
|
||||||
user_id, account_data_type, content["content"]
|
user_id, account_data_type, content["content"]
|
||||||
)
|
)
|
||||||
|
@ -101,7 +98,7 @@ class ReplicationRemoveUserAccountDataRestServlet(ReplicationEndpoint):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
async def _handle_request( # type: ignore[override]
|
async def _handle_request( # type: ignore[override]
|
||||||
self, request: Request, user_id: str, account_data_type: str
|
self, request: Request, content: JsonDict, user_id: str, account_data_type: str
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
max_stream_id = await self.handler.remove_account_data_for_user(
|
max_stream_id = await self.handler.remove_account_data_for_user(
|
||||||
user_id, account_data_type
|
user_id, account_data_type
|
||||||
|
@ -143,10 +140,13 @@ class ReplicationAddRoomAccountDataRestServlet(ReplicationEndpoint):
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
async def _handle_request( # type: ignore[override]
|
async def _handle_request( # type: ignore[override]
|
||||||
self, request: Request, user_id: str, room_id: str, account_data_type: str
|
self,
|
||||||
|
request: Request,
|
||||||
|
content: JsonDict,
|
||||||
|
user_id: str,
|
||||||
|
room_id: str,
|
||||||
|
account_data_type: str,
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
content = parse_json_object_from_request(request)
|
|
||||||
|
|
||||||
max_stream_id = await self.handler.add_account_data_to_room(
|
max_stream_id = await self.handler.add_account_data_to_room(
|
||||||
user_id, room_id, account_data_type, content["content"]
|
user_id, room_id, account_data_type, content["content"]
|
||||||
)
|
)
|
||||||
|
@ -183,7 +183,12 @@ class ReplicationRemoveRoomAccountDataRestServlet(ReplicationEndpoint):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
async def _handle_request( # type: ignore[override]
|
async def _handle_request( # type: ignore[override]
|
||||||
self, request: Request, user_id: str, room_id: str, account_data_type: str
|
self,
|
||||||
|
request: Request,
|
||||||
|
content: JsonDict,
|
||||||
|
user_id: str,
|
||||||
|
room_id: str,
|
||||||
|
account_data_type: str,
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
max_stream_id = await self.handler.remove_account_data_for_room(
|
max_stream_id = await self.handler.remove_account_data_for_room(
|
||||||
user_id, room_id, account_data_type
|
user_id, room_id, account_data_type
|
||||||
|
@ -225,10 +230,8 @@ class ReplicationAddTagRestServlet(ReplicationEndpoint):
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
async def _handle_request( # type: ignore[override]
|
async def _handle_request( # type: ignore[override]
|
||||||
self, request: Request, user_id: str, room_id: str, tag: str
|
self, request: Request, content: JsonDict, user_id: str, room_id: str, tag: str
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
content = parse_json_object_from_request(request)
|
|
||||||
|
|
||||||
max_stream_id = await self.handler.add_tag_to_room(
|
max_stream_id = await self.handler.add_tag_to_room(
|
||||||
user_id, room_id, tag, content["content"]
|
user_id, room_id, tag, content["content"]
|
||||||
)
|
)
|
||||||
|
@ -266,7 +269,7 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
async def _handle_request( # type: ignore[override]
|
async def _handle_request( # type: ignore[override]
|
||||||
self, request: Request, user_id: str, room_id: str, tag: str
|
self, request: Request, content: JsonDict, user_id: str, room_id: str, tag: str
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
max_stream_id = await self.handler.remove_tag_from_room(
|
max_stream_id = await self.handler.remove_tag_from_room(
|
||||||
user_id,
|
user_id,
|
||||||
|
|
|
@ -18,7 +18,6 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||||
from twisted.web.server import Request
|
from twisted.web.server import Request
|
||||||
|
|
||||||
from synapse.http.server import HttpServer
|
from synapse.http.server import HttpServer
|
||||||
from synapse.http.servlet import parse_json_object_from_request
|
|
||||||
from synapse.logging.opentracing import active_span
|
from synapse.logging.opentracing import active_span
|
||||||
from synapse.replication.http._base import ReplicationEndpoint
|
from synapse.replication.http._base import ReplicationEndpoint
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
|
@ -78,7 +77,7 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
async def _handle_request( # type: ignore[override]
|
async def _handle_request( # type: ignore[override]
|
||||||
self, request: Request, user_id: str
|
self, request: Request, content: JsonDict, user_id: str
|
||||||
) -> Tuple[int, Optional[JsonDict]]:
|
) -> Tuple[int, Optional[JsonDict]]:
|
||||||
user_devices = await self.device_list_updater.user_device_resync(user_id)
|
user_devices = await self.device_list_updater.user_device_resync(user_id)
|
||||||
|
|
||||||
|
@ -138,9 +137,8 @@ class ReplicationMultiUserDevicesResyncRestServlet(ReplicationEndpoint):
|
||||||
return {"user_ids": user_ids}
|
return {"user_ids": user_ids}
|
||||||
|
|
||||||
async def _handle_request( # type: ignore[override]
|
async def _handle_request( # type: ignore[override]
|
||||||
self, request: Request
|
self, request: Request, content: JsonDict
|
||||||
) -> Tuple[int, Dict[str, Optional[JsonDict]]]:
|
) -> Tuple[int, Dict[str, Optional[JsonDict]]]:
|
||||||
content = parse_json_object_from_request(request)
|
|
||||||
user_ids: List[str] = content["user_ids"]
|
user_ids: List[str] = content["user_ids"]
|
||||||
|
|
||||||
logger.info("Resync for %r", user_ids)
|
logger.info("Resync for %r", user_ids)
|
||||||
|
@ -205,10 +203,8 @@ class ReplicationUploadKeysForUserRestServlet(ReplicationEndpoint):
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _handle_request( # type: ignore[override]
|
async def _handle_request( # type: ignore[override]
|
||||||
self, request: Request
|
self, request: Request, content: JsonDict
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
content = parse_json_object_from_request(request)
|
|
||||||
|
|
||||||
user_id = content["user_id"]
|
user_id = content["user_id"]
|
||||||
device_id = content["device_id"]
|
device_id = content["device_id"]
|
||||||
keys = content["keys"]
|
keys = content["keys"]
|
||||||
|
|
|
@ -21,7 +21,6 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
|
||||||
from synapse.events import EventBase, make_event_from_dict
|
from synapse.events import EventBase, make_event_from_dict
|
||||||
from synapse.events.snapshot import EventContext
|
from synapse.events.snapshot import EventContext
|
||||||
from synapse.http.server import HttpServer
|
from synapse.http.server import HttpServer
|
||||||
from synapse.http.servlet import parse_json_object_from_request
|
|
||||||
from synapse.replication.http._base import ReplicationEndpoint
|
from synapse.replication.http._base import ReplicationEndpoint
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
|
@ -114,10 +113,8 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
|
||||||
|
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
async def _handle_request(self, request: Request) -> Tuple[int, JsonDict]: # type: ignore[override]
|
async def _handle_request(self, request: Request, content: JsonDict) -> Tuple[int, JsonDict]: # type: ignore[override]
|
||||||
with Measure(self.clock, "repl_fed_send_events_parse"):
|
with Measure(self.clock, "repl_fed_send_events_parse"):
|
||||||
content = parse_json_object_from_request(request)
|
|
||||||
|
|
||||||
room_id = content["room_id"]
|
room_id = content["room_id"]
|
||||||
backfilled = content["backfilled"]
|
backfilled = content["backfilled"]
|
||||||
|
|
||||||
|
@ -181,11 +178,8 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
|
||||||
return {"origin": origin, "content": content}
|
return {"origin": origin, "content": content}
|
||||||
|
|
||||||
async def _handle_request( # type: ignore[override]
|
async def _handle_request( # type: ignore[override]
|
||||||
self, request: Request, edu_type: str
|
self, request: Request, content: JsonDict, edu_type: str
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
with Measure(self.clock, "repl_fed_send_edu_parse"):
|
|
||||||
content = parse_json_object_from_request(request)
|
|
||||||
|
|
||||||
origin = content["origin"]
|
origin = content["origin"]
|
||||||
edu_content = content["content"]
|
edu_content = content["content"]
|
||||||
|
|
||||||
|
@ -231,11 +225,8 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint):
|
||||||
return {"args": args}
|
return {"args": args}
|
||||||
|
|
||||||
async def _handle_request( # type: ignore[override]
|
async def _handle_request( # type: ignore[override]
|
||||||
self, request: Request, query_type: str
|
self, request: Request, content: JsonDict, query_type: str
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
with Measure(self.clock, "repl_fed_query_parse"):
|
|
||||||
content = parse_json_object_from_request(request)
|
|
||||||
|
|
||||||
args = content["args"]
|
args = content["args"]
|
||||||
args["origin"] = content["origin"]
|
args["origin"] = content["origin"]
|
||||||
|
|
||||||
|
@ -274,7 +265,7 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
async def _handle_request( # type: ignore[override]
|
async def _handle_request( # type: ignore[override]
|
||||||
self, request: Request, room_id: str
|
self, request: Request, content: JsonDict, room_id: str
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
await self.store.clean_room_for_join(room_id)
|
await self.store.clean_room_for_join(room_id)
|
||||||
|
|
||||||
|
@ -307,9 +298,8 @@ class ReplicationStoreRoomOnOutlierMembershipRestServlet(ReplicationEndpoint):
|
||||||
return {"room_version": room_version.identifier}
|
return {"room_version": room_version.identifier}
|
||||||
|
|
||||||
async def _handle_request( # type: ignore[override]
|
async def _handle_request( # type: ignore[override]
|
||||||
self, request: Request, room_id: str
|
self, request: Request, content: JsonDict, room_id: str
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
content = parse_json_object_from_request(request)
|
|
||||||
room_version = KNOWN_ROOM_VERSIONS[content["room_version"]]
|
room_version = KNOWN_ROOM_VERSIONS[content["room_version"]]
|
||||||
await self.store.maybe_store_room_on_outlier_membership(room_id, room_version)
|
await self.store.maybe_store_room_on_outlier_membership(room_id, room_version)
|
||||||
return 200, {}
|
return 200, {}
|
||||||
|
|
|
@ -18,7 +18,6 @@ from typing import TYPE_CHECKING, Optional, Tuple, cast
|
||||||
from twisted.web.server import Request
|
from twisted.web.server import Request
|
||||||
|
|
||||||
from synapse.http.server import HttpServer
|
from synapse.http.server import HttpServer
|
||||||
from synapse.http.servlet import parse_json_object_from_request
|
|
||||||
from synapse.replication.http._base import ReplicationEndpoint
|
from synapse.replication.http._base import ReplicationEndpoint
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
|
|
||||||
|
@ -73,10 +72,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _handle_request( # type: ignore[override]
|
async def _handle_request( # type: ignore[override]
|
||||||
self, request: Request, user_id: str
|
self, request: Request, content: JsonDict, user_id: str
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
content = parse_json_object_from_request(request)
|
|
||||||
|
|
||||||
device_id = content["device_id"]
|
device_id = content["device_id"]
|
||||||
initial_display_name = content["initial_display_name"]
|
initial_display_name = content["initial_display_name"]
|
||||||
is_guest = content["is_guest"]
|
is_guest = content["is_guest"]
|
||||||
|
|
|
@ -17,7 +17,6 @@ from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||||
from twisted.web.server import Request
|
from twisted.web.server import Request
|
||||||
|
|
||||||
from synapse.http.server import HttpServer
|
from synapse.http.server import HttpServer
|
||||||
from synapse.http.servlet import parse_json_object_from_request
|
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.replication.http._base import ReplicationEndpoint
|
from synapse.replication.http._base import ReplicationEndpoint
|
||||||
from synapse.types import JsonDict, Requester, UserID
|
from synapse.types import JsonDict, Requester, UserID
|
||||||
|
@ -79,10 +78,8 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _handle_request( # type: ignore[override]
|
async def _handle_request( # type: ignore[override]
|
||||||
self, request: SynapseRequest, room_id: str, user_id: str
|
self, request: SynapseRequest, content: JsonDict, room_id: str, user_id: str
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
content = parse_json_object_from_request(request)
|
|
||||||
|
|
||||||
remote_room_hosts = content["remote_room_hosts"]
|
remote_room_hosts = content["remote_room_hosts"]
|
||||||
event_content = content["content"]
|
event_content = content["content"]
|
||||||
|
|
||||||
|
@ -147,11 +144,10 @@ class ReplicationRemoteKnockRestServlet(ReplicationEndpoint):
|
||||||
async def _handle_request( # type: ignore[override]
|
async def _handle_request( # type: ignore[override]
|
||||||
self,
|
self,
|
||||||
request: SynapseRequest,
|
request: SynapseRequest,
|
||||||
|
content: JsonDict,
|
||||||
room_id: str,
|
room_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
content = parse_json_object_from_request(request)
|
|
||||||
|
|
||||||
remote_room_hosts = content["remote_room_hosts"]
|
remote_room_hosts = content["remote_room_hosts"]
|
||||||
event_content = content["content"]
|
event_content = content["content"]
|
||||||
|
|
||||||
|
@ -217,10 +213,8 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _handle_request( # type: ignore[override]
|
async def _handle_request( # type: ignore[override]
|
||||||
self, request: SynapseRequest, invite_event_id: str
|
self, request: SynapseRequest, content: JsonDict, invite_event_id: str
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
content = parse_json_object_from_request(request)
|
|
||||||
|
|
||||||
txn_id = content["txn_id"]
|
txn_id = content["txn_id"]
|
||||||
event_content = content["content"]
|
event_content = content["content"]
|
||||||
|
|
||||||
|
@ -285,10 +279,9 @@ class ReplicationRemoteRescindKnockRestServlet(ReplicationEndpoint):
|
||||||
async def _handle_request( # type: ignore[override]
|
async def _handle_request( # type: ignore[override]
|
||||||
self,
|
self,
|
||||||
request: SynapseRequest,
|
request: SynapseRequest,
|
||||||
|
content: JsonDict,
|
||||||
knock_event_id: str,
|
knock_event_id: str,
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
content = parse_json_object_from_request(request)
|
|
||||||
|
|
||||||
txn_id = content["txn_id"]
|
txn_id = content["txn_id"]
|
||||||
event_content = content["content"]
|
event_content = content["content"]
|
||||||
|
|
||||||
|
@ -347,7 +340,12 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
async def _handle_request( # type: ignore[override]
|
async def _handle_request( # type: ignore[override]
|
||||||
self, request: Request, room_id: str, user_id: str, change: str
|
self,
|
||||||
|
request: Request,
|
||||||
|
content: JsonDict,
|
||||||
|
room_id: str,
|
||||||
|
user_id: str,
|
||||||
|
change: str,
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
logger.info("user membership change: %s in %s", user_id, room_id)
|
logger.info("user membership change: %s in %s", user_id, room_id)
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,6 @@ from typing import TYPE_CHECKING, Tuple
|
||||||
from twisted.web.server import Request
|
from twisted.web.server import Request
|
||||||
|
|
||||||
from synapse.http.server import HttpServer
|
from synapse.http.server import HttpServer
|
||||||
from synapse.http.servlet import parse_json_object_from_request
|
|
||||||
from synapse.replication.http._base import ReplicationEndpoint
|
from synapse.replication.http._base import ReplicationEndpoint
|
||||||
from synapse.types import JsonDict, UserID
|
from synapse.types import JsonDict, UserID
|
||||||
|
|
||||||
|
@ -56,7 +55,7 @@ class ReplicationBumpPresenceActiveTime(ReplicationEndpoint):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
async def _handle_request( # type: ignore[override]
|
async def _handle_request( # type: ignore[override]
|
||||||
self, request: Request, user_id: str
|
self, request: Request, content: JsonDict, user_id: str
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
await self._presence_handler.bump_presence_active_time(
|
await self._presence_handler.bump_presence_active_time(
|
||||||
UserID.from_string(user_id)
|
UserID.from_string(user_id)
|
||||||
|
@ -107,10 +106,8 @@ class ReplicationPresenceSetState(ReplicationEndpoint):
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _handle_request( # type: ignore[override]
|
async def _handle_request( # type: ignore[override]
|
||||||
self, request: Request, user_id: str
|
self, request: Request, content: JsonDict, user_id: str
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
content = parse_json_object_from_request(request)
|
|
||||||
|
|
||||||
await self._presence_handler.set_state(
|
await self._presence_handler.set_state(
|
||||||
UserID.from_string(user_id),
|
UserID.from_string(user_id),
|
||||||
content["state"],
|
content["state"],
|
||||||
|
|
|
@ -18,7 +18,6 @@ from typing import TYPE_CHECKING, Tuple
|
||||||
from twisted.web.server import Request
|
from twisted.web.server import Request
|
||||||
|
|
||||||
from synapse.http.server import HttpServer
|
from synapse.http.server import HttpServer
|
||||||
from synapse.http.servlet import parse_json_object_from_request
|
|
||||||
from synapse.replication.http._base import ReplicationEndpoint
|
from synapse.replication.http._base import ReplicationEndpoint
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
|
|
||||||
|
@ -61,10 +60,8 @@ class ReplicationRemovePusherRestServlet(ReplicationEndpoint):
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
async def _handle_request( # type: ignore[override]
|
async def _handle_request( # type: ignore[override]
|
||||||
self, request: Request, user_id: str
|
self, request: Request, content: JsonDict, user_id: str
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
content = parse_json_object_from_request(request)
|
|
||||||
|
|
||||||
app_id = content["app_id"]
|
app_id = content["app_id"]
|
||||||
pushkey = content["pushkey"]
|
pushkey = content["pushkey"]
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,6 @@ from typing import TYPE_CHECKING, Optional, Tuple
|
||||||
from twisted.web.server import Request
|
from twisted.web.server import Request
|
||||||
|
|
||||||
from synapse.http.server import HttpServer
|
from synapse.http.server import HttpServer
|
||||||
from synapse.http.servlet import parse_json_object_from_request
|
|
||||||
from synapse.replication.http._base import ReplicationEndpoint
|
from synapse.replication.http._base import ReplicationEndpoint
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
|
|
||||||
|
@ -96,10 +95,8 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _handle_request( # type: ignore[override]
|
async def _handle_request( # type: ignore[override]
|
||||||
self, request: Request, user_id: str
|
self, request: Request, content: JsonDict, user_id: str
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
content = parse_json_object_from_request(request)
|
|
||||||
|
|
||||||
await self.registration_handler.check_registration_ratelimit(content["address"])
|
await self.registration_handler.check_registration_ratelimit(content["address"])
|
||||||
|
|
||||||
# Always default admin users to approved (since it means they were created by
|
# Always default admin users to approved (since it means they were created by
|
||||||
|
@ -150,10 +147,8 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint):
|
||||||
return {"auth_result": auth_result, "access_token": access_token}
|
return {"auth_result": auth_result, "access_token": access_token}
|
||||||
|
|
||||||
async def _handle_request( # type: ignore[override]
|
async def _handle_request( # type: ignore[override]
|
||||||
self, request: Request, user_id: str
|
self, request: Request, content: JsonDict, user_id: str
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
content = parse_json_object_from_request(request)
|
|
||||||
|
|
||||||
auth_result = content["auth_result"]
|
auth_result = content["auth_result"]
|
||||||
access_token = content["access_token"]
|
access_token = content["access_token"]
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,6 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||||
from synapse.events import EventBase, make_event_from_dict
|
from synapse.events import EventBase, make_event_from_dict
|
||||||
from synapse.events.snapshot import EventContext
|
from synapse.events.snapshot import EventContext
|
||||||
from synapse.http.server import HttpServer
|
from synapse.http.server import HttpServer
|
||||||
from synapse.http.servlet import parse_json_object_from_request
|
|
||||||
from synapse.replication.http._base import ReplicationEndpoint
|
from synapse.replication.http._base import ReplicationEndpoint
|
||||||
from synapse.types import JsonDict, Requester, UserID
|
from synapse.types import JsonDict, Requester, UserID
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
|
@ -114,11 +113,9 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
async def _handle_request( # type: ignore[override]
|
async def _handle_request( # type: ignore[override]
|
||||||
self, request: Request, event_id: str
|
self, request: Request, content: JsonDict, event_id: str
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
with Measure(self.clock, "repl_send_event_parse"):
|
with Measure(self.clock, "repl_send_event_parse"):
|
||||||
content = parse_json_object_from_request(request)
|
|
||||||
|
|
||||||
event_dict = content["event"]
|
event_dict = content["event"]
|
||||||
room_ver = KNOWN_ROOM_VERSIONS[content["room_version"]]
|
room_ver = KNOWN_ROOM_VERSIONS[content["room_version"]]
|
||||||
internal_metadata = content["internal_metadata"]
|
internal_metadata = content["internal_metadata"]
|
||||||
|
|
|
@ -21,7 +21,6 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||||
from synapse.events import EventBase, make_event_from_dict
|
from synapse.events import EventBase, make_event_from_dict
|
||||||
from synapse.events.snapshot import EventContext
|
from synapse.events.snapshot import EventContext
|
||||||
from synapse.http.server import HttpServer
|
from synapse.http.server import HttpServer
|
||||||
from synapse.http.servlet import parse_json_object_from_request
|
|
||||||
from synapse.replication.http._base import ReplicationEndpoint
|
from synapse.replication.http._base import ReplicationEndpoint
|
||||||
from synapse.types import JsonDict, Requester, UserID
|
from synapse.types import JsonDict, Requester, UserID
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
|
@ -114,10 +113,9 @@ class ReplicationSendEventsRestServlet(ReplicationEndpoint):
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
async def _handle_request( # type: ignore[override]
|
async def _handle_request( # type: ignore[override]
|
||||||
self, request: Request
|
self, request: Request, payload: JsonDict
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
with Measure(self.clock, "repl_send_events_parse"):
|
with Measure(self.clock, "repl_send_events_parse"):
|
||||||
payload = parse_json_object_from_request(request)
|
|
||||||
events_and_context = []
|
events_and_context = []
|
||||||
events = payload["events"]
|
events = payload["events"]
|
||||||
|
|
||||||
|
|
|
@ -57,7 +57,7 @@ class ReplicationUpdateCurrentStateRestServlet(ReplicationEndpoint):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
async def _handle_request( # type: ignore[override]
|
async def _handle_request( # type: ignore[override]
|
||||||
self, request: Request, room_id: str
|
self, request: Request, content: JsonDict, room_id: str
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
writer_instance = self._events_shard_config.get_instance(room_id)
|
writer_instance = self._events_shard_config.get_instance(room_id)
|
||||||
if writer_instance != self._instance_name:
|
if writer_instance != self._instance_name:
|
||||||
|
|
|
@ -54,6 +54,10 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
|
||||||
PATH_ARGS = ("stream_name",)
|
PATH_ARGS = ("stream_name",)
|
||||||
METHOD = "GET"
|
METHOD = "GET"
|
||||||
|
|
||||||
|
# We don't want to wait for replication streams to catch up, as this gets
|
||||||
|
# called in the process of catching replication streams up.
|
||||||
|
WAIT_FOR_STREAMS = False
|
||||||
|
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
super().__init__(hs)
|
super().__init__(hs)
|
||||||
|
|
||||||
|
@ -67,7 +71,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
|
||||||
return {"from_token": from_token, "upto_token": upto_token}
|
return {"from_token": from_token, "upto_token": upto_token}
|
||||||
|
|
||||||
async def _handle_request( # type: ignore[override]
|
async def _handle_request( # type: ignore[override]
|
||||||
self, request: Request, stream_name: str
|
self, request: Request, content: JsonDict, stream_name: str
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
stream = self.streams.get(stream_name)
|
stream = self.streams.get(stream_name)
|
||||||
if stream is None:
|
if stream is None:
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
|
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
from twisted.internet.defer import Deferred
|
from twisted.internet.defer import Deferred
|
||||||
from twisted.internet.interfaces import IAddress, IConnector
|
from twisted.internet.interfaces import IAddress, IConnector
|
||||||
from twisted.internet.protocol import ReconnectingClientFactory
|
from twisted.internet.protocol import ReconnectingClientFactory
|
||||||
|
@ -314,10 +315,21 @@ class ReplicationDataHandler:
|
||||||
self.send_handler.wake_destination(server)
|
self.send_handler.wake_destination(server)
|
||||||
|
|
||||||
async def wait_for_stream_position(
|
async def wait_for_stream_position(
|
||||||
self, instance_name: str, stream_name: str, position: int
|
self,
|
||||||
|
instance_name: str,
|
||||||
|
stream_name: str,
|
||||||
|
position: int,
|
||||||
|
raise_on_timeout: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Wait until this instance has received updates up to and including
|
"""Wait until this instance has received updates up to and including
|
||||||
the given stream position.
|
the given stream position.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
instance_name
|
||||||
|
stream_name
|
||||||
|
position
|
||||||
|
raise_on_timeout: Whether to raise an exception if we time out
|
||||||
|
waiting for the updates, or if we log an error and return.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if instance_name == self._instance_name:
|
if instance_name == self._instance_name:
|
||||||
|
@ -345,7 +357,16 @@ class ReplicationDataHandler:
|
||||||
# We measure here to get in flight counts and average waiting time.
|
# We measure here to get in flight counts and average waiting time.
|
||||||
with Measure(self._clock, "repl.wait_for_stream_position"):
|
with Measure(self._clock, "repl.wait_for_stream_position"):
|
||||||
logger.info("Waiting for repl stream %r to reach %s", stream_name, position)
|
logger.info("Waiting for repl stream %r to reach %s", stream_name, position)
|
||||||
|
try:
|
||||||
await make_deferred_yieldable(deferred)
|
await make_deferred_yieldable(deferred)
|
||||||
|
except defer.TimeoutError:
|
||||||
|
logger.error("Timed out waiting for stream %s", stream_name)
|
||||||
|
|
||||||
|
if raise_on_timeout:
|
||||||
|
raise
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Finished waiting for repl stream %r to reach %s", stream_name, position
|
"Finished waiting for repl stream %r to reach %s", stream_name, position
|
||||||
)
|
)
|
||||||
|
|
|
@ -199,11 +199,6 @@ class ReplicationStreamer:
|
||||||
# The token has advanced but there is no data to
|
# The token has advanced but there is no data to
|
||||||
# send, so we send a `POSITION` to inform other
|
# send, so we send a `POSITION` to inform other
|
||||||
# workers of the updated position.
|
# workers of the updated position.
|
||||||
if stream.NAME == EventsStream.NAME:
|
|
||||||
# XXX: We only do this for the EventStream as it
|
|
||||||
# turns out that e.g. account data streams share
|
|
||||||
# their "current token" with each other, meaning
|
|
||||||
# that it is *not* safe to send a POSITION.
|
|
||||||
|
|
||||||
# Note: `last_token` may not *actually* be the
|
# Note: `last_token` may not *actually* be the
|
||||||
# last token we sent out in a RDATA or POSITION.
|
# last token we sent out in a RDATA or POSITION.
|
||||||
|
|
|
@ -378,6 +378,12 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
|
||||||
self._current_positions.values(), default=1
|
self._current_positions.values(), default=1
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not writers:
|
||||||
|
# If there have been no explicit writers given then any instance can
|
||||||
|
# write to the stream. In which case, let's pre-seed our own
|
||||||
|
# position with the current minimum.
|
||||||
|
self._current_positions[self._instance_name] = self._persisted_upto_position
|
||||||
|
|
||||||
def _load_current_ids(
|
def _load_current_ids(
|
||||||
self,
|
self,
|
||||||
db_conn: LoggingDatabaseConnection,
|
db_conn: LoggingDatabaseConnection,
|
||||||
|
@ -695,24 +701,22 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
|
||||||
|
|
||||||
heapq.heappush(self._known_persisted_positions, new_id)
|
heapq.heappush(self._known_persisted_positions, new_id)
|
||||||
|
|
||||||
# If we're a writer and we don't have any active writes we update our
|
|
||||||
# current position to the latest position seen. This allows the instance
|
|
||||||
# to report a recent position when asked, rather than a potentially old
|
|
||||||
# one (if this instance hasn't written anything for a while).
|
|
||||||
our_current_position = self._current_positions.get(self._instance_name)
|
|
||||||
if (
|
|
||||||
our_current_position
|
|
||||||
and not self._unfinished_ids
|
|
||||||
and not self._in_flight_fetches
|
|
||||||
):
|
|
||||||
self._current_positions[self._instance_name] = max(
|
|
||||||
our_current_position, new_id
|
|
||||||
)
|
|
||||||
|
|
||||||
# We move the current min position up if the minimum current positions
|
# We move the current min position up if the minimum current positions
|
||||||
# of all instances is higher (since by definition all positions less
|
# of all instances is higher (since by definition all positions less
|
||||||
# that that have been persisted).
|
# that that have been persisted).
|
||||||
min_curr = min(self._current_positions.values(), default=0)
|
our_current_position = self._current_positions.get(self._instance_name, 0)
|
||||||
|
min_curr = min(
|
||||||
|
(
|
||||||
|
token
|
||||||
|
for name, token in self._current_positions.items()
|
||||||
|
if name != self._instance_name
|
||||||
|
),
|
||||||
|
default=our_current_position,
|
||||||
|
)
|
||||||
|
|
||||||
|
if our_current_position and (self._unfinished_ids or self._in_flight_fetches):
|
||||||
|
min_curr = min(min_curr, our_current_position)
|
||||||
|
|
||||||
self._persisted_upto_position = max(min_curr, self._persisted_upto_position)
|
self._persisted_upto_position = max(min_curr, self._persisted_upto_position)
|
||||||
|
|
||||||
# We now iterate through the seen positions, discarding those that are
|
# We now iterate through the seen positions, discarding those that are
|
||||||
|
|
|
@ -604,6 +604,12 @@ class RoomStreamToken:
|
||||||
elif self.instance_map:
|
elif self.instance_map:
|
||||||
entries = []
|
entries = []
|
||||||
for name, pos in self.instance_map.items():
|
for name, pos in self.instance_map.items():
|
||||||
|
if pos <= self.stream:
|
||||||
|
# Ignore instances who are below the minimum stream position
|
||||||
|
# (we might know they've advanced without seeing a recent
|
||||||
|
# write from them).
|
||||||
|
continue
|
||||||
|
|
||||||
instance_id = await store.get_id_for_instance(name)
|
instance_id = await store.get_id_for_instance(name)
|
||||||
entries.append(f"{instance_id}.{pos}")
|
entries.append(f"{instance_id}.{pos}")
|
||||||
|
|
||||||
|
|
|
@ -44,7 +44,7 @@ class CancellableReplicationEndpoint(ReplicationEndpoint):
|
||||||
|
|
||||||
@cancellable
|
@cancellable
|
||||||
async def _handle_request( # type: ignore[override]
|
async def _handle_request( # type: ignore[override]
|
||||||
self, request: Request
|
self, request: Request, content: JsonDict
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
await self.clock.sleep(1.0)
|
await self.clock.sleep(1.0)
|
||||||
return HTTPStatus.OK, {"result": True}
|
return HTTPStatus.OK, {"result": True}
|
||||||
|
@ -54,6 +54,7 @@ class UncancellableReplicationEndpoint(ReplicationEndpoint):
|
||||||
NAME = "uncancellable_sleep"
|
NAME = "uncancellable_sleep"
|
||||||
PATH_ARGS = ()
|
PATH_ARGS = ()
|
||||||
CACHE = False
|
CACHE = False
|
||||||
|
WAIT_FOR_STREAMS = False
|
||||||
|
|
||||||
def __init__(self, hs: HomeServer):
|
def __init__(self, hs: HomeServer):
|
||||||
super().__init__(hs)
|
super().__init__(hs)
|
||||||
|
@ -64,7 +65,7 @@ class UncancellableReplicationEndpoint(ReplicationEndpoint):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
async def _handle_request( # type: ignore[override]
|
async def _handle_request( # type: ignore[override]
|
||||||
self, request: Request
|
self, request: Request, content: JsonDict
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
await self.clock.sleep(1.0)
|
await self.clock.sleep(1.0)
|
||||||
return HTTPStatus.OK, {"result": True}
|
return HTTPStatus.OK, {"result": True}
|
||||||
|
@ -85,7 +86,7 @@ class ReplicationEndpointCancellationTestCase(unittest.HomeserverTestCase):
|
||||||
def test_cancellable_disconnect(self) -> None:
|
def test_cancellable_disconnect(self) -> None:
|
||||||
"""Test that handlers with the `@cancellable` flag can be cancelled."""
|
"""Test that handlers with the `@cancellable` flag can be cancelled."""
|
||||||
path = f"{REPLICATION_PREFIX}/{CancellableReplicationEndpoint.NAME}/"
|
path = f"{REPLICATION_PREFIX}/{CancellableReplicationEndpoint.NAME}/"
|
||||||
channel = self.make_request("POST", path, await_result=False)
|
channel = self.make_request("POST", path, await_result=False, content={})
|
||||||
test_disconnect(
|
test_disconnect(
|
||||||
self.reactor,
|
self.reactor,
|
||||||
channel,
|
channel,
|
||||||
|
@ -96,7 +97,7 @@ class ReplicationEndpointCancellationTestCase(unittest.HomeserverTestCase):
|
||||||
def test_uncancellable_disconnect(self) -> None:
|
def test_uncancellable_disconnect(self) -> None:
|
||||||
"""Test that handlers without the `@cancellable` flag cannot be cancelled."""
|
"""Test that handlers without the `@cancellable` flag cannot be cancelled."""
|
||||||
path = f"{REPLICATION_PREFIX}/{UncancellableReplicationEndpoint.NAME}/"
|
path = f"{REPLICATION_PREFIX}/{UncancellableReplicationEndpoint.NAME}/"
|
||||||
channel = self.make_request("POST", path, await_result=False)
|
channel = self.make_request("POST", path, await_result=False, content={})
|
||||||
test_disconnect(
|
test_disconnect(
|
||||||
self.reactor,
|
self.reactor,
|
||||||
channel,
|
channel,
|
||||||
|
|
|
@ -349,8 +349,8 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
# The first ID gen will notice that it can advance its token to 7 as it
|
# The first ID gen will notice that it can advance its token to 7 as it
|
||||||
# has no in progress writes...
|
# has no in progress writes...
|
||||||
self.assertEqual(first_id_gen.get_positions(), {"first": 7, "second": 7})
|
self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
|
||||||
self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
|
self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3)
|
||||||
self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
|
self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
|
||||||
|
|
||||||
# ... but the second ID gen doesn't know that.
|
# ... but the second ID gen doesn't know that.
|
||||||
|
@ -366,8 +366,9 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||||
self.assertEqual(stream_id, 8)
|
self.assertEqual(stream_id, 8)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
first_id_gen.get_positions(), {"first": 7, "second": 7}
|
first_id_gen.get_positions(), {"first": 3, "second": 7}
|
||||||
)
|
)
|
||||||
|
self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)
|
||||||
|
|
||||||
self.get_success(_get_next_async())
|
self.get_success(_get_next_async())
|
||||||
|
|
||||||
|
@ -473,7 +474,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
id_gen = self._create_id_generator("first", writers=["first", "second"])
|
id_gen = self._create_id_generator("first", writers=["first", "second"])
|
||||||
|
|
||||||
self.assertEqual(id_gen.get_positions(), {"first": 5, "second": 5})
|
self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
|
||||||
|
|
||||||
self.assertEqual(id_gen.get_persisted_upto_position(), 5)
|
self.assertEqual(id_gen.get_persisted_upto_position(), 5)
|
||||||
|
|
||||||
|
@ -720,7 +721,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
self.get_success(_get_next_async2())
|
self.get_success(_get_next_async2())
|
||||||
|
|
||||||
self.assertEqual(id_gen_1.get_positions(), {"first": -2, "second": -2})
|
self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2})
|
||||||
self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2})
|
self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2})
|
||||||
self.assertEqual(id_gen_1.get_persisted_upto_position(), -2)
|
self.assertEqual(id_gen_1.get_persisted_upto_position(), -2)
|
||||||
self.assertEqual(id_gen_2.get_persisted_upto_position(), -2)
|
self.assertEqual(id_gen_2.get_persisted_upto_position(), -2)
|
||||||
|
@ -816,15 +817,12 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||||
first_id_gen = self._create_id_generator("first", writers=["first", "second"])
|
first_id_gen = self._create_id_generator("first", writers=["first", "second"])
|
||||||
second_id_gen = self._create_id_generator("second", writers=["first", "second"])
|
second_id_gen = self._create_id_generator("second", writers=["first", "second"])
|
||||||
|
|
||||||
# The first ID gen will notice that it can advance its token to 7 as it
|
self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 6})
|
||||||
# has no in progress writes...
|
self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3)
|
||||||
self.assertEqual(first_id_gen.get_positions(), {"first": 7, "second": 6})
|
|
||||||
self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
|
|
||||||
self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 6)
|
self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 6)
|
||||||
self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)
|
self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)
|
||||||
|
|
||||||
# ... but the second ID gen doesn't know that.
|
|
||||||
self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
|
self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
|
||||||
self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 3)
|
self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 3)
|
||||||
self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
|
self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
|
||||||
self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)
|
self.assertEqual(second_id_gen.get_persisted_upto_position(), 7)
|
||||||
|
|
Loading…
Reference in a new issue