0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-12 04:52:26 +01:00

Fix bug where sync could get stuck when using workers (#17438)

This is because we serialized the token wrong if the instance map
contained entries from before the minimum token.
This commit is contained in:
Erik Johnston 2024-07-15 16:13:04 +01:00 committed by GitHub
parent d88ba45db9
commit df11af14db
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 138 additions and 10 deletions

1
changelog.d/17438.bugfix Normal file
View file

@ -0,0 +1 @@
Fix rare bug where `/sync` would break for a user when using workers with multiple stream writers.

View file

@ -699,10 +699,17 @@ class SlidingSyncHandler:
instance_to_max_stream_ordering_map[instance_name] = stream_ordering
# Then assemble the `RoomStreamToken`
min_stream_pos = min(instance_to_max_stream_ordering_map.values())
membership_snapshot_token = RoomStreamToken(
# Minimum position in the `instance_map`
stream=min(instance_to_max_stream_ordering_map.values()),
instance_map=immutabledict(instance_to_max_stream_ordering_map),
stream=min_stream_pos,
instance_map=immutabledict(
{
instance_name: stream_pos
for instance_name, stream_pos in instance_to_max_stream_ordering_map.items()
if stream_pos > min_stream_pos
}
),
)
# Since we fetched the users room list at some point in time after the from/to

View file

@ -20,6 +20,7 @@
#
#
import abc
import logging
import re
import string
from enum import Enum
@ -74,6 +75,9 @@ if TYPE_CHECKING:
from synapse.storage.databases.main import DataStore, PurgeEventsStore
from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore
logger = logging.getLogger(__name__)
# Define a state map type from type/state_key to T (usually an event ID or
# event)
T = TypeVar("T")
@ -454,6 +458,8 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
represented by a default `stream` attribute and a map of instance name to
stream position of any writers that are ahead of the default stream
position.
The values in `instance_map` must be greater than the `stream` attribute.
"""
stream: int = attr.ib(validator=attr.validators.instance_of(int), kw_only=True)
@ -468,6 +474,15 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
kw_only=True,
)
def __attrs_post_init__(self) -> None:
# Enforce that all instances have a value greater than the min stream
# position.
for i, v in self.instance_map.items():
if v <= self.stream:
raise ValueError(
f"'instance_map' includes a stream position before the main 'stream' attribute. Instance: {i}"
)
@classmethod
@abc.abstractmethod
async def parse(cls, store: "DataStore", string: str) -> "Self":
@ -494,6 +509,9 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
for instance in set(self.instance_map).union(other.instance_map)
}
# Filter out any redundant entries.
instance_map = {i: s for i, s in instance_map.items() if s > max_stream}
return attr.evolve(
self, stream=max_stream, instance_map=immutabledict(instance_map)
)
@ -539,10 +557,15 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
def bound_stream_token(self, max_stream: int) -> "Self":
"""Bound the stream positions to a maximum value"""
min_pos = min(self.stream, max_stream)
return type(self)(
stream=min(self.stream, max_stream),
stream=min_pos,
instance_map=immutabledict(
{k: min(s, max_stream) for k, s in self.instance_map.items()}
{
k: min(s, max_stream)
for k, s in self.instance_map.items()
if min(s, max_stream) > min_pos
}
),
)
@ -637,6 +660,8 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
"Cannot set both 'topological' and 'instance_map' on 'RoomStreamToken'."
)
super().__attrs_post_init__()
@classmethod
async def parse(cls, store: "PurgeEventsStore", string: str) -> "RoomStreamToken":
try:
@ -651,6 +676,11 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
instance_map = {}
for part in parts[1:]:
if not part:
# Handle tokens of the form `m5~`, which were created by
# a bug
continue
key, value = part.split(".")
instance_id = int(key)
pos = int(value)
@ -666,7 +696,10 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
except CancelledError:
raise
except Exception:
pass
# We log an exception here as even though this *might* be a client
# handing a bad token, its more likely that Synapse returned a bad
# token (and we really want to catch those!).
logger.exception("Failed to parse stream token: %r", string)
raise SynapseError(400, "Invalid room stream token %r" % (string,))
@classmethod
@ -713,6 +746,8 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
return self.instance_map.get(instance_name, self.stream)
async def to_string(self, store: "DataStore") -> str:
"""See class level docstring for information about the format."""
if self.topological is not None:
return "t%d-%d" % (self.topological, self.stream)
elif self.instance_map:
@ -727,8 +762,10 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
instance_id = await store.get_id_for_instance(name)
entries.append(f"{instance_id}.{pos}")
encoded_map = "~".join(entries)
return f"m{self.stream}~{encoded_map}"
if entries:
encoded_map = "~".join(entries)
return f"m{self.stream}~{encoded_map}"
return f"s{self.stream}"
else:
return "s%d" % (self.stream,)
@ -756,6 +793,11 @@ class MultiWriterStreamToken(AbstractMultiWriterStreamToken):
instance_map = {}
for part in parts[1:]:
if not part:
# Handle tokens of the form `m5~`, which were created by
# a bug
continue
key, value = part.split(".")
instance_id = int(key)
pos = int(value)
@ -770,10 +812,15 @@ class MultiWriterStreamToken(AbstractMultiWriterStreamToken):
except CancelledError:
raise
except Exception:
pass
# We log an exception here as even though this *might* be a client
# handing a bad token, its more likely that Synapse returned a bad
# token (and we really want to catch those!).
logger.exception("Failed to parse stream token: %r", string)
raise SynapseError(400, "Invalid stream token %r" % (string,))
async def to_string(self, store: "DataStore") -> str:
"""See class level docstring for information about the format."""
if self.instance_map:
entries = []
for name, pos in self.instance_map.items():
@ -786,8 +833,10 @@ class MultiWriterStreamToken(AbstractMultiWriterStreamToken):
instance_id = await store.get_id_for_instance(name)
entries.append(f"{instance_id}.{pos}")
encoded_map = "~".join(entries)
return f"m{self.stream}~{encoded_map}"
if entries:
encoded_map = "~".join(entries)
return f"m{self.stream}~{encoded_map}"
return str(self.stream)
else:
return str(self.stream)

View file

@ -19,9 +19,18 @@
#
#
from typing import Type
from unittest import skipUnless
from immutabledict import immutabledict
from parameterized import parameterized_class
from synapse.api.errors import SynapseError
from synapse.types import (
AbstractMultiWriterStreamToken,
MultiWriterStreamToken,
RoomAlias,
RoomStreamToken,
UserID,
get_domain_from_id,
get_localpart_from_id,
@ -29,6 +38,7 @@ from synapse.types import (
)
from tests import unittest
from tests.utils import USE_POSTGRES_FOR_TESTS
class IsMineIDTests(unittest.HomeserverTestCase):
@ -127,3 +137,64 @@ class MapUsernameTestCase(unittest.TestCase):
# this should work with either a unicode or a bytes
self.assertEqual(map_username_to_mxid_localpart("têst"), "t=c3=aast")
self.assertEqual(map_username_to_mxid_localpart("têst".encode()), "t=c3=aast")
@parameterized_class(
("token_type",),
[
(MultiWriterStreamToken,),
(RoomStreamToken,),
],
class_name_func=lambda cls, num, params_dict: f"{cls.__name__}_{params_dict['token_type'].__name__}",
)
class MultiWriterTokenTestCase(unittest.HomeserverTestCase):
"""Tests for the different types of multi writer tokens."""
token_type: Type[AbstractMultiWriterStreamToken]
def test_basic_token(self) -> None:
"""Test that a simple stream token can be serialized and unserialized"""
store = self.hs.get_datastores().main
token = self.token_type(stream=5)
string_token = self.get_success(token.to_string(store))
if isinstance(token, RoomStreamToken):
self.assertEqual(string_token, "s5")
else:
self.assertEqual(string_token, "5")
parsed_token = self.get_success(self.token_type.parse(store, string_token))
self.assertEqual(parsed_token, token)
@skipUnless(USE_POSTGRES_FOR_TESTS, "Requires Postgres")
def test_instance_map(self) -> None:
"""Test for stream token with instance map"""
store = self.hs.get_datastores().main
token = self.token_type(stream=5, instance_map=immutabledict({"foo": 6}))
string_token = self.get_success(token.to_string(store))
self.assertEqual(string_token, "m5~1.6")
parsed_token = self.get_success(self.token_type.parse(store, string_token))
self.assertEqual(parsed_token, token)
def test_instance_map_assertion(self) -> None:
"""Test that we assert values in the instance map are greater than the
min stream position"""
with self.assertRaises(ValueError):
self.token_type(stream=5, instance_map=immutabledict({"foo": 4}))
with self.assertRaises(ValueError):
self.token_type(stream=5, instance_map=immutabledict({"foo": 5}))
def test_parse_bad_token(self) -> None:
"""Test that we can parse tokens produced by a bug in Synapse of the
form `m5~`"""
store = self.hs.get_datastores().main
parsed_token = self.get_success(self.token_type.parse(store, "m5~"))
self.assertEqual(parsed_token, self.token_type(stream=5))