0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-12 04:52:26 +01:00

Sliding sync: Store the per-connection state in the database. (#17599)

Based on #17600

---------

Co-authored-by: Eric Eastwood <eric.eastwood@beta.gouv.fr>
This commit is contained in:
Erik Johnston 2024-08-29 16:26:58 +01:00 committed by Erik Johnston
parent dab88a7b1f
commit b913aaa788
14 changed files with 694 additions and 116 deletions

1
changelog.d/17599.misc Normal file
View file

@ -0,0 +1 @@
Store sliding sync per-connection state in the database.

View file

@ -98,6 +98,7 @@ from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.databases.main.search import SearchStore from synapse.storage.databases.main.search import SearchStore
from synapse.storage.databases.main.session import SessionStore from synapse.storage.databases.main.session import SessionStore
from synapse.storage.databases.main.signatures import SignatureWorkerStore from synapse.storage.databases.main.signatures import SignatureWorkerStore
from synapse.storage.databases.main.sliding_sync import SlidingSyncStore
from synapse.storage.databases.main.state import StateGroupWorkerStore from synapse.storage.databases.main.state import StateGroupWorkerStore
from synapse.storage.databases.main.stats import StatsStore from synapse.storage.databases.main.stats import StatsStore
from synapse.storage.databases.main.stream import StreamWorkerStore from synapse.storage.databases.main.stream import StreamWorkerStore
@ -159,6 +160,7 @@ class GenericWorkerStore(
SessionStore, SessionStore,
TaskSchedulerWorkerStore, TaskSchedulerWorkerStore,
ExperimentalFeaturesStore, ExperimentalFeaturesStore,
SlidingSyncStore,
): ):
# Properties that multiple storage classes define. Tell mypy what the # Properties that multiple storage classes define. Tell mypy what the
# expected type is. # expected type is.

View file

@ -100,7 +100,7 @@ class SlidingSyncHandler:
self.rooms_to_exclude_globally = hs.config.server.rooms_to_exclude_from_sync self.rooms_to_exclude_globally = hs.config.server.rooms_to_exclude_from_sync
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
self.connection_store = SlidingSyncConnectionStore() self.connection_store = SlidingSyncConnectionStore(self.store)
self.extensions = SlidingSyncExtensionHandler(hs) self.extensions = SlidingSyncExtensionHandler(hs)
self.room_lists = SlidingSyncRoomLists(hs) self.room_lists = SlidingSyncRoomLists(hs)
@ -221,16 +221,11 @@ class SlidingSyncHandler:
# amount of time (more with round-trips and re-processing) in the end to # amount of time (more with round-trips and re-processing) in the end to
# get everything again. # get everything again.
previous_connection_state = ( previous_connection_state = (
await self.connection_store.get_per_connection_state( await self.connection_store.get_and_clear_connection_positions(
sync_config, from_token sync_config, from_token
) )
) )
await self.connection_store.mark_token_seen(
sync_config=sync_config,
from_token=from_token,
)
# Get all of the room IDs that the user should be able to see in the sync # Get all of the room IDs that the user should be able to see in the sync
# response # response
has_lists = sync_config.lists is not None and len(sync_config.lists) > 0 has_lists = sync_config.lists is not None and len(sync_config.lists) > 0

View file

@ -13,12 +13,12 @@
# #
import logging import logging
from typing import TYPE_CHECKING, Dict, Optional, Tuple from typing import TYPE_CHECKING, Optional
import attr import attr
from synapse.api.errors import SlidingSyncUnknownPosition
from synapse.logging.opentracing import trace from synapse.logging.opentracing import trace
from synapse.storage.databases.main import DataStore
from synapse.types import SlidingSyncStreamToken from synapse.types import SlidingSyncStreamToken
from synapse.types.handlers.sliding_sync import ( from synapse.types.handlers.sliding_sync import (
MutablePerConnectionState, MutablePerConnectionState,
@ -61,22 +61,9 @@ class SlidingSyncConnectionStore:
to mapping of room ID to `HaveSentRoom`. to mapping of room ID to `HaveSentRoom`.
""" """
# `(user_id, conn_id)` -> `connection_position` -> `PerConnectionState` store: "DataStore"
_connections: Dict[Tuple[str, str], Dict[int, PerConnectionState]] = attr.Factory(
dict
)
async def is_valid_token( async def get_and_clear_connection_positions(
self, sync_config: SlidingSyncConfig, connection_token: int
) -> bool:
"""Return whether the connection token is valid/recognized"""
if connection_token == 0:
return True
conn_key = self._get_connection_key(sync_config)
return connection_token in self._connections.get(conn_key, {})
async def get_per_connection_state(
self, self,
sync_config: SlidingSyncConfig, sync_config: SlidingSyncConfig,
from_token: Optional[SlidingSyncStreamToken], from_token: Optional[SlidingSyncStreamToken],
@ -86,23 +73,21 @@ class SlidingSyncConnectionStore:
Raises: Raises:
SlidingSyncUnknownPosition if the connection_token is unknown SlidingSyncUnknownPosition if the connection_token is unknown
""" """
if from_token is None: # If this is our first request, there is no previous connection state to fetch out of the database
if from_token is None or from_token.connection_position == 0:
return PerConnectionState() return PerConnectionState()
connection_position = from_token.connection_position conn_id = sync_config.conn_id or ""
if connection_position == 0:
# Initial sync (request without a `from_token`) starts at `0` so
# there is no existing per-connection state
return PerConnectionState()
conn_key = self._get_connection_key(sync_config) device_id = sync_config.requester.device_id
sync_statuses = self._connections.get(conn_key, {}) assert device_id is not None
connection_state = sync_statuses.get(connection_position)
if connection_state is None: return await self.store.get_and_clear_connection_positions(
raise SlidingSyncUnknownPosition() sync_config.user.to_string(),
device_id,
return connection_state conn_id,
from_token.connection_position,
)
@trace @trace
async def record_new_state( async def record_new_state(
@ -116,85 +101,28 @@ class SlidingSyncConnectionStore:
If there are no changes to the state this may return the same token as If there are no changes to the state this may return the same token as
the existing per-connection state. the existing per-connection state.
""" """
prev_connection_token = 0
if from_token is not None:
prev_connection_token = from_token.connection_position
if not new_connection_state.has_updates(): if not new_connection_state.has_updates():
return prev_connection_token if from_token is not None:
return from_token.connection_position
else:
return 0
conn_key = self._get_connection_key(sync_config) # A from token with a zero connection position means there was no
sync_statuses = self._connections.setdefault(conn_key, {}) # previously stored connection state, so we treat a zero the same as
# there being no previous position.
previous_connection_position = None
if from_token is not None and from_token.connection_position != 0:
previous_connection_position = from_token.connection_position
# Generate a new token, removing any existing entries in that token
# (which can happen if requests get resent).
new_store_token = prev_connection_token + 1
sync_statuses.pop(new_store_token, None)
# We copy the `MutablePerConnectionState` so that the inner `ChainMap`s
# don't grow forever.
sync_statuses[new_store_token] = new_connection_state.copy()
return new_store_token
@trace
async def mark_token_seen(
self,
sync_config: SlidingSyncConfig,
from_token: Optional[SlidingSyncStreamToken],
) -> None:
"""We have received a request with the given token, so we can clear out
any other tokens associated with the connection.
If there is no from token then we have started afresh, and so we delete
all tokens associated with the device.
"""
# Clear out any tokens for the connection that doesn't match the one
# from the request.
conn_key = self._get_connection_key(sync_config)
sync_statuses = self._connections.pop(conn_key, {})
if from_token is None:
return
sync_statuses = {
connection_token: room_statuses
for connection_token, room_statuses in sync_statuses.items()
if connection_token == from_token.connection_position
}
if sync_statuses:
self._connections[conn_key] = sync_statuses
@staticmethod
def _get_connection_key(sync_config: SlidingSyncConfig) -> Tuple[str, str]:
"""Return a unique identifier for this connection.
The first part is simply the user ID.
The second part is generally a combination of device ID and conn_id.
However, both these two are optional (e.g. puppet access tokens don't
have device IDs), so this handles those edge cases.
We use this over the raw `conn_id` to avoid clashes between different
clients that use the same `conn_id`. Imagine a user uses a web client
that uses `conn_id: main_sync_loop` and an Android client that also has
a `conn_id: main_sync_loop`.
"""
user_id = sync_config.user.to_string()
# Only one sliding sync connection is allowed per given conn_id (empty
# or not).
conn_id = sync_config.conn_id or "" conn_id = sync_config.conn_id or ""
if sync_config.requester.device_id: device_id = sync_config.requester.device_id
return (user_id, f"D/{sync_config.requester.device_id}/{conn_id}") assert device_id is not None
if sync_config.requester.access_token_id: return await self.store.persist_per_connection_state(
# If we don't have a device, then the access token ID should be a sync_config.user.to_string(),
# stable ID. device_id,
return (user_id, f"A/{sync_config.requester.access_token_id}/{conn_id}") conn_id,
previous_connection_position,
# If we have neither then its likely an AS or some weird token. Either new_connection_state,
# way we can just fail here. )
raise Exception("Cannot use sliding sync with access token type")

View file

@ -64,6 +64,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.types import Connection, Cursor, SQLQueryParameters from synapse.storage.types import Connection, Cursor, SQLQueryParameters
from synapse.types import StrCollection
from synapse.util.async_helpers import delay_cancellation from synapse.util.async_helpers import delay_cancellation
from synapse.util.iterutils import batch_iter from synapse.util.iterutils import batch_iter
@ -1095,6 +1096,48 @@ class DatabasePool:
txn.execute(sql, vals) txn.execute(sql, vals)
@staticmethod
def simple_insert_returning_txn(
txn: LoggingTransaction,
table: str,
values: Dict[str, Any],
returning: StrCollection,
) -> Tuple[Any, ...]:
"""Executes a `INSERT INTO... RETURNING...` statement (or equivalent for
SQLite versions that don't support it).
"""
if txn.database_engine.supports_returning:
sql = "INSERT INTO %s (%s) VALUES(%s) RETURNING %s" % (
table,
", ".join(k for k in values.keys()),
", ".join("?" for _ in values.keys()),
", ".join(k for k in returning),
)
txn.execute(sql, list(values.values()))
row = txn.fetchone()
assert row is not None
return row
else:
# For old versions of SQLite we do a standard insert and then can
# use `last_insert_rowid` to get at the row we just inserted
DatabasePool.simple_insert_txn(
txn,
table=table,
values=values,
)
txn.execute("SELECT last_insert_rowid()")
row = txn.fetchone()
assert row is not None
(rowid,) = row
row = DatabasePool.simple_select_one_txn(
txn, table=table, keyvalues={"rowid": rowid}, retcols=returning
)
assert row is not None
return row
async def simple_insert_many( async def simple_insert_many(
self, self,
table: str, table: str,

View file

@ -33,6 +33,7 @@ from synapse.storage.database import (
LoggingDatabaseConnection, LoggingDatabaseConnection,
LoggingTransaction, LoggingTransaction,
) )
from synapse.storage.databases.main.sliding_sync import SlidingSyncStore
from synapse.storage.databases.main.stats import UserSortOrder from synapse.storage.databases.main.stats import UserSortOrder
from synapse.storage.engines import BaseDatabaseEngine from synapse.storage.engines import BaseDatabaseEngine
from synapse.storage.types import Cursor from synapse.storage.types import Cursor
@ -156,6 +157,7 @@ class DataStore(
LockStore, LockStore,
SessionStore, SessionStore,
TaskSchedulerWorkerStore, TaskSchedulerWorkerStore,
SlidingSyncStore,
): ):
def __init__( def __init__(
self, self,

View file

@ -0,0 +1,491 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
import logging
from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Set, cast
import attr
from synapse.api.errors import SlidingSyncUnknownPosition
from synapse.logging.opentracing import log_kv
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import LoggingTransaction
from synapse.types import MultiWriterStreamToken, RoomStreamToken
from synapse.types.handlers.sliding_sync import (
HaveSentRoom,
HaveSentRoomFlag,
MutablePerConnectionState,
PerConnectionState,
RoomStatusMap,
RoomSyncConfig,
)
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
if TYPE_CHECKING:
from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__)
class SlidingSyncStore(SQLBaseStore):
async def persist_per_connection_state(
self,
user_id: str,
device_id: str,
conn_id: str,
previous_connection_position: Optional[int],
per_connection_state: "MutablePerConnectionState",
) -> int:
"""Persist updates to the per-connection state for a sliding sync
connection.
Returns:
The connection position of the newly persisted state.
"""
# This cast is safe because the downstream code only cares about
# `store.get_id_for_instance(...)` and `StreamWorkerStore` is mixed
# alongside `SlidingSyncStore` wherever we create a store.
store = cast("DataStore", self)
return await self.db_pool.runInteraction(
"persist_per_connection_state",
self.persist_per_connection_state_txn,
user_id=user_id,
device_id=device_id,
conn_id=conn_id,
previous_connection_position=previous_connection_position,
per_connection_state=await PerConnectionStateDB.from_state(
per_connection_state, store
),
)
def persist_per_connection_state_txn(
self,
txn: LoggingTransaction,
user_id: str,
device_id: str,
conn_id: str,
previous_connection_position: Optional[int],
per_connection_state: "PerConnectionStateDB",
) -> int:
# First we fetch (or create) the connection key associated with the
# previous connection position.
if previous_connection_position is not None:
# The `previous_connection_position` is a user-supplied value, so we
# need to make sure that the one they supplied is actually theirs.
sql = """
SELECT connection_key
FROM sliding_sync_connection_positions
INNER JOIN sliding_sync_connections USING (connection_key)
WHERE
connection_position = ?
AND user_id = ? AND effective_device_id = ? AND conn_id = ?
"""
txn.execute(
sql, (previous_connection_position, user_id, device_id, conn_id)
)
row = txn.fetchone()
if row is None:
raise SlidingSyncUnknownPosition()
(connection_key,) = row
else:
# We're restarting the connection, so we clear the previous existing data we
# used to track it. We do this here to ensure that if we get lots of
# one-shot requests we don't stack up lots of entries. We have `ON DELETE
# CASCADE` setup on the dependent tables so this will clear out all the
# associated data.
self.db_pool.simple_delete_txn(
txn,
table="sliding_sync_connections",
keyvalues={
"user_id": user_id,
"effective_device_id": device_id,
"conn_id": conn_id,
},
)
(connection_key,) = self.db_pool.simple_insert_returning_txn(
txn,
table="sliding_sync_connections",
values={
"user_id": user_id,
"effective_device_id": device_id,
"conn_id": conn_id,
"created_ts": self._clock.time_msec(),
},
returning=("connection_key",),
)
# Define a new connection position for the updates
(connection_position,) = self.db_pool.simple_insert_returning_txn(
txn,
table="sliding_sync_connection_positions",
values={
"connection_key": connection_key,
"created_ts": self._clock.time_msec(),
},
returning=("connection_position",),
)
# We need to deduplicate the `required_state` JSON. We do this by
# fetching all JSON associated with the connection and comparing that
# with the updates to `required_state`
# Dict from required state json -> required state ID
required_state_to_id: Dict[str, int] = {}
if previous_connection_position is not None:
rows = self.db_pool.simple_select_list_txn(
txn,
table="sliding_sync_connection_required_state",
keyvalues={"connection_key": connection_key},
retcols=("required_state_id", "required_state"),
)
for required_state_id, required_state in rows:
required_state_to_id[required_state] = required_state_id
room_to_state_ids: Dict[str, int] = {}
unique_required_state: Dict[str, List[str]] = {}
for room_id, room_state in per_connection_state.room_configs.items():
serialized_state = json_encoder.encode(
# We store the required state as a sorted list of event type /
# state key tuples.
sorted(
(event_type, state_key)
for event_type, state_keys in room_state.required_state_map.items()
for state_key in state_keys
)
)
existing_state_id = required_state_to_id.get(serialized_state)
if existing_state_id is not None:
room_to_state_ids[room_id] = existing_state_id
else:
unique_required_state.setdefault(serialized_state, []).append(room_id)
# Insert any new `required_state` json we haven't previously seen.
for serialized_required_state, room_ids in unique_required_state.items():
(required_state_id,) = self.db_pool.simple_insert_returning_txn(
txn,
table="sliding_sync_connection_required_state",
values={
"connection_key": connection_key,
"required_state": serialized_required_state,
},
returning=("required_state_id",),
)
for room_id in room_ids:
room_to_state_ids[room_id] = required_state_id
# Copy over state from the previous connection position (we'll overwrite
# these rows with any changes).
if previous_connection_position is not None:
sql = """
INSERT INTO sliding_sync_connection_streams
(connection_position, stream, room_id, room_status, last_token)
SELECT ?, stream, room_id, room_status, last_token
FROM sliding_sync_connection_streams
WHERE connection_position = ?
"""
txn.execute(sql, (connection_position, previous_connection_position))
sql = """
INSERT INTO sliding_sync_connection_room_configs
(connection_position, room_id, timeline_limit, required_state_id)
SELECT ?, room_id, timeline_limit, required_state_id
FROM sliding_sync_connection_room_configs
WHERE connection_position = ?
"""
txn.execute(sql, (connection_position, previous_connection_position))
# We now upsert the changes to the various streams.
key_values = []
value_values = []
for room_id, have_sent_room in per_connection_state.rooms._statuses.items():
key_values.append((connection_position, "rooms", room_id))
value_values.append(
(have_sent_room.status.value, have_sent_room.last_token)
)
for room_id, have_sent_room in per_connection_state.receipts._statuses.items():
key_values.append((connection_position, "receipts", room_id))
value_values.append(
(have_sent_room.status.value, have_sent_room.last_token)
)
self.db_pool.simple_upsert_many_txn(
txn,
table="sliding_sync_connection_streams",
key_names=(
"connection_position",
"stream",
"room_id",
),
key_values=key_values,
value_names=(
"room_status",
"last_token",
),
value_values=value_values,
)
# ... and upsert changes to the room configs.
keys = []
values = []
for room_id, room_config in per_connection_state.room_configs.items():
keys.append((connection_position, room_id))
values.append((room_config.timeline_limit, room_to_state_ids[room_id]))
self.db_pool.simple_upsert_many_txn(
txn,
table="sliding_sync_connection_room_configs",
key_names=(
"connection_position",
"room_id",
),
key_values=keys,
value_names=(
"timeline_limit",
"required_state_id",
),
value_values=values,
)
return connection_position
@cached(iterable=True, max_entries=100000)
async def get_and_clear_connection_positions(
self, user_id: str, device_id: str, conn_id: str, connection_position: int
) -> "PerConnectionState":
"""Get the per-connection state for the given connection position."""
per_connection_state_db = await self.db_pool.runInteraction(
"get_and_clear_connection_positions",
self._get_and_clear_connection_positions_txn,
user_id=user_id,
device_id=device_id,
conn_id=conn_id,
connection_position=connection_position,
)
# This cast is safe because the downstream code only cares about
# `store.get_id_for_instance(...)` and `StreamWorkerStore` is mixed
# alongside `SlidingSyncStore` wherever we create a store.
store = cast("DataStore", self)
return await per_connection_state_db.to_state(store)
def _get_and_clear_connection_positions_txn(
self,
txn: LoggingTransaction,
user_id: str,
device_id: str,
conn_id: str,
connection_position: int,
) -> "PerConnectionStateDB":
# The `previous_connection_position` is a user-supplied value, so we
# need to make sure that the one they supplied is actually theirs.
sql = """
SELECT connection_key
FROM sliding_sync_connection_positions
INNER JOIN sliding_sync_connections USING (connection_key)
WHERE
connection_position = ?
AND user_id = ? AND effective_device_id = ? AND conn_id = ?
"""
txn.execute(sql, (connection_position, user_id, device_id, conn_id))
row = txn.fetchone()
if row is None:
raise SlidingSyncUnknownPosition()
(connection_key,) = row
# Now that we have seen the client has received and used the connection
# position, we can delete all the other connection positions.
sql = """
DELETE FROM sliding_sync_connection_positions
WHERE connection_key = ? AND connection_position != ?
"""
txn.execute(sql, (connection_key, connection_position))
# Fetch and create a mapping from required state ID to the actual
# required state for the connection.
rows = self.db_pool.simple_select_list_txn(
txn,
table="sliding_sync_connection_required_state",
keyvalues={"connection_key": connection_key},
retcols=(
"required_state_id",
"required_state",
),
)
required_state_map: Dict[int, Dict[str, Set[str]]] = {}
for row in rows:
state = required_state_map[row[0]] = {}
for event_type, state_keys in db_to_json(row[1]):
state[event_type] = set(state_keys)
# Get all the room configs, looking up the required state from the map
# above.
room_config_rows = self.db_pool.simple_select_list_txn(
txn,
table="sliding_sync_connection_room_configs",
keyvalues={"connection_position": connection_position},
retcols=(
"room_id",
"timeline_limit",
"required_state_id",
),
)
room_configs: Dict[str, RoomSyncConfig] = {}
for (
room_id,
timeline_limit,
required_state_id,
) in room_config_rows:
room_configs[room_id] = RoomSyncConfig(
timeline_limit=timeline_limit,
required_state_map=required_state_map[required_state_id],
)
# Now look up the per-room stream data.
rooms: Dict[str, HaveSentRoom[str]] = {}
receipts: Dict[str, HaveSentRoom[str]] = {}
receipt_rows = self.db_pool.simple_select_list_txn(
txn,
table="sliding_sync_connection_streams",
keyvalues={"connection_position": connection_position},
retcols=(
"stream",
"room_id",
"room_status",
"last_token",
),
)
for stream, room_id, room_status, last_token in receipt_rows:
have_sent_room: HaveSentRoom[str] = HaveSentRoom(
status=HaveSentRoomFlag(room_status), last_token=last_token
)
if stream == "rooms":
rooms[room_id] = have_sent_room
elif stream == "receipts":
receipts[room_id] = have_sent_room
else:
# For forwards compatibility we ignore unknown streams, as in
# future we want to be able to easily add more stream types.
logger.warning("Unrecognized sliding sync stream in DB %r", stream)
return PerConnectionStateDB(
rooms=RoomStatusMap(rooms),
receipts=RoomStatusMap(receipts),
room_configs=room_configs,
)
@attr.s(auto_attribs=True, frozen=True)
class PerConnectionStateDB:
"""An equivalent to `PerConnectionState` that holds data in a format stored
in the DB.
The principle difference is that the tokens for the different streams are
serialized to strings.
When persisting this *only* contains updates to the state.
"""
rooms: "RoomStatusMap[str]"
receipts: "RoomStatusMap[str]"
room_configs: Mapping[str, "RoomSyncConfig"]
@staticmethod
async def from_state(
per_connection_state: "MutablePerConnectionState", store: "DataStore"
) -> "PerConnectionStateDB":
"""Convert from a standard `PerConnectionState`"""
rooms = {
room_id: HaveSentRoom(
status=status.status,
last_token=(
await status.last_token.to_string(store)
if status.last_token is not None
else None
),
)
for room_id, status in per_connection_state.rooms.get_updates().items()
}
receipts = {
room_id: HaveSentRoom(
status=status.status,
last_token=(
await status.last_token.to_string(store)
if status.last_token is not None
else None
),
)
for room_id, status in per_connection_state.receipts.get_updates().items()
}
log_kv(
{
"rooms": rooms,
"receipts": receipts,
"room_configs": per_connection_state.room_configs.maps[0],
}
)
return PerConnectionStateDB(
rooms=RoomStatusMap(rooms),
receipts=RoomStatusMap(receipts),
room_configs=per_connection_state.room_configs.maps[0],
)
async def to_state(self, store: "DataStore") -> "PerConnectionState":
"""Convert into a standard `PerConnectionState`"""
rooms = {
room_id: HaveSentRoom(
status=status.status,
last_token=(
await RoomStreamToken.parse(store, status.last_token)
if status.last_token is not None
else None
),
)
for room_id, status in self.rooms._statuses.items()
}
receipts = {
room_id: HaveSentRoom(
status=status.status,
last_token=(
await MultiWriterStreamToken.parse(store, status.last_token)
if status.last_token is not None
else None
),
)
for room_id, status in self.receipts._statuses.items()
}
return PerConnectionState(
rooms=RoomStatusMap(rooms),
receipts=RoomStatusMap(receipts),
room_configs=self.room_configs,
)

View file

@ -28,6 +28,11 @@ if TYPE_CHECKING:
from synapse.storage.database import LoggingDatabaseConnection from synapse.storage.database import LoggingDatabaseConnection
# A string that will be replaced with the appropriate auto increment directive
# for the database engine, expands to an auto incrementing integer primary key.
AUTO_INCREMENT_PRIMARY_KEYPLACEHOLDER = "$%AUTO_INCREMENT_PRIMARY_KEY%$"
class IsolationLevel(IntEnum): class IsolationLevel(IntEnum):
READ_COMMITTED: int = 1 READ_COMMITTED: int = 1
REPEATABLE_READ: int = 2 REPEATABLE_READ: int = 2

View file

@ -25,6 +25,7 @@ from typing import TYPE_CHECKING, Any, Mapping, NoReturn, Optional, Tuple, cast
import psycopg2.extensions import psycopg2.extensions
from synapse.storage.engines._base import ( from synapse.storage.engines._base import (
AUTO_INCREMENT_PRIMARY_KEYPLACEHOLDER,
BaseDatabaseEngine, BaseDatabaseEngine,
IncorrectDatabaseSetup, IncorrectDatabaseSetup,
IsolationLevel, IsolationLevel,
@ -256,4 +257,10 @@ class PostgresEngine(
executing the script in its own transaction. The script transaction is executing the script in its own transaction. The script transaction is
left open and it is the responsibility of the caller to commit it. left open and it is the responsibility of the caller to commit it.
""" """
# Replace auto increment placeholder with the appropriate directive
script = script.replace(
AUTO_INCREMENT_PRIMARY_KEYPLACEHOLDER,
"BIGINT PRIMARY KEY GENERATED ALWAYS AS IDENTITY",
)
cursor.execute(f"COMMIT; BEGIN TRANSACTION; {script}") cursor.execute(f"COMMIT; BEGIN TRANSACTION; {script}")

View file

@ -25,6 +25,7 @@ import threading
from typing import TYPE_CHECKING, Any, List, Mapping, Optional from typing import TYPE_CHECKING, Any, List, Mapping, Optional
from synapse.storage.engines import BaseDatabaseEngine from synapse.storage.engines import BaseDatabaseEngine
from synapse.storage.engines._base import AUTO_INCREMENT_PRIMARY_KEYPLACEHOLDER
from synapse.storage.types import Cursor from synapse.storage.types import Cursor
if TYPE_CHECKING: if TYPE_CHECKING:
@ -168,6 +169,11 @@ class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection, sqlite3.Cursor]):
> first. No other implicit transaction control is performed; any transaction > first. No other implicit transaction control is performed; any transaction
> control must be added to sql_script. > control must be added to sql_script.
""" """
# Replace auto increment placeholder with the appropriate directive
script = script.replace(
AUTO_INCREMENT_PRIMARY_KEYPLACEHOLDER, "INTEGER PRIMARY KEY AUTOINCREMENT"
)
# The implementation of `executescript` can be found at # The implementation of `executescript` can be found at
# https://github.com/python/cpython/blob/3.11/Modules/_sqlite/cursor.c#L1035. # https://github.com/python/cpython/blob/3.11/Modules/_sqlite/cursor.c#L1035.
cursor.executescript(f"BEGIN TRANSACTION; {script}") cursor.executescript(f"BEGIN TRANSACTION; {script}")

View file

@ -19,7 +19,7 @@
# #
# #
SCHEMA_VERSION = 86 # remember to update the list below when updating SCHEMA_VERSION = 87 # remember to update the list below when updating
"""Represents the expectations made by the codebase about the database schema """Represents the expectations made by the codebase about the database schema
This should be incremented whenever the codebase changes its requirements on the This should be incremented whenever the codebase changes its requirements on the
@ -142,6 +142,11 @@ Changes in SCHEMA_VERSION = 85
Changes in SCHEMA_VERSION = 86 Changes in SCHEMA_VERSION = 86
- Add a column `authenticated` to the tables `local_media_repository` and `remote_media_cache` - Add a column `authenticated` to the tables `local_media_repository` and `remote_media_cache`
Changes in SCHEMA_VERSION = 87
- Add tables for storing the per-connection state for sliding sync requests:
sliding_sync_connections, sliding_sync_connection_positions, sliding_sync_connection_required_state,
sliding_sync_connection_room_configs, sliding_sync_connection_streams
""" """

View file

@ -0,0 +1,81 @@
--
-- This file is licensed under the Affero General Public License (AGPL) version 3.
--
-- Copyright (C) 2024 New Vector, Ltd
--
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU Affero General Public License as
-- published by the Free Software Foundation, either version 3 of the
-- License, or (at your option) any later version.
--
-- See the GNU Affero General Public License for more details:
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
-- Table to track active sliding sync connections.
--
-- A new connection will be created for every sliding sync request without a
-- `since` token for a given `conn_id` for a device.#
--
-- Once a new connection is created and used we delete all other connections for
-- the `conn_id`.
CREATE TABLE sliding_sync_connections(
connection_key $%AUTO_INCREMENT_PRIMARY_KEY%$,
user_id TEXT NOT NULL,
-- Generally the device ID, but may be something else for e.g. puppeted accounts.
effective_device_id TEXT NOT NULL,
conn_id TEXT NOT NULL,
created_ts BIGINT NOT NULL
);
CREATE INDEX sliding_sync_connections_idx ON sliding_sync_connections(user_id, effective_device_id, conn_id);
CREATE INDEX sliding_sync_connections_ts_idx ON sliding_sync_connections(created_ts);
-- We track per-connection state by associating changes to the state with
-- connection positions. This ensures that we correctly track state even if we
-- see retries of requests.
--
-- If the client starts a "new" connection (by not specifying a since token),
-- we'll clear out the other connections (to ensure that we don't end up with
-- lots of connection keys).
CREATE TABLE sliding_sync_connection_positions(
connection_position $%AUTO_INCREMENT_PRIMARY_KEY%$,
connection_key BIGINT NOT NULL REFERENCES sliding_sync_connections(connection_key) ON DELETE CASCADE,
created_ts BIGINT NOT NULL
);
CREATE INDEX sliding_sync_connection_positions_key ON sliding_sync_connection_positions(connection_key);
CREATE INDEX sliding_sync_connection_positions_ts_idx ON sliding_sync_connection_positions(created_ts);
-- To save space we deduplicate the `required_state` json by assigning IDs to
-- different values.
CREATE TABLE sliding_sync_connection_required_state(
required_state_id $%AUTO_INCREMENT_PRIMARY_KEY%$,
connection_key BIGINT NOT NULL REFERENCES sliding_sync_connections(connection_key) ON DELETE CASCADE,
required_state TEXT NOT NULL -- We store this as a json list of event type / state key tuples.
);
CREATE INDEX sliding_sync_connection_required_state_conn_pos ON sliding_sync_connection_required_state(connection_key);
-- Stores the room configs we have seen for rooms in a connection.
CREATE TABLE sliding_sync_connection_room_configs(
connection_position BIGINT NOT NULL REFERENCES sliding_sync_connection_positions(connection_position) ON DELETE CASCADE,
room_id TEXT NOT NULL,
timeline_limit BIGINT NOT NULL,
required_state_id BIGINT NOT NULL REFERENCES sliding_sync_connection_required_state(required_state_id)
);
CREATE UNIQUE INDEX sliding_sync_connection_room_configs_idx ON sliding_sync_connection_room_configs(connection_position, room_id);
-- Stores what data we have sent for given streams down given connections.
CREATE TABLE sliding_sync_connection_streams(
connection_position BIGINT NOT NULL REFERENCES sliding_sync_connection_positions(connection_position) ON DELETE CASCADE,
stream TEXT NOT NULL, -- e.g. "events" or "receipts"
room_id TEXT NOT NULL,
room_status TEXT NOT NULL, -- "live" or "previously", i.e. the `HaveSentRoomFlag` value
last_token TEXT -- For "previously" the token for the stream we have sent up to.
);
CREATE UNIQUE INDEX sliding_sync_connection_streams_idx ON sliding_sync_connection_streams(connection_position, room_id, stream);

View file

@ -730,6 +730,9 @@ class RoomStatusMap(Generic[T]):
return RoomStatusMap(statuses=dict(self._statuses)) return RoomStatusMap(statuses=dict(self._statuses))
def __len__(self) -> int:
return len(self._statuses)
class MutableRoomStatusMap(RoomStatusMap[T]): class MutableRoomStatusMap(RoomStatusMap[T]):
"""A mutable version of `RoomStatusMap`""" """A mutable version of `RoomStatusMap`"""
@ -831,6 +834,9 @@ class PerConnectionState:
room_configs=dict(self.room_configs), room_configs=dict(self.room_configs),
) )
def __len__(self) -> int:
return len(self.rooms) + len(self.receipts) + len(self.room_configs)
@attr.s(auto_attribs=True) @attr.s(auto_attribs=True)
class MutablePerConnectionState(PerConnectionState): class MutablePerConnectionState(PerConnectionState):

View file

@ -191,8 +191,14 @@ class SlidingSyncRoomsRequiredStateTestCase(SlidingSyncBase):
} }
_, from_token = self.do_sync(sync_body, tok=user1_tok) _, from_token = self.do_sync(sync_body, tok=user1_tok)
# Reset the in-memory cache # Reset the positions
self.hs.get_sliding_sync_handler().connection_store._connections.clear() self.get_success(
self.store.db_pool.simple_delete(
table="sliding_sync_connections",
keyvalues={"user_id": user1_id},
desc="clear_sliding_sync_connections_cache",
)
)
# Make the Sliding Sync request # Make the Sliding Sync request
channel = self.make_request( channel = self.make_request(