mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-14 09:24:04 +01:00
Factor out MultiWriter
token from RoomStreamToken
(#16427)
This commit is contained in:
parent
ab9c1e8f39
commit
009b47badf
9 changed files with 115 additions and 61 deletions
1
changelog.d/16427.misc
Normal file
1
changelog.d/16427.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Factor out `MultiWriter` token from `RoomStreamToken`.
|
|
@ -171,8 +171,8 @@ class AdminHandler:
|
||||||
else:
|
else:
|
||||||
stream_ordering = room.stream_ordering
|
stream_ordering = room.stream_ordering
|
||||||
|
|
||||||
from_key = RoomStreamToken(0, 0)
|
from_key = RoomStreamToken(topological=0, stream=0)
|
||||||
to_key = RoomStreamToken(None, stream_ordering)
|
to_key = RoomStreamToken(stream=stream_ordering)
|
||||||
|
|
||||||
# Events that we've processed in this room
|
# Events that we've processed in this room
|
||||||
written_events: Set[str] = set()
|
written_events: Set[str] = set()
|
||||||
|
|
|
@ -192,8 +192,7 @@ class InitialSyncHandler:
|
||||||
)
|
)
|
||||||
elif event.membership == Membership.LEAVE:
|
elif event.membership == Membership.LEAVE:
|
||||||
room_end_token = RoomStreamToken(
|
room_end_token = RoomStreamToken(
|
||||||
None,
|
stream=event.stream_ordering,
|
||||||
event.stream_ordering,
|
|
||||||
)
|
)
|
||||||
deferred_room_state = run_in_background(
|
deferred_room_state = run_in_background(
|
||||||
self._state_storage_controller.get_state_for_events,
|
self._state_storage_controller.get_state_for_events,
|
||||||
|
|
|
@ -1708,7 +1708,7 @@ class RoomEventSource(EventSource[RoomStreamToken, EventBase]):
|
||||||
|
|
||||||
if from_key.topological:
|
if from_key.topological:
|
||||||
logger.warning("Stream has topological part!!!! %r", from_key)
|
logger.warning("Stream has topological part!!!! %r", from_key)
|
||||||
from_key = RoomStreamToken(None, from_key.stream)
|
from_key = RoomStreamToken(stream=from_key.stream)
|
||||||
|
|
||||||
app_service = self.store.get_app_service_by_user_id(user.to_string())
|
app_service = self.store.get_app_service_by_user_id(user.to_string())
|
||||||
if app_service:
|
if app_service:
|
||||||
|
|
|
@ -2333,7 +2333,7 @@ class SyncHandler:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
leave_token = now_token.copy_and_replace(
|
leave_token = now_token.copy_and_replace(
|
||||||
StreamKeyType.ROOM, RoomStreamToken(None, event.stream_ordering)
|
StreamKeyType.ROOM, RoomStreamToken(stream=event.stream_ordering)
|
||||||
)
|
)
|
||||||
room_entries.append(
|
room_entries.append(
|
||||||
RoomSyncResultBuilder(
|
RoomSyncResultBuilder(
|
||||||
|
|
|
@ -146,7 +146,7 @@ class PurgeHistoryRestServlet(RestServlet):
|
||||||
# RoomStreamToken expects [int] not Optional[int]
|
# RoomStreamToken expects [int] not Optional[int]
|
||||||
assert event.internal_metadata.stream_ordering is not None
|
assert event.internal_metadata.stream_ordering is not None
|
||||||
room_token = RoomStreamToken(
|
room_token = RoomStreamToken(
|
||||||
event.depth, event.internal_metadata.stream_ordering
|
topological=event.depth, stream=event.internal_metadata.stream_ordering
|
||||||
)
|
)
|
||||||
token = await room_token.to_string(self.store)
|
token = await room_token.to_string(self.store)
|
||||||
|
|
||||||
|
|
|
@ -266,7 +266,7 @@ def generate_next_token(
|
||||||
# when we are going backwards so we subtract one from the
|
# when we are going backwards so we subtract one from the
|
||||||
# stream part.
|
# stream part.
|
||||||
last_stream_ordering -= 1
|
last_stream_ordering -= 1
|
||||||
return RoomStreamToken(last_topo_ordering, last_stream_ordering)
|
return RoomStreamToken(topological=last_topo_ordering, stream=last_stream_ordering)
|
||||||
|
|
||||||
|
|
||||||
def _make_generic_sql_bound(
|
def _make_generic_sql_bound(
|
||||||
|
@ -558,7 +558,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
if p > min_pos
|
if p > min_pos
|
||||||
}
|
}
|
||||||
|
|
||||||
return RoomStreamToken(None, min_pos, immutabledict(positions))
|
return RoomStreamToken(stream=min_pos, instance_map=immutabledict(positions))
|
||||||
|
|
||||||
async def get_room_events_stream_for_rooms(
|
async def get_room_events_stream_for_rooms(
|
||||||
self,
|
self,
|
||||||
|
@ -708,7 +708,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
ret.reverse()
|
ret.reverse()
|
||||||
|
|
||||||
if rows:
|
if rows:
|
||||||
key = RoomStreamToken(None, min(r.stream_ordering for r in rows))
|
key = RoomStreamToken(stream=min(r.stream_ordering for r in rows))
|
||||||
else:
|
else:
|
||||||
# Assume we didn't get anything because there was nothing to
|
# Assume we didn't get anything because there was nothing to
|
||||||
# get.
|
# get.
|
||||||
|
@ -969,7 +969,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
topo = await self.db_pool.runInteraction(
|
topo = await self.db_pool.runInteraction(
|
||||||
"_get_max_topological_txn", self._get_max_topological_txn, room_id
|
"_get_max_topological_txn", self._get_max_topological_txn, room_id
|
||||||
)
|
)
|
||||||
return RoomStreamToken(topo, stream_ordering)
|
return RoomStreamToken(topological=topo, stream=stream_ordering)
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def get_stream_id_for_event_txn(
|
def get_stream_id_for_event_txn(
|
||||||
|
@ -1033,7 +1033,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
retcols=("stream_ordering", "topological_ordering"),
|
retcols=("stream_ordering", "topological_ordering"),
|
||||||
desc="get_topological_token_for_event",
|
desc="get_topological_token_for_event",
|
||||||
)
|
)
|
||||||
return RoomStreamToken(row["topological_ordering"], row["stream_ordering"])
|
return RoomStreamToken(
|
||||||
|
topological=row["topological_ordering"], stream=row["stream_ordering"]
|
||||||
|
)
|
||||||
|
|
||||||
async def get_current_topological_token(self, room_id: str, stream_key: int) -> int:
|
async def get_current_topological_token(self, room_id: str, stream_key: int) -> int:
|
||||||
"""Gets the topological token in a room after or at the given stream
|
"""Gets the topological token in a room after or at the given stream
|
||||||
|
@ -1114,8 +1116,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
else:
|
else:
|
||||||
topo = None
|
topo = None
|
||||||
internal = event.internal_metadata
|
internal = event.internal_metadata
|
||||||
internal.before = RoomStreamToken(topo, stream - 1)
|
internal.before = RoomStreamToken(topological=topo, stream=stream - 1)
|
||||||
internal.after = RoomStreamToken(topo, stream)
|
internal.after = RoomStreamToken(topological=topo, stream=stream)
|
||||||
internal.order = (int(topo) if topo else 0, int(stream))
|
internal.order = (int(topo) if topo else 0, int(stream))
|
||||||
|
|
||||||
async def get_events_around(
|
async def get_events_around(
|
||||||
|
@ -1191,11 +1193,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
# Paginating backwards includes the event at the token, but paginating
|
# Paginating backwards includes the event at the token, but paginating
|
||||||
# forward doesn't.
|
# forward doesn't.
|
||||||
before_token = RoomStreamToken(
|
before_token = RoomStreamToken(
|
||||||
results["topological_ordering"] - 1, results["stream_ordering"]
|
topological=results["topological_ordering"] - 1,
|
||||||
|
stream=results["stream_ordering"],
|
||||||
)
|
)
|
||||||
|
|
||||||
after_token = RoomStreamToken(
|
after_token = RoomStreamToken(
|
||||||
results["topological_ordering"], results["stream_ordering"]
|
topological=results["topological_ordering"],
|
||||||
|
stream=results["stream_ordering"],
|
||||||
)
|
)
|
||||||
|
|
||||||
rows, start_token = self._paginate_room_events_txn(
|
rows, start_token = self._paginate_room_events_txn(
|
||||||
|
|
|
@ -61,6 +61,8 @@ from synapse.util.cancellation import cancellable
|
||||||
from synapse.util.stringutils import parse_and_validate_server_name
|
from synapse.util.stringutils import parse_and_validate_server_name
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
from synapse.appservice.api import ApplicationService
|
from synapse.appservice.api import ApplicationService
|
||||||
from synapse.storage.databases.main import DataStore, PurgeEventsStore
|
from synapse.storage.databases.main import DataStore, PurgeEventsStore
|
||||||
from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore
|
from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore
|
||||||
|
@ -437,7 +439,78 @@ def map_username_to_mxid_localpart(
|
||||||
|
|
||||||
|
|
||||||
@attr.s(frozen=True, slots=True, order=False)
|
@attr.s(frozen=True, slots=True, order=False)
|
||||||
class RoomStreamToken:
|
class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
|
||||||
|
"""An abstract stream token class for streams that supports multiple
|
||||||
|
writers.
|
||||||
|
|
||||||
|
This works by keeping track of the stream position of each writer,
|
||||||
|
represented by a default `stream` attribute and a map of instance name to
|
||||||
|
stream position of any writers that are ahead of the default stream
|
||||||
|
position.
|
||||||
|
"""
|
||||||
|
|
||||||
|
stream: int = attr.ib(validator=attr.validators.instance_of(int), kw_only=True)
|
||||||
|
|
||||||
|
instance_map: "immutabledict[str, int]" = attr.ib(
|
||||||
|
factory=immutabledict,
|
||||||
|
validator=attr.validators.deep_mapping(
|
||||||
|
key_validator=attr.validators.instance_of(str),
|
||||||
|
value_validator=attr.validators.instance_of(int),
|
||||||
|
mapping_validator=attr.validators.instance_of(immutabledict),
|
||||||
|
),
|
||||||
|
kw_only=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def parse(cls, store: "DataStore", string: str) -> "Self":
|
||||||
|
"""Parse the string representation of the token."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def to_string(self, store: "DataStore") -> str:
|
||||||
|
"""Serialize the token into its string representation."""
|
||||||
|
...
|
||||||
|
|
||||||
|
def copy_and_advance(self, other: "Self") -> "Self":
|
||||||
|
"""Return a new token such that if an event is after both this token and
|
||||||
|
the other token, then its after the returned token too.
|
||||||
|
"""
|
||||||
|
|
||||||
|
max_stream = max(self.stream, other.stream)
|
||||||
|
|
||||||
|
instance_map = {
|
||||||
|
instance: max(
|
||||||
|
self.instance_map.get(instance, self.stream),
|
||||||
|
other.instance_map.get(instance, other.stream),
|
||||||
|
)
|
||||||
|
for instance in set(self.instance_map).union(other.instance_map)
|
||||||
|
}
|
||||||
|
|
||||||
|
return attr.evolve(
|
||||||
|
self, stream=max_stream, instance_map=immutabledict(instance_map)
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_max_stream_pos(self) -> int:
|
||||||
|
"""Get the maximum stream position referenced in this token.
|
||||||
|
|
||||||
|
The corresponding "min" position is, by definition just `self.stream`.
|
||||||
|
|
||||||
|
This is used to handle tokens that have non-empty `instance_map`, and so
|
||||||
|
reference stream positions after the `self.stream` position.
|
||||||
|
"""
|
||||||
|
return max(self.instance_map.values(), default=self.stream)
|
||||||
|
|
||||||
|
def get_stream_pos_for_instance(self, instance_name: str) -> int:
|
||||||
|
"""Get the stream position that the given writer was at at this token."""
|
||||||
|
|
||||||
|
# If we don't have an entry for the instance we can assume that it was
|
||||||
|
# at `self.stream`.
|
||||||
|
return self.instance_map.get(instance_name, self.stream)
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(frozen=True, slots=True, order=False)
|
||||||
|
class RoomStreamToken(AbstractMultiWriterStreamToken):
|
||||||
"""Tokens are positions between events. The token "s1" comes after event 1.
|
"""Tokens are positions between events. The token "s1" comes after event 1.
|
||||||
|
|
||||||
s0 s1
|
s0 s1
|
||||||
|
@ -514,16 +587,8 @@ class RoomStreamToken:
|
||||||
|
|
||||||
topological: Optional[int] = attr.ib(
|
topological: Optional[int] = attr.ib(
|
||||||
validator=attr.validators.optional(attr.validators.instance_of(int)),
|
validator=attr.validators.optional(attr.validators.instance_of(int)),
|
||||||
)
|
kw_only=True,
|
||||||
stream: int = attr.ib(validator=attr.validators.instance_of(int))
|
default=None,
|
||||||
|
|
||||||
instance_map: "immutabledict[str, int]" = attr.ib(
|
|
||||||
factory=immutabledict,
|
|
||||||
validator=attr.validators.deep_mapping(
|
|
||||||
key_validator=attr.validators.instance_of(str),
|
|
||||||
value_validator=attr.validators.instance_of(int),
|
|
||||||
mapping_validator=attr.validators.instance_of(immutabledict),
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def __attrs_post_init__(self) -> None:
|
def __attrs_post_init__(self) -> None:
|
||||||
|
@ -583,17 +648,7 @@ class RoomStreamToken:
|
||||||
if self.topological or other.topological:
|
if self.topological or other.topological:
|
||||||
raise Exception("Can't advance topological tokens")
|
raise Exception("Can't advance topological tokens")
|
||||||
|
|
||||||
max_stream = max(self.stream, other.stream)
|
return super().copy_and_advance(other)
|
||||||
|
|
||||||
instance_map = {
|
|
||||||
instance: max(
|
|
||||||
self.instance_map.get(instance, self.stream),
|
|
||||||
other.instance_map.get(instance, other.stream),
|
|
||||||
)
|
|
||||||
for instance in set(self.instance_map).union(other.instance_map)
|
|
||||||
}
|
|
||||||
|
|
||||||
return RoomStreamToken(None, max_stream, immutabledict(instance_map))
|
|
||||||
|
|
||||||
def as_historical_tuple(self) -> Tuple[int, int]:
|
def as_historical_tuple(self) -> Tuple[int, int]:
|
||||||
"""Returns a tuple of `(topological, stream)` for historical tokens.
|
"""Returns a tuple of `(topological, stream)` for historical tokens.
|
||||||
|
@ -619,16 +674,6 @@ class RoomStreamToken:
|
||||||
# at `self.stream`.
|
# at `self.stream`.
|
||||||
return self.instance_map.get(instance_name, self.stream)
|
return self.instance_map.get(instance_name, self.stream)
|
||||||
|
|
||||||
def get_max_stream_pos(self) -> int:
|
|
||||||
"""Get the maximum stream position referenced in this token.
|
|
||||||
|
|
||||||
The corresponding "min" position is, by definition just `self.stream`.
|
|
||||||
|
|
||||||
This is used to handle tokens that have non-empty `instance_map`, and so
|
|
||||||
reference stream positions after the `self.stream` position.
|
|
||||||
"""
|
|
||||||
return max(self.instance_map.values(), default=self.stream)
|
|
||||||
|
|
||||||
async def to_string(self, store: "DataStore") -> str:
|
async def to_string(self, store: "DataStore") -> str:
|
||||||
if self.topological is not None:
|
if self.topological is not None:
|
||||||
return "t%d-%d" % (self.topological, self.stream)
|
return "t%d-%d" % (self.topological, self.stream)
|
||||||
|
@ -838,23 +883,28 @@ class StreamToken:
|
||||||
return getattr(self, key.value)
|
return getattr(self, key.value)
|
||||||
|
|
||||||
|
|
||||||
StreamToken.START = StreamToken(RoomStreamToken(None, 0), 0, 0, 0, 0, 0, 0, 0, 0, 0)
|
StreamToken.START = StreamToken(RoomStreamToken(stream=0), 0, 0, 0, 0, 0, 0, 0, 0, 0)
|
||||||
|
|
||||||
|
|
||||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||||
class PersistedEventPosition:
|
class PersistedPosition:
|
||||||
|
"""Position of a newly persisted row with instance that persisted it."""
|
||||||
|
|
||||||
|
instance_name: str
|
||||||
|
stream: int
|
||||||
|
|
||||||
|
def persisted_after(self, token: AbstractMultiWriterStreamToken) -> bool:
|
||||||
|
return token.get_stream_pos_for_instance(self.instance_name) < self.stream
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||||
|
class PersistedEventPosition(PersistedPosition):
|
||||||
"""Position of a newly persisted event with instance that persisted it.
|
"""Position of a newly persisted event with instance that persisted it.
|
||||||
|
|
||||||
This can be used to test whether the event is persisted before or after a
|
This can be used to test whether the event is persisted before or after a
|
||||||
RoomStreamToken.
|
RoomStreamToken.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
instance_name: str
|
|
||||||
stream: int
|
|
||||||
|
|
||||||
def persisted_after(self, token: RoomStreamToken) -> bool:
|
|
||||||
return token.get_stream_pos_for_instance(self.instance_name) < self.stream
|
|
||||||
|
|
||||||
def to_room_stream_token(self) -> RoomStreamToken:
|
def to_room_stream_token(self) -> RoomStreamToken:
|
||||||
"""Converts the position to a room stream token such that events
|
"""Converts the position to a room stream token such that events
|
||||||
persisted in the same room after this position will be after the
|
persisted in the same room after this position will be after the
|
||||||
|
@ -865,7 +915,7 @@ class PersistedEventPosition:
|
||||||
"""
|
"""
|
||||||
# Doing the naive thing satisfies the desired properties described in
|
# Doing the naive thing satisfies the desired properties described in
|
||||||
# the docstring.
|
# the docstring.
|
||||||
return RoomStreamToken(None, self.stream)
|
return RoomStreamToken(stream=self.stream)
|
||||||
|
|
||||||
|
|
||||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||||
|
|
|
@ -86,7 +86,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
||||||
[event],
|
[event],
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.handler.notify_interested_services(RoomStreamToken(None, 1))
|
self.handler.notify_interested_services(RoomStreamToken(stream=1))
|
||||||
|
|
||||||
self.mock_scheduler.enqueue_for_appservice.assert_called_once_with(
|
self.mock_scheduler.enqueue_for_appservice.assert_called_once_with(
|
||||||
interested_service, events=[event]
|
interested_service, events=[event]
|
||||||
|
@ -107,7 +107,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.mock_store.get_events_as_list = AsyncMock(side_effect=[[event]])
|
self.mock_store.get_events_as_list = AsyncMock(side_effect=[[event]])
|
||||||
self.handler.notify_interested_services(RoomStreamToken(None, 0))
|
self.handler.notify_interested_services(RoomStreamToken(stream=0))
|
||||||
|
|
||||||
self.mock_as_api.query_user.assert_called_once_with(services[0], user_id)
|
self.mock_as_api.query_user.assert_called_once_with(services[0], user_id)
|
||||||
|
|
||||||
|
@ -126,7 +126,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
self.handler.notify_interested_services(RoomStreamToken(None, 0))
|
self.handler.notify_interested_services(RoomStreamToken(stream=0))
|
||||||
|
|
||||||
self.assertFalse(
|
self.assertFalse(
|
||||||
self.mock_as_api.query_user.called,
|
self.mock_as_api.query_user.called,
|
||||||
|
@ -441,7 +441,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.hs.get_application_service_handler()._notify_interested_services(
|
self.hs.get_application_service_handler()._notify_interested_services(
|
||||||
RoomStreamToken(
|
RoomStreamToken(
|
||||||
None, self.hs.get_application_service_handler().current_max
|
stream=self.hs.get_application_service_handler().current_max
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in a new issue