mirror of
https://mau.dev/maunium/synapse.git
synced 2024-11-14 14:01:59 +01:00
Fix sync waiting for an invalid token from the "future" (#17386)
Fixes https://github.com/element-hq/synapse/issues/17274, hopefully. Basically, old versions of Synapse could advance streams without persisting anything in the DB (fixed in #17229). On restart those updates would get lost, and so the position of the stream would revert to an older position. If this happened across an upgrade to a later Synapse version which included #17215, then sync could get blocked indefinitely (until the stream advanced to the position in the token). We fix this by bounding the stream positions we'll wait for to the maximum position of the underlying stream ID generator.
This commit is contained in:
parent
9c8f1a6d41
commit
b3b793786c
17 changed files with 229 additions and 31 deletions
1
changelog.d/17386.bugfix
Normal file
1
changelog.d/17386.bugfix
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Fix bug where `/sync` requests could get blocked indefinitely after an upgrade from Synapse versions before v1.109.0.
|
|
@ -764,6 +764,13 @@ class Notifier:
|
||||||
|
|
||||||
async def wait_for_stream_token(self, stream_token: StreamToken) -> bool:
|
async def wait_for_stream_token(self, stream_token: StreamToken) -> bool:
|
||||||
"""Wait for this worker to catch up with the given stream token."""
|
"""Wait for this worker to catch up with the given stream token."""
|
||||||
|
current_token = self.event_sources.get_current_token()
|
||||||
|
if stream_token.is_before_or_eq(current_token):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Work around a bug where older Synapse versions gave out tokens "from
|
||||||
|
# the future", i.e. that are ahead of the tokens persisted in the DB.
|
||||||
|
stream_token = await self.event_sources.bound_future_token(stream_token)
|
||||||
|
|
||||||
start = self.clock.time_msec()
|
start = self.clock.time_msec()
|
||||||
while True:
|
while True:
|
||||||
|
|
|
@ -43,10 +43,7 @@ from synapse.storage.database import (
|
||||||
)
|
)
|
||||||
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
||||||
from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
|
from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
|
||||||
from synapse.storage.util.id_generators import (
|
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||||
AbstractStreamIdGenerator,
|
|
||||||
MultiWriterIdGenerator,
|
|
||||||
)
|
|
||||||
from synapse.types import JsonDict, JsonMapping
|
from synapse.types import JsonDict, JsonMapping
|
||||||
from synapse.util import json_encoder
|
from synapse.util import json_encoder
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import cached
|
||||||
|
@ -71,7 +68,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
|
||||||
self._instance_name in hs.config.worker.writers.account_data
|
self._instance_name in hs.config.worker.writers.account_data
|
||||||
)
|
)
|
||||||
|
|
||||||
self._account_data_id_gen: AbstractStreamIdGenerator
|
self._account_data_id_gen: MultiWriterIdGenerator
|
||||||
|
|
||||||
self._account_data_id_gen = MultiWriterIdGenerator(
|
self._account_data_id_gen = MultiWriterIdGenerator(
|
||||||
db_conn=db_conn,
|
db_conn=db_conn,
|
||||||
|
@ -113,6 +110,9 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
|
||||||
"""
|
"""
|
||||||
return self._account_data_id_gen.get_current_token()
|
return self._account_data_id_gen.get_current_token()
|
||||||
|
|
||||||
|
def get_account_data_id_generator(self) -> MultiWriterIdGenerator:
|
||||||
|
return self._account_data_id_gen
|
||||||
|
|
||||||
@cached()
|
@cached()
|
||||||
async def get_global_account_data_for_user(
|
async def get_global_account_data_for_user(
|
||||||
self, user_id: str
|
self, user_id: str
|
||||||
|
|
|
@ -50,10 +50,7 @@ from synapse.storage.database import (
|
||||||
LoggingTransaction,
|
LoggingTransaction,
|
||||||
make_in_list_sql_clause,
|
make_in_list_sql_clause,
|
||||||
)
|
)
|
||||||
from synapse.storage.util.id_generators import (
|
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||||
AbstractStreamIdGenerator,
|
|
||||||
MultiWriterIdGenerator,
|
|
||||||
)
|
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
from synapse.util import json_encoder
|
from synapse.util import json_encoder
|
||||||
from synapse.util.caches.expiringcache import ExpiringCache
|
from synapse.util.caches.expiringcache import ExpiringCache
|
||||||
|
@ -92,7 +89,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
||||||
self._instance_name in hs.config.worker.writers.to_device
|
self._instance_name in hs.config.worker.writers.to_device
|
||||||
)
|
)
|
||||||
|
|
||||||
self._to_device_msg_id_gen: AbstractStreamIdGenerator = MultiWriterIdGenerator(
|
self._to_device_msg_id_gen: MultiWriterIdGenerator = MultiWriterIdGenerator(
|
||||||
db_conn=db_conn,
|
db_conn=db_conn,
|
||||||
db=database,
|
db=database,
|
||||||
notifier=hs.get_replication_notifier(),
|
notifier=hs.get_replication_notifier(),
|
||||||
|
@ -169,6 +166,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
||||||
def get_to_device_stream_token(self) -> int:
|
def get_to_device_stream_token(self) -> int:
|
||||||
return self._to_device_msg_id_gen.get_current_token()
|
return self._to_device_msg_id_gen.get_current_token()
|
||||||
|
|
||||||
|
def get_to_device_id_generator(self) -> MultiWriterIdGenerator:
|
||||||
|
return self._to_device_msg_id_gen
|
||||||
|
|
||||||
async def get_messages_for_user_devices(
|
async def get_messages_for_user_devices(
|
||||||
self,
|
self,
|
||||||
user_ids: Collection[str],
|
user_ids: Collection[str],
|
||||||
|
|
|
@ -243,6 +243,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
||||||
def get_device_stream_token(self) -> int:
|
def get_device_stream_token(self) -> int:
|
||||||
return self._device_list_id_gen.get_current_token()
|
return self._device_list_id_gen.get_current_token()
|
||||||
|
|
||||||
|
def get_device_stream_id_generator(self) -> MultiWriterIdGenerator:
|
||||||
|
return self._device_list_id_gen
|
||||||
|
|
||||||
async def count_devices_by_users(
|
async def count_devices_by_users(
|
||||||
self, user_ids: Optional[Collection[str]] = None
|
self, user_ids: Optional[Collection[str]] = None
|
||||||
) -> int:
|
) -> int:
|
||||||
|
|
|
@ -192,8 +192,8 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
):
|
):
|
||||||
super().__init__(database, db_conn, hs)
|
super().__init__(database, db_conn, hs)
|
||||||
|
|
||||||
self._stream_id_gen: AbstractStreamIdGenerator
|
self._stream_id_gen: MultiWriterIdGenerator
|
||||||
self._backfill_id_gen: AbstractStreamIdGenerator
|
self._backfill_id_gen: MultiWriterIdGenerator
|
||||||
|
|
||||||
self._stream_id_gen = MultiWriterIdGenerator(
|
self._stream_id_gen = MultiWriterIdGenerator(
|
||||||
db_conn=db_conn,
|
db_conn=db_conn,
|
||||||
|
|
|
@ -42,10 +42,7 @@ from synapse.storage.database import (
|
||||||
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
||||||
from synapse.storage.engines._base import IsolationLevel
|
from synapse.storage.engines._base import IsolationLevel
|
||||||
from synapse.storage.types import Connection
|
from synapse.storage.types import Connection
|
||||||
from synapse.storage.util.id_generators import (
|
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||||
AbstractStreamIdGenerator,
|
|
||||||
MultiWriterIdGenerator,
|
|
||||||
)
|
|
||||||
from synapse.util.caches.descriptors import cached, cachedList
|
from synapse.util.caches.descriptors import cached, cachedList
|
||||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||||
from synapse.util.iterutils import batch_iter
|
from synapse.util.iterutils import batch_iter
|
||||||
|
@ -83,7 +80,7 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
|
||||||
super().__init__(database, db_conn, hs)
|
super().__init__(database, db_conn, hs)
|
||||||
|
|
||||||
self._instance_name = hs.get_instance_name()
|
self._instance_name = hs.get_instance_name()
|
||||||
self._presence_id_gen: AbstractStreamIdGenerator
|
self._presence_id_gen: MultiWriterIdGenerator
|
||||||
|
|
||||||
self._can_persist_presence = (
|
self._can_persist_presence = (
|
||||||
self._instance_name in hs.config.worker.writers.presence
|
self._instance_name in hs.config.worker.writers.presence
|
||||||
|
@ -455,6 +452,9 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
|
||||||
def get_current_presence_token(self) -> int:
|
def get_current_presence_token(self) -> int:
|
||||||
return self._presence_id_gen.get_current_token()
|
return self._presence_id_gen.get_current_token()
|
||||||
|
|
||||||
|
def get_presence_stream_id_gen(self) -> MultiWriterIdGenerator:
|
||||||
|
return self._presence_id_gen
|
||||||
|
|
||||||
def _get_active_presence(self, db_conn: Connection) -> List[UserPresenceState]:
|
def _get_active_presence(self, db_conn: Connection) -> List[UserPresenceState]:
|
||||||
"""Fetch non-offline presence from the database so that we can register
|
"""Fetch non-offline presence from the database so that we can register
|
||||||
the appropriate time outs.
|
the appropriate time outs.
|
||||||
|
|
|
@ -178,6 +178,9 @@ class PushRulesWorkerStore(
|
||||||
"""
|
"""
|
||||||
return self._push_rules_stream_id_gen.get_current_token()
|
return self._push_rules_stream_id_gen.get_current_token()
|
||||||
|
|
||||||
|
def get_push_rules_stream_id_gen(self) -> MultiWriterIdGenerator:
|
||||||
|
return self._push_rules_stream_id_gen
|
||||||
|
|
||||||
def process_replication_rows(
|
def process_replication_rows(
|
||||||
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
|
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
|
@ -45,10 +45,7 @@ from synapse.storage.database import (
|
||||||
LoggingTransaction,
|
LoggingTransaction,
|
||||||
)
|
)
|
||||||
from synapse.storage.engines._base import IsolationLevel
|
from synapse.storage.engines._base import IsolationLevel
|
||||||
from synapse.storage.util.id_generators import (
|
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||||
AbstractStreamIdGenerator,
|
|
||||||
MultiWriterIdGenerator,
|
|
||||||
)
|
|
||||||
from synapse.types import (
|
from synapse.types import (
|
||||||
JsonDict,
|
JsonDict,
|
||||||
JsonMapping,
|
JsonMapping,
|
||||||
|
@ -76,7 +73,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
# In the worker store this is an ID tracker which we overwrite in the non-worker
|
# In the worker store this is an ID tracker which we overwrite in the non-worker
|
||||||
# class below that is used on the main process.
|
# class below that is used on the main process.
|
||||||
self._receipts_id_gen: AbstractStreamIdGenerator
|
self._receipts_id_gen: MultiWriterIdGenerator
|
||||||
|
|
||||||
self._can_write_to_receipts = (
|
self._can_write_to_receipts = (
|
||||||
self._instance_name in hs.config.worker.writers.receipts
|
self._instance_name in hs.config.worker.writers.receipts
|
||||||
|
@ -136,6 +133,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||||
def get_receipt_stream_id_for_instance(self, instance_name: str) -> int:
|
def get_receipt_stream_id_for_instance(self, instance_name: str) -> int:
|
||||||
return self._receipts_id_gen.get_current_token_for_writer(instance_name)
|
return self._receipts_id_gen.get_current_token_for_writer(instance_name)
|
||||||
|
|
||||||
|
def get_receipts_stream_id_gen(self) -> MultiWriterIdGenerator:
|
||||||
|
return self._receipts_id_gen
|
||||||
|
|
||||||
def get_last_unthreaded_receipt_for_user_txn(
|
def get_last_unthreaded_receipt_for_user_txn(
|
||||||
self,
|
self,
|
||||||
txn: LoggingTransaction,
|
txn: LoggingTransaction,
|
||||||
|
|
|
@ -59,11 +59,7 @@ from synapse.storage.database import (
|
||||||
)
|
)
|
||||||
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
||||||
from synapse.storage.types import Cursor
|
from synapse.storage.types import Cursor
|
||||||
from synapse.storage.util.id_generators import (
|
from synapse.storage.util.id_generators import IdGenerator, MultiWriterIdGenerator
|
||||||
AbstractStreamIdGenerator,
|
|
||||||
IdGenerator,
|
|
||||||
MultiWriterIdGenerator,
|
|
||||||
)
|
|
||||||
from synapse.types import JsonDict, RetentionPolicy, StrCollection, ThirdPartyInstanceID
|
from synapse.types import JsonDict, RetentionPolicy, StrCollection, ThirdPartyInstanceID
|
||||||
from synapse.util import json_encoder
|
from synapse.util import json_encoder
|
||||||
from synapse.util.caches.descriptors import cached, cachedList
|
from synapse.util.caches.descriptors import cached, cachedList
|
||||||
|
@ -151,7 +147,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
||||||
|
|
||||||
self.config: HomeServerConfig = hs.config
|
self.config: HomeServerConfig = hs.config
|
||||||
|
|
||||||
self._un_partial_stated_rooms_stream_id_gen: AbstractStreamIdGenerator
|
self._un_partial_stated_rooms_stream_id_gen: MultiWriterIdGenerator
|
||||||
|
|
||||||
self._un_partial_stated_rooms_stream_id_gen = MultiWriterIdGenerator(
|
self._un_partial_stated_rooms_stream_id_gen = MultiWriterIdGenerator(
|
||||||
db_conn=db_conn,
|
db_conn=db_conn,
|
||||||
|
@ -1409,6 +1405,9 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
||||||
instance_name
|
instance_name
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_un_partial_stated_rooms_id_generator(self) -> MultiWriterIdGenerator:
|
||||||
|
return self._un_partial_stated_rooms_stream_id_gen
|
||||||
|
|
||||||
async def get_un_partial_stated_rooms_between(
|
async def get_un_partial_stated_rooms_between(
|
||||||
self, last_id: int, current_id: int, room_ids: Collection[str]
|
self, last_id: int, current_id: int, room_ids: Collection[str]
|
||||||
) -> Set[str]:
|
) -> Set[str]:
|
||||||
|
|
|
@ -577,6 +577,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
|
|
||||||
return RoomStreamToken(stream=min_pos, instance_map=immutabledict(positions))
|
return RoomStreamToken(stream=min_pos, instance_map=immutabledict(positions))
|
||||||
|
|
||||||
|
def get_events_stream_id_generator(self) -> MultiWriterIdGenerator:
|
||||||
|
return self._stream_id_gen
|
||||||
|
|
||||||
async def get_room_events_stream_for_rooms(
|
async def get_room_events_stream_for_rooms(
|
||||||
self,
|
self,
|
||||||
room_ids: Collection[str],
|
room_ids: Collection[str],
|
||||||
|
|
|
@ -812,6 +812,11 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
|
||||||
pos = self.get_current_token_for_writer(self._instance_name)
|
pos = self.get_current_token_for_writer(self._instance_name)
|
||||||
txn.execute(sql, (self._stream_name, self._instance_name, pos))
|
txn.execute(sql, (self._stream_name, self._instance_name, pos))
|
||||||
|
|
||||||
|
async def get_max_allocated_token(self) -> int:
|
||||||
|
return await self._db.runInteraction(
|
||||||
|
"get_max_allocated_token", self._sequence_gen.get_max_allocated
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@attr.s(frozen=True, auto_attribs=True)
|
@attr.s(frozen=True, auto_attribs=True)
|
||||||
class _AsyncCtxManagerWrapper(Generic[T]):
|
class _AsyncCtxManagerWrapper(Generic[T]):
|
||||||
|
|
|
@ -88,6 +88,10 @@ class SequenceGenerator(metaclass=abc.ABCMeta):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def get_max_allocated(self, txn: Cursor) -> int:
|
||||||
|
"""Get the maximum ID that we have allocated"""
|
||||||
|
|
||||||
|
|
||||||
class PostgresSequenceGenerator(SequenceGenerator):
|
class PostgresSequenceGenerator(SequenceGenerator):
|
||||||
"""An implementation of SequenceGenerator which uses a postgres sequence"""
|
"""An implementation of SequenceGenerator which uses a postgres sequence"""
|
||||||
|
@ -190,6 +194,17 @@ class PostgresSequenceGenerator(SequenceGenerator):
|
||||||
% {"seq": self._sequence_name, "stream_name": stream_name}
|
% {"seq": self._sequence_name, "stream_name": stream_name}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_max_allocated(self, txn: Cursor) -> int:
|
||||||
|
# We just read from the sequence what the last value we fetched was.
|
||||||
|
txn.execute(f"SELECT last_value, is_called FROM {self._sequence_name}")
|
||||||
|
row = txn.fetchone()
|
||||||
|
assert row is not None
|
||||||
|
|
||||||
|
last_value, is_called = row
|
||||||
|
if not is_called:
|
||||||
|
last_value -= 1
|
||||||
|
return last_value
|
||||||
|
|
||||||
|
|
||||||
GetFirstCallbackType = Callable[[Cursor], int]
|
GetFirstCallbackType = Callable[[Cursor], int]
|
||||||
|
|
||||||
|
@ -248,6 +263,15 @@ class LocalSequenceGenerator(SequenceGenerator):
|
||||||
# There is nothing to do for in memory sequences
|
# There is nothing to do for in memory sequences
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def get_max_allocated(self, txn: Cursor) -> int:
|
||||||
|
with self._lock:
|
||||||
|
if self._current_max_id is None:
|
||||||
|
assert self._callback is not None
|
||||||
|
self._current_max_id = self._callback(txn)
|
||||||
|
self._callback = None
|
||||||
|
|
||||||
|
return self._current_max_id
|
||||||
|
|
||||||
|
|
||||||
def build_sequence_generator(
|
def build_sequence_generator(
|
||||||
db_conn: "LoggingDatabaseConnection",
|
db_conn: "LoggingDatabaseConnection",
|
||||||
|
|
|
@ -30,7 +30,12 @@ from synapse.handlers.room import RoomEventSource
|
||||||
from synapse.handlers.typing import TypingNotificationEventSource
|
from synapse.handlers.typing import TypingNotificationEventSource
|
||||||
from synapse.logging.opentracing import trace
|
from synapse.logging.opentracing import trace
|
||||||
from synapse.streams import EventSource
|
from synapse.streams import EventSource
|
||||||
from synapse.types import MultiWriterStreamToken, StreamKeyType, StreamToken
|
from synapse.types import (
|
||||||
|
AbstractMultiWriterStreamToken,
|
||||||
|
MultiWriterStreamToken,
|
||||||
|
StreamKeyType,
|
||||||
|
StreamToken,
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
@ -91,6 +96,63 @@ class EventSources:
|
||||||
)
|
)
|
||||||
return token
|
return token
|
||||||
|
|
||||||
|
async def bound_future_token(self, token: StreamToken) -> StreamToken:
|
||||||
|
"""Bound a token that is ahead of the current token to the maximum
|
||||||
|
persisted values.
|
||||||
|
|
||||||
|
This ensures that if we wait for the given token we know the stream will
|
||||||
|
eventually advance to that point.
|
||||||
|
|
||||||
|
This works around a bug where older Synapse versions will give out
|
||||||
|
tokens for streams, and then after a restart will give back tokens where
|
||||||
|
the stream has "gone backwards".
|
||||||
|
"""
|
||||||
|
|
||||||
|
current_token = self.get_current_token()
|
||||||
|
|
||||||
|
stream_key_to_id_gen = {
|
||||||
|
StreamKeyType.ROOM: self.store.get_events_stream_id_generator(),
|
||||||
|
StreamKeyType.PRESENCE: self.store.get_presence_stream_id_gen(),
|
||||||
|
StreamKeyType.RECEIPT: self.store.get_receipts_stream_id_gen(),
|
||||||
|
StreamKeyType.ACCOUNT_DATA: self.store.get_account_data_id_generator(),
|
||||||
|
StreamKeyType.PUSH_RULES: self.store.get_push_rules_stream_id_gen(),
|
||||||
|
StreamKeyType.TO_DEVICE: self.store.get_to_device_id_generator(),
|
||||||
|
StreamKeyType.DEVICE_LIST: self.store.get_device_stream_id_generator(),
|
||||||
|
StreamKeyType.UN_PARTIAL_STATED_ROOMS: self.store.get_un_partial_stated_rooms_id_generator(),
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, key in StreamKeyType.__members__.items():
|
||||||
|
if key == StreamKeyType.TYPING:
|
||||||
|
# Typing stream is allowed to "reset", and so comparisons don't
|
||||||
|
# really make sense as is.
|
||||||
|
# TODO: Figure out a better way of tracking resets.
|
||||||
|
continue
|
||||||
|
|
||||||
|
token_value = token.get_field(key)
|
||||||
|
current_value = current_token.get_field(key)
|
||||||
|
|
||||||
|
if isinstance(token_value, AbstractMultiWriterStreamToken):
|
||||||
|
assert type(current_value) is type(token_value)
|
||||||
|
|
||||||
|
if not token_value.is_before_or_eq(current_value): # type: ignore[arg-type]
|
||||||
|
max_token = await stream_key_to_id_gen[
|
||||||
|
key
|
||||||
|
].get_max_allocated_token()
|
||||||
|
|
||||||
|
token = token.copy_and_replace(
|
||||||
|
key, token.room_key.bound_stream_token(max_token)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert isinstance(current_value, int)
|
||||||
|
if current_value < token_value:
|
||||||
|
max_token = await stream_key_to_id_gen[
|
||||||
|
key
|
||||||
|
].get_max_allocated_token()
|
||||||
|
|
||||||
|
token = token.copy_and_replace(key, min(token_value, max_token))
|
||||||
|
|
||||||
|
return token
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
async def get_start_token_for_pagination(self, room_id: str) -> StreamToken:
|
async def get_start_token_for_pagination(self, room_id: str) -> StreamToken:
|
||||||
"""Get the start token for a given room to be used to paginate
|
"""Get the start token for a given room to be used to paginate
|
||||||
|
|
|
@ -536,6 +536,16 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def bound_stream_token(self, max_stream: int) -> "Self":
|
||||||
|
"""Bound the stream positions to a maximum value"""
|
||||||
|
|
||||||
|
return type(self)(
|
||||||
|
stream=min(self.stream, max_stream),
|
||||||
|
instance_map=immutabledict(
|
||||||
|
{k: min(s, max_stream) for k, s in self.instance_map.items()}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@attr.s(frozen=True, slots=True, order=False)
|
@attr.s(frozen=True, slots=True, order=False)
|
||||||
class RoomStreamToken(AbstractMultiWriterStreamToken):
|
class RoomStreamToken(AbstractMultiWriterStreamToken):
|
||||||
|
@ -722,6 +732,14 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
|
||||||
else:
|
else:
|
||||||
return "s%d" % (self.stream,)
|
return "s%d" % (self.stream,)
|
||||||
|
|
||||||
|
def bound_stream_token(self, max_stream: int) -> "RoomStreamToken":
|
||||||
|
"""See super class"""
|
||||||
|
|
||||||
|
# This only makes sense for stream tokens.
|
||||||
|
assert self.topological is None
|
||||||
|
|
||||||
|
return super().bound_stream_token(max_stream)
|
||||||
|
|
||||||
|
|
||||||
@attr.s(frozen=True, slots=True, order=False)
|
@attr.s(frozen=True, slots=True, order=False)
|
||||||
class MultiWriterStreamToken(AbstractMultiWriterStreamToken):
|
class MultiWriterStreamToken(AbstractMultiWriterStreamToken):
|
||||||
|
|
|
@ -22,6 +22,7 @@ from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
from twisted.test.proto_helpers import MemoryReactor
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.api.constants import AccountDataTypes, EventTypes, JoinRules
|
from synapse.api.constants import AccountDataTypes, EventTypes, JoinRules
|
||||||
|
@ -35,7 +36,7 @@ from synapse.handlers.sync import SyncConfig, SyncRequestKey, SyncResult, SyncVe
|
||||||
from synapse.rest import admin
|
from synapse.rest import admin
|
||||||
from synapse.rest.client import knock, login, room
|
from synapse.rest.client import knock, login, room
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.types import JsonDict, UserID, create_requester
|
from synapse.types import JsonDict, StreamKeyType, UserID, create_requester
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
|
||||||
import tests.unittest
|
import tests.unittest
|
||||||
|
@ -959,6 +960,76 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.fail("No push rules found")
|
self.fail("No push rules found")
|
||||||
|
|
||||||
|
def test_wait_for_future_sync_token(self) -> None:
|
||||||
|
"""Test that if we receive a token that is ahead of our current token,
|
||||||
|
we'll wait until the stream position advances.
|
||||||
|
|
||||||
|
This can happen if replication streams start lagging, and the client's
|
||||||
|
previous sync request was serviced by a worker ahead of ours.
|
||||||
|
"""
|
||||||
|
user = self.register_user("alice", "password")
|
||||||
|
|
||||||
|
# We simulate a lagging stream by getting a stream ID from the ID gen
|
||||||
|
# and then waiting to mark it as "persisted".
|
||||||
|
presence_id_gen = self.store.get_presence_stream_id_gen()
|
||||||
|
ctx_mgr = presence_id_gen.get_next()
|
||||||
|
stream_id = self.get_success(ctx_mgr.__aenter__())
|
||||||
|
|
||||||
|
# Create the new token based on the stream ID above.
|
||||||
|
current_token = self.hs.get_event_sources().get_current_token()
|
||||||
|
since_token = current_token.copy_and_advance(StreamKeyType.PRESENCE, stream_id)
|
||||||
|
|
||||||
|
sync_d = defer.ensureDeferred(
|
||||||
|
self.sync_handler.wait_for_sync_for_user(
|
||||||
|
create_requester(user),
|
||||||
|
generate_sync_config(user),
|
||||||
|
sync_version=SyncVersion.SYNC_V2,
|
||||||
|
request_key=generate_request_key(),
|
||||||
|
since_token=since_token,
|
||||||
|
timeout=0,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# This should block waiting for the presence stream to update
|
||||||
|
self.pump()
|
||||||
|
self.assertFalse(sync_d.called)
|
||||||
|
|
||||||
|
# Marking the stream ID as persisted should unblock the request.
|
||||||
|
self.get_success(ctx_mgr.__aexit__(None, None, None))
|
||||||
|
|
||||||
|
self.get_success(sync_d, by=1.0)
|
||||||
|
|
||||||
|
def test_wait_for_invalid_future_sync_token(self) -> None:
|
||||||
|
"""Like the previous test, except we give a token that has a stream
|
||||||
|
position ahead of what is in the DB, i.e. its invalid and we shouldn't
|
||||||
|
wait for the stream to advance (as it may never do so).
|
||||||
|
|
||||||
|
This can happen due to older versions of Synapse giving out stream
|
||||||
|
positions without persisting them in the DB, and so on restart the
|
||||||
|
stream would get reset back to an older position.
|
||||||
|
"""
|
||||||
|
user = self.register_user("alice", "password")
|
||||||
|
|
||||||
|
# Create a token and arbitrarily advance one of the streams.
|
||||||
|
current_token = self.hs.get_event_sources().get_current_token()
|
||||||
|
since_token = current_token.copy_and_advance(
|
||||||
|
StreamKeyType.PRESENCE, current_token.presence_key + 1
|
||||||
|
)
|
||||||
|
|
||||||
|
sync_d = defer.ensureDeferred(
|
||||||
|
self.sync_handler.wait_for_sync_for_user(
|
||||||
|
create_requester(user),
|
||||||
|
generate_sync_config(user),
|
||||||
|
sync_version=SyncVersion.SYNC_V2,
|
||||||
|
request_key=generate_request_key(),
|
||||||
|
since_token=since_token,
|
||||||
|
timeout=0,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# We should return without waiting for the presence stream to advance.
|
||||||
|
self.get_success(sync_d)
|
||||||
|
|
||||||
|
|
||||||
def generate_sync_config(
|
def generate_sync_config(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
|
|
|
@ -1386,10 +1386,12 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
|
||||||
# Create a future token that will cause us to wait. Since we never send a new
|
# Create a future token that will cause us to wait. Since we never send a new
|
||||||
# event to reach that future stream_ordering, the worker will wait until the
|
# event to reach that future stream_ordering, the worker will wait until the
|
||||||
# full timeout.
|
# full timeout.
|
||||||
|
stream_id_gen = self.store.get_events_stream_id_generator()
|
||||||
|
stream_id = self.get_success(stream_id_gen.get_next().__aenter__())
|
||||||
current_token = self.event_sources.get_current_token()
|
current_token = self.event_sources.get_current_token()
|
||||||
future_position_token = current_token.copy_and_replace(
|
future_position_token = current_token.copy_and_replace(
|
||||||
StreamKeyType.ROOM,
|
StreamKeyType.ROOM,
|
||||||
RoomStreamToken(stream=current_token.room_key.stream + 1),
|
RoomStreamToken(stream=stream_id),
|
||||||
)
|
)
|
||||||
|
|
||||||
future_position_token_serialized = self.get_success(
|
future_position_token_serialized = self.get_success(
|
||||||
|
|
Loading…
Reference in a new issue