mirror of
https://mau.dev/maunium/synapse.git
synced 2024-11-14 14:01:59 +01:00
Sliding sync: Store the per-connection state in the database. (#17599)
Based on #17600 --------- Co-authored-by: Eric Eastwood <eric.eastwood@beta.gouv.fr>
This commit is contained in:
parent
2999a14aed
commit
e43c2b023e
14 changed files with 691 additions and 115 deletions
1
changelog.d/17599.misc
Normal file
1
changelog.d/17599.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Store sliding sync per-connection state in the database.
|
|
@ -98,6 +98,7 @@ from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
|
|||
from synapse.storage.databases.main.search import SearchStore
|
||||
from synapse.storage.databases.main.session import SessionStore
|
||||
from synapse.storage.databases.main.signatures import SignatureWorkerStore
|
||||
from synapse.storage.databases.main.sliding_sync import SlidingSyncStore
|
||||
from synapse.storage.databases.main.state import StateGroupWorkerStore
|
||||
from synapse.storage.databases.main.stats import StatsStore
|
||||
from synapse.storage.databases.main.stream import StreamWorkerStore
|
||||
|
@ -159,6 +160,7 @@ class GenericWorkerStore(
|
|||
SessionStore,
|
||||
TaskSchedulerWorkerStore,
|
||||
ExperimentalFeaturesStore,
|
||||
SlidingSyncStore,
|
||||
):
|
||||
# Properties that multiple storage classes define. Tell mypy what the
|
||||
# expected type is.
|
||||
|
|
|
@ -89,7 +89,7 @@ class SlidingSyncHandler:
|
|||
self.rooms_to_exclude_globally = hs.config.server.rooms_to_exclude_from_sync
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
|
||||
self.connection_store = SlidingSyncConnectionStore()
|
||||
self.connection_store = SlidingSyncConnectionStore(self.store)
|
||||
self.extensions = SlidingSyncExtensionHandler(hs)
|
||||
self.room_lists = SlidingSyncRoomLists(hs)
|
||||
|
||||
|
@ -210,16 +210,11 @@ class SlidingSyncHandler:
|
|||
# amount of time (more with round-trips and re-processing) in the end to
|
||||
# get everything again.
|
||||
previous_connection_state = (
|
||||
await self.connection_store.get_per_connection_state(
|
||||
await self.connection_store.get_and_clear_connection_positions(
|
||||
sync_config, from_token
|
||||
)
|
||||
)
|
||||
|
||||
await self.connection_store.mark_token_seen(
|
||||
sync_config=sync_config,
|
||||
from_token=from_token,
|
||||
)
|
||||
|
||||
# Get all of the room IDs that the user should be able to see in the sync
|
||||
# response
|
||||
has_lists = sync_config.lists is not None and len(sync_config.lists) > 0
|
||||
|
|
|
@ -13,12 +13,12 @@
|
|||
#
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.api.errors import SlidingSyncUnknownPosition
|
||||
from synapse.logging.opentracing import trace
|
||||
from synapse.storage.databases.main import DataStore
|
||||
from synapse.types import SlidingSyncStreamToken
|
||||
from synapse.types.handlers.sliding_sync import (
|
||||
MutablePerConnectionState,
|
||||
|
@ -61,22 +61,9 @@ class SlidingSyncConnectionStore:
|
|||
to mapping of room ID to `HaveSentRoom`.
|
||||
"""
|
||||
|
||||
# `(user_id, conn_id)` -> `connection_position` -> `PerConnectionState`
|
||||
_connections: Dict[Tuple[str, str], Dict[int, PerConnectionState]] = attr.Factory(
|
||||
dict
|
||||
)
|
||||
store: "DataStore"
|
||||
|
||||
async def is_valid_token(
|
||||
self, sync_config: SlidingSyncConfig, connection_token: int
|
||||
) -> bool:
|
||||
"""Return whether the connection token is valid/recognized"""
|
||||
if connection_token == 0:
|
||||
return True
|
||||
|
||||
conn_key = self._get_connection_key(sync_config)
|
||||
return connection_token in self._connections.get(conn_key, {})
|
||||
|
||||
async def get_per_connection_state(
|
||||
async def get_and_clear_connection_positions(
|
||||
self,
|
||||
sync_config: SlidingSyncConfig,
|
||||
from_token: Optional[SlidingSyncStreamToken],
|
||||
|
@ -86,23 +73,21 @@ class SlidingSyncConnectionStore:
|
|||
Raises:
|
||||
SlidingSyncUnknownPosition if the connection_token is unknown
|
||||
"""
|
||||
if from_token is None:
|
||||
# If this is our first request, there is no previous connection state to fetch out of the database
|
||||
if from_token is None or from_token.connection_position == 0:
|
||||
return PerConnectionState()
|
||||
|
||||
connection_position = from_token.connection_position
|
||||
if connection_position == 0:
|
||||
# Initial sync (request without a `from_token`) starts at `0` so
|
||||
# there is no existing per-connection state
|
||||
return PerConnectionState()
|
||||
conn_id = sync_config.conn_id or ""
|
||||
|
||||
conn_key = self._get_connection_key(sync_config)
|
||||
sync_statuses = self._connections.get(conn_key, {})
|
||||
connection_state = sync_statuses.get(connection_position)
|
||||
device_id = sync_config.requester.device_id
|
||||
assert device_id is not None
|
||||
|
||||
if connection_state is None:
|
||||
raise SlidingSyncUnknownPosition()
|
||||
|
||||
return connection_state
|
||||
return await self.store.get_and_clear_connection_positions(
|
||||
sync_config.user.to_string(),
|
||||
device_id,
|
||||
conn_id,
|
||||
from_token.connection_position,
|
||||
)
|
||||
|
||||
@trace
|
||||
async def record_new_state(
|
||||
|
@ -116,85 +101,28 @@ class SlidingSyncConnectionStore:
|
|||
If there are no changes to the state this may return the same token as
|
||||
the existing per-connection state.
|
||||
"""
|
||||
prev_connection_token = 0
|
||||
if from_token is not None:
|
||||
prev_connection_token = from_token.connection_position
|
||||
|
||||
if not new_connection_state.has_updates():
|
||||
return prev_connection_token
|
||||
if from_token is not None:
|
||||
return from_token.connection_position
|
||||
else:
|
||||
return 0
|
||||
|
||||
conn_key = self._get_connection_key(sync_config)
|
||||
sync_statuses = self._connections.setdefault(conn_key, {})
|
||||
# A from token with a zero connection position means there was no
|
||||
# previously stored connection state, so we treat a zero the same as
|
||||
# there being no previous position.
|
||||
previous_connection_position = None
|
||||
if from_token is not None and from_token.connection_position != 0:
|
||||
previous_connection_position = from_token.connection_position
|
||||
|
||||
# Generate a new token, removing any existing entries in that token
|
||||
# (which can happen if requests get resent).
|
||||
new_store_token = prev_connection_token + 1
|
||||
sync_statuses.pop(new_store_token, None)
|
||||
|
||||
# We copy the `MutablePerConnectionState` so that the inner `ChainMap`s
|
||||
# don't grow forever.
|
||||
sync_statuses[new_store_token] = new_connection_state.copy()
|
||||
|
||||
return new_store_token
|
||||
|
||||
@trace
|
||||
async def mark_token_seen(
|
||||
self,
|
||||
sync_config: SlidingSyncConfig,
|
||||
from_token: Optional[SlidingSyncStreamToken],
|
||||
) -> None:
|
||||
"""We have received a request with the given token, so we can clear out
|
||||
any other tokens associated with the connection.
|
||||
|
||||
If there is no from token then we have started afresh, and so we delete
|
||||
all tokens associated with the device.
|
||||
"""
|
||||
# Clear out any tokens for the connection that doesn't match the one
|
||||
# from the request.
|
||||
|
||||
conn_key = self._get_connection_key(sync_config)
|
||||
sync_statuses = self._connections.pop(conn_key, {})
|
||||
if from_token is None:
|
||||
return
|
||||
|
||||
sync_statuses = {
|
||||
connection_token: room_statuses
|
||||
for connection_token, room_statuses in sync_statuses.items()
|
||||
if connection_token == from_token.connection_position
|
||||
}
|
||||
if sync_statuses:
|
||||
self._connections[conn_key] = sync_statuses
|
||||
|
||||
@staticmethod
|
||||
def _get_connection_key(sync_config: SlidingSyncConfig) -> Tuple[str, str]:
|
||||
"""Return a unique identifier for this connection.
|
||||
|
||||
The first part is simply the user ID.
|
||||
|
||||
The second part is generally a combination of device ID and conn_id.
|
||||
However, both these two are optional (e.g. puppet access tokens don't
|
||||
have device IDs), so this handles those edge cases.
|
||||
|
||||
We use this over the raw `conn_id` to avoid clashes between different
|
||||
clients that use the same `conn_id`. Imagine a user uses a web client
|
||||
that uses `conn_id: main_sync_loop` and an Android client that also has
|
||||
a `conn_id: main_sync_loop`.
|
||||
"""
|
||||
|
||||
user_id = sync_config.user.to_string()
|
||||
|
||||
# Only one sliding sync connection is allowed per given conn_id (empty
|
||||
# or not).
|
||||
conn_id = sync_config.conn_id or ""
|
||||
|
||||
if sync_config.requester.device_id:
|
||||
return (user_id, f"D/{sync_config.requester.device_id}/{conn_id}")
|
||||
device_id = sync_config.requester.device_id
|
||||
assert device_id is not None
|
||||
|
||||
if sync_config.requester.access_token_id:
|
||||
# If we don't have a device, then the access token ID should be a
|
||||
# stable ID.
|
||||
return (user_id, f"A/{sync_config.requester.access_token_id}/{conn_id}")
|
||||
|
||||
# If we have neither then its likely an AS or some weird token. Either
|
||||
# way we can just fail here.
|
||||
raise Exception("Cannot use sliding sync with access token type")
|
||||
return await self.store.persist_per_connection_state(
|
||||
sync_config.user.to_string(),
|
||||
device_id,
|
||||
conn_id,
|
||||
previous_connection_position,
|
||||
new_connection_state,
|
||||
)
|
||||
|
|
|
@ -65,6 +65,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
|
|||
from synapse.storage.background_updates import BackgroundUpdater
|
||||
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
|
||||
from synapse.storage.types import Connection, Cursor, SQLQueryParameters
|
||||
from synapse.types import StrCollection
|
||||
from synapse.util.async_helpers import delay_cancellation
|
||||
from synapse.util.iterutils import batch_iter
|
||||
|
||||
|
@ -1096,6 +1097,48 @@ class DatabasePool:
|
|||
|
||||
txn.execute(sql, vals)
|
||||
|
||||
@staticmethod
|
||||
def simple_insert_returning_txn(
|
||||
txn: LoggingTransaction,
|
||||
table: str,
|
||||
values: Dict[str, Any],
|
||||
returning: StrCollection,
|
||||
) -> Tuple[Any, ...]:
|
||||
"""Executes a `INSERT INTO... RETURNING...` statement (or equivalent for
|
||||
SQLite versions that don't support it).
|
||||
"""
|
||||
|
||||
if txn.database_engine.supports_returning:
|
||||
sql = "INSERT INTO %s (%s) VALUES(%s) RETURNING %s" % (
|
||||
table,
|
||||
", ".join(k for k in values.keys()),
|
||||
", ".join("?" for _ in values.keys()),
|
||||
", ".join(k for k in returning),
|
||||
)
|
||||
|
||||
txn.execute(sql, list(values.values()))
|
||||
row = txn.fetchone()
|
||||
assert row is not None
|
||||
return row
|
||||
else:
|
||||
# For old versions of SQLite we do a standard insert and then can
|
||||
# use `last_insert_rowid` to get at the row we just inserted
|
||||
DatabasePool.simple_insert_txn(
|
||||
txn,
|
||||
table=table,
|
||||
values=values,
|
||||
)
|
||||
txn.execute("SELECT last_insert_rowid()")
|
||||
row = txn.fetchone()
|
||||
assert row is not None
|
||||
(rowid,) = row
|
||||
|
||||
row = DatabasePool.simple_select_one_txn(
|
||||
txn, table=table, keyvalues={"rowid": rowid}, retcols=returning
|
||||
)
|
||||
assert row is not None
|
||||
return row
|
||||
|
||||
async def simple_insert_many(
|
||||
self,
|
||||
table: str,
|
||||
|
|
|
@ -33,6 +33,7 @@ from synapse.storage.database import (
|
|||
LoggingDatabaseConnection,
|
||||
LoggingTransaction,
|
||||
)
|
||||
from synapse.storage.databases.main.sliding_sync import SlidingSyncStore
|
||||
from synapse.storage.databases.main.stats import UserSortOrder
|
||||
from synapse.storage.engines import BaseDatabaseEngine
|
||||
from synapse.storage.types import Cursor
|
||||
|
@ -156,6 +157,7 @@ class DataStore(
|
|||
LockStore,
|
||||
SessionStore,
|
||||
TaskSchedulerWorkerStore,
|
||||
SlidingSyncStore,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
491
synapse/storage/databases/main/sliding_sync.py
Normal file
491
synapse/storage/databases/main/sliding_sync.py
Normal file
|
@ -0,0 +1,491 @@
|
|||
#
|
||||
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
#
|
||||
# Copyright (C) 2023 New Vector, Ltd
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
# published by the Free Software Foundation, either version 3 of the
|
||||
# License, or (at your option) any later version.
|
||||
#
|
||||
# See the GNU Affero General Public License for more details:
|
||||
# <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
#
|
||||
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Set, cast
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.api.errors import SlidingSyncUnknownPosition
|
||||
from synapse.logging.opentracing import log_kv
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
from synapse.storage.database import LoggingTransaction
|
||||
from synapse.types import MultiWriterStreamToken, RoomStreamToken
|
||||
from synapse.types.handlers.sliding_sync import (
|
||||
HaveSentRoom,
|
||||
HaveSentRoomFlag,
|
||||
MutablePerConnectionState,
|
||||
PerConnectionState,
|
||||
RoomStatusMap,
|
||||
RoomSyncConfig,
|
||||
)
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.caches.descriptors import cached
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.storage.databases.main import DataStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SlidingSyncStore(SQLBaseStore):
|
||||
async def persist_per_connection_state(
|
||||
self,
|
||||
user_id: str,
|
||||
device_id: str,
|
||||
conn_id: str,
|
||||
previous_connection_position: Optional[int],
|
||||
per_connection_state: "MutablePerConnectionState",
|
||||
) -> int:
|
||||
"""Persist updates to the per-connection state for a sliding sync
|
||||
connection.
|
||||
|
||||
Returns:
|
||||
The connection position of the newly persisted state.
|
||||
"""
|
||||
|
||||
# This cast is safe because the downstream code only cares about
|
||||
# `store.get_id_for_instance(...)` and `StreamWorkerStore` is mixed
|
||||
# alongside `SlidingSyncStore` wherever we create a store.
|
||||
store = cast("DataStore", self)
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"persist_per_connection_state",
|
||||
self.persist_per_connection_state_txn,
|
||||
user_id=user_id,
|
||||
device_id=device_id,
|
||||
conn_id=conn_id,
|
||||
previous_connection_position=previous_connection_position,
|
||||
per_connection_state=await PerConnectionStateDB.from_state(
|
||||
per_connection_state, store
|
||||
),
|
||||
)
|
||||
|
||||
def persist_per_connection_state_txn(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
user_id: str,
|
||||
device_id: str,
|
||||
conn_id: str,
|
||||
previous_connection_position: Optional[int],
|
||||
per_connection_state: "PerConnectionStateDB",
|
||||
) -> int:
|
||||
# First we fetch (or create) the connection key associated with the
|
||||
# previous connection position.
|
||||
if previous_connection_position is not None:
|
||||
# The `previous_connection_position` is a user-supplied value, so we
|
||||
# need to make sure that the one they supplied is actually theirs.
|
||||
sql = """
|
||||
SELECT connection_key
|
||||
FROM sliding_sync_connection_positions
|
||||
INNER JOIN sliding_sync_connections USING (connection_key)
|
||||
WHERE
|
||||
connection_position = ?
|
||||
AND user_id = ? AND effective_device_id = ? AND conn_id = ?
|
||||
"""
|
||||
txn.execute(
|
||||
sql, (previous_connection_position, user_id, device_id, conn_id)
|
||||
)
|
||||
row = txn.fetchone()
|
||||
if row is None:
|
||||
raise SlidingSyncUnknownPosition()
|
||||
|
||||
(connection_key,) = row
|
||||
else:
|
||||
# We're restarting the connection, so we clear the previous existing data we
|
||||
# used to track it. We do this here to ensure that if we get lots of
|
||||
# one-shot requests we don't stack up lots of entries. We have `ON DELETE
|
||||
# CASCADE` setup on the dependent tables so this will clear out all the
|
||||
# associated data.
|
||||
self.db_pool.simple_delete_txn(
|
||||
txn,
|
||||
table="sliding_sync_connections",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"effective_device_id": device_id,
|
||||
"conn_id": conn_id,
|
||||
},
|
||||
)
|
||||
|
||||
(connection_key,) = self.db_pool.simple_insert_returning_txn(
|
||||
txn,
|
||||
table="sliding_sync_connections",
|
||||
values={
|
||||
"user_id": user_id,
|
||||
"effective_device_id": device_id,
|
||||
"conn_id": conn_id,
|
||||
"created_ts": self._clock.time_msec(),
|
||||
},
|
||||
returning=("connection_key",),
|
||||
)
|
||||
|
||||
# Define a new connection position for the updates
|
||||
(connection_position,) = self.db_pool.simple_insert_returning_txn(
|
||||
txn,
|
||||
table="sliding_sync_connection_positions",
|
||||
values={
|
||||
"connection_key": connection_key,
|
||||
"created_ts": self._clock.time_msec(),
|
||||
},
|
||||
returning=("connection_position",),
|
||||
)
|
||||
|
||||
# We need to deduplicate the `required_state` JSON. We do this by
|
||||
# fetching all JSON associated with the connection and comparing that
|
||||
# with the updates to `required_state`
|
||||
|
||||
# Dict from required state json -> required state ID
|
||||
required_state_to_id: Dict[str, int] = {}
|
||||
if previous_connection_position is not None:
|
||||
rows = self.db_pool.simple_select_list_txn(
|
||||
txn,
|
||||
table="sliding_sync_connection_required_state",
|
||||
keyvalues={"connection_key": connection_key},
|
||||
retcols=("required_state_id", "required_state"),
|
||||
)
|
||||
for required_state_id, required_state in rows:
|
||||
required_state_to_id[required_state] = required_state_id
|
||||
|
||||
room_to_state_ids: Dict[str, int] = {}
|
||||
unique_required_state: Dict[str, List[str]] = {}
|
||||
for room_id, room_state in per_connection_state.room_configs.items():
|
||||
serialized_state = json_encoder.encode(
|
||||
# We store the required state as a sorted list of event type /
|
||||
# state key tuples.
|
||||
sorted(
|
||||
(event_type, state_key)
|
||||
for event_type, state_keys in room_state.required_state_map.items()
|
||||
for state_key in state_keys
|
||||
)
|
||||
)
|
||||
|
||||
existing_state_id = required_state_to_id.get(serialized_state)
|
||||
if existing_state_id is not None:
|
||||
room_to_state_ids[room_id] = existing_state_id
|
||||
else:
|
||||
unique_required_state.setdefault(serialized_state, []).append(room_id)
|
||||
|
||||
# Insert any new `required_state` json we haven't previously seen.
|
||||
for serialized_required_state, room_ids in unique_required_state.items():
|
||||
(required_state_id,) = self.db_pool.simple_insert_returning_txn(
|
||||
txn,
|
||||
table="sliding_sync_connection_required_state",
|
||||
values={
|
||||
"connection_key": connection_key,
|
||||
"required_state": serialized_required_state,
|
||||
},
|
||||
returning=("required_state_id",),
|
||||
)
|
||||
for room_id in room_ids:
|
||||
room_to_state_ids[room_id] = required_state_id
|
||||
|
||||
# Copy over state from the previous connection position (we'll overwrite
|
||||
# these rows with any changes).
|
||||
if previous_connection_position is not None:
|
||||
sql = """
|
||||
INSERT INTO sliding_sync_connection_streams
|
||||
(connection_position, stream, room_id, room_status, last_token)
|
||||
SELECT ?, stream, room_id, room_status, last_token
|
||||
FROM sliding_sync_connection_streams
|
||||
WHERE connection_position = ?
|
||||
"""
|
||||
txn.execute(sql, (connection_position, previous_connection_position))
|
||||
|
||||
sql = """
|
||||
INSERT INTO sliding_sync_connection_room_configs
|
||||
(connection_position, room_id, timeline_limit, required_state_id)
|
||||
SELECT ?, room_id, timeline_limit, required_state_id
|
||||
FROM sliding_sync_connection_room_configs
|
||||
WHERE connection_position = ?
|
||||
"""
|
||||
txn.execute(sql, (connection_position, previous_connection_position))
|
||||
|
||||
# We now upsert the changes to the various streams.
|
||||
key_values = []
|
||||
value_values = []
|
||||
for room_id, have_sent_room in per_connection_state.rooms._statuses.items():
|
||||
key_values.append((connection_position, "rooms", room_id))
|
||||
value_values.append(
|
||||
(have_sent_room.status.value, have_sent_room.last_token)
|
||||
)
|
||||
|
||||
for room_id, have_sent_room in per_connection_state.receipts._statuses.items():
|
||||
key_values.append((connection_position, "receipts", room_id))
|
||||
value_values.append(
|
||||
(have_sent_room.status.value, have_sent_room.last_token)
|
||||
)
|
||||
|
||||
self.db_pool.simple_upsert_many_txn(
|
||||
txn,
|
||||
table="sliding_sync_connection_streams",
|
||||
key_names=(
|
||||
"connection_position",
|
||||
"stream",
|
||||
"room_id",
|
||||
),
|
||||
key_values=key_values,
|
||||
value_names=(
|
||||
"room_status",
|
||||
"last_token",
|
||||
),
|
||||
value_values=value_values,
|
||||
)
|
||||
|
||||
# ... and upsert changes to the room configs.
|
||||
keys = []
|
||||
values = []
|
||||
for room_id, room_config in per_connection_state.room_configs.items():
|
||||
keys.append((connection_position, room_id))
|
||||
values.append((room_config.timeline_limit, room_to_state_ids[room_id]))
|
||||
|
||||
self.db_pool.simple_upsert_many_txn(
|
||||
txn,
|
||||
table="sliding_sync_connection_room_configs",
|
||||
key_names=(
|
||||
"connection_position",
|
||||
"room_id",
|
||||
),
|
||||
key_values=keys,
|
||||
value_names=(
|
||||
"timeline_limit",
|
||||
"required_state_id",
|
||||
),
|
||||
value_values=values,
|
||||
)
|
||||
|
||||
return connection_position
|
||||
|
||||
@cached(iterable=True, max_entries=100000)
|
||||
async def get_and_clear_connection_positions(
|
||||
self, user_id: str, device_id: str, conn_id: str, connection_position: int
|
||||
) -> "PerConnectionState":
|
||||
"""Get the per-connection state for the given connection position."""
|
||||
|
||||
per_connection_state_db = await self.db_pool.runInteraction(
|
||||
"get_and_clear_connection_positions",
|
||||
self._get_and_clear_connection_positions_txn,
|
||||
user_id=user_id,
|
||||
device_id=device_id,
|
||||
conn_id=conn_id,
|
||||
connection_position=connection_position,
|
||||
)
|
||||
|
||||
# This cast is safe because the downstream code only cares about
|
||||
# `store.get_id_for_instance(...)` and `StreamWorkerStore` is mixed
|
||||
# alongside `SlidingSyncStore` wherever we create a store.
|
||||
store = cast("DataStore", self)
|
||||
|
||||
return await per_connection_state_db.to_state(store)
|
||||
|
||||
def _get_and_clear_connection_positions_txn(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
user_id: str,
|
||||
device_id: str,
|
||||
conn_id: str,
|
||||
connection_position: int,
|
||||
) -> "PerConnectionStateDB":
|
||||
# The `previous_connection_position` is a user-supplied value, so we
|
||||
# need to make sure that the one they supplied is actually theirs.
|
||||
sql = """
|
||||
SELECT connection_key
|
||||
FROM sliding_sync_connection_positions
|
||||
INNER JOIN sliding_sync_connections USING (connection_key)
|
||||
WHERE
|
||||
connection_position = ?
|
||||
AND user_id = ? AND effective_device_id = ? AND conn_id = ?
|
||||
"""
|
||||
txn.execute(sql, (connection_position, user_id, device_id, conn_id))
|
||||
row = txn.fetchone()
|
||||
if row is None:
|
||||
raise SlidingSyncUnknownPosition()
|
||||
|
||||
(connection_key,) = row
|
||||
|
||||
# Now that we have seen the client has received and used the connection
|
||||
# position, we can delete all the other connection positions.
|
||||
sql = """
|
||||
DELETE FROM sliding_sync_connection_positions
|
||||
WHERE connection_key = ? AND connection_position != ?
|
||||
"""
|
||||
txn.execute(sql, (connection_key, connection_position))
|
||||
|
||||
# Fetch and create a mapping from required state ID to the actual
|
||||
# required state for the connection.
|
||||
rows = self.db_pool.simple_select_list_txn(
|
||||
txn,
|
||||
table="sliding_sync_connection_required_state",
|
||||
keyvalues={"connection_key": connection_key},
|
||||
retcols=(
|
||||
"required_state_id",
|
||||
"required_state",
|
||||
),
|
||||
)
|
||||
|
||||
required_state_map: Dict[int, Dict[str, Set[str]]] = {}
|
||||
for row in rows:
|
||||
state = required_state_map[row[0]] = {}
|
||||
for event_type, state_keys in db_to_json(row[1]):
|
||||
state[event_type] = set(state_keys)
|
||||
|
||||
# Get all the room configs, looking up the required state from the map
|
||||
# above.
|
||||
room_config_rows = self.db_pool.simple_select_list_txn(
|
||||
txn,
|
||||
table="sliding_sync_connection_room_configs",
|
||||
keyvalues={"connection_position": connection_position},
|
||||
retcols=(
|
||||
"room_id",
|
||||
"timeline_limit",
|
||||
"required_state_id",
|
||||
),
|
||||
)
|
||||
|
||||
room_configs: Dict[str, RoomSyncConfig] = {}
|
||||
for (
|
||||
room_id,
|
||||
timeline_limit,
|
||||
required_state_id,
|
||||
) in room_config_rows:
|
||||
room_configs[room_id] = RoomSyncConfig(
|
||||
timeline_limit=timeline_limit,
|
||||
required_state_map=required_state_map[required_state_id],
|
||||
)
|
||||
|
||||
# Now look up the per-room stream data.
|
||||
rooms: Dict[str, HaveSentRoom[str]] = {}
|
||||
receipts: Dict[str, HaveSentRoom[str]] = {}
|
||||
|
||||
receipt_rows = self.db_pool.simple_select_list_txn(
|
||||
txn,
|
||||
table="sliding_sync_connection_streams",
|
||||
keyvalues={"connection_position": connection_position},
|
||||
retcols=(
|
||||
"stream",
|
||||
"room_id",
|
||||
"room_status",
|
||||
"last_token",
|
||||
),
|
||||
)
|
||||
for stream, room_id, room_status, last_token in receipt_rows:
|
||||
have_sent_room: HaveSentRoom[str] = HaveSentRoom(
|
||||
status=HaveSentRoomFlag(room_status), last_token=last_token
|
||||
)
|
||||
if stream == "rooms":
|
||||
rooms[room_id] = have_sent_room
|
||||
elif stream == "receipts":
|
||||
receipts[room_id] = have_sent_room
|
||||
else:
|
||||
# For forwards compatibility we ignore unknown streams, as in
|
||||
# future we want to be able to easily add more stream types.
|
||||
logger.warning("Unrecognized sliding sync stream in DB %r", stream)
|
||||
|
||||
return PerConnectionStateDB(
|
||||
rooms=RoomStatusMap(rooms),
|
||||
receipts=RoomStatusMap(receipts),
|
||||
room_configs=room_configs,
|
||||
)
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True, frozen=True)
|
||||
class PerConnectionStateDB:
|
||||
"""An equivalent to `PerConnectionState` that holds data in a format stored
|
||||
in the DB.
|
||||
|
||||
The principle difference is that the tokens for the different streams are
|
||||
serialized to strings.
|
||||
|
||||
When persisting this *only* contains updates to the state.
|
||||
"""
|
||||
|
||||
rooms: "RoomStatusMap[str]"
|
||||
receipts: "RoomStatusMap[str]"
|
||||
|
||||
room_configs: Mapping[str, "RoomSyncConfig"]
|
||||
|
||||
@staticmethod
|
||||
async def from_state(
|
||||
per_connection_state: "MutablePerConnectionState", store: "DataStore"
|
||||
) -> "PerConnectionStateDB":
|
||||
"""Convert from a standard `PerConnectionState`"""
|
||||
rooms = {
|
||||
room_id: HaveSentRoom(
|
||||
status=status.status,
|
||||
last_token=(
|
||||
await status.last_token.to_string(store)
|
||||
if status.last_token is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
for room_id, status in per_connection_state.rooms.get_updates().items()
|
||||
}
|
||||
|
||||
receipts = {
|
||||
room_id: HaveSentRoom(
|
||||
status=status.status,
|
||||
last_token=(
|
||||
await status.last_token.to_string(store)
|
||||
if status.last_token is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
for room_id, status in per_connection_state.receipts.get_updates().items()
|
||||
}
|
||||
|
||||
log_kv(
|
||||
{
|
||||
"rooms": rooms,
|
||||
"receipts": receipts,
|
||||
"room_configs": per_connection_state.room_configs.maps[0],
|
||||
}
|
||||
)
|
||||
|
||||
return PerConnectionStateDB(
|
||||
rooms=RoomStatusMap(rooms),
|
||||
receipts=RoomStatusMap(receipts),
|
||||
room_configs=per_connection_state.room_configs.maps[0],
|
||||
)
|
||||
|
||||
async def to_state(self, store: "DataStore") -> "PerConnectionState":
|
||||
"""Convert into a standard `PerConnectionState`"""
|
||||
rooms = {
|
||||
room_id: HaveSentRoom(
|
||||
status=status.status,
|
||||
last_token=(
|
||||
await RoomStreamToken.parse(store, status.last_token)
|
||||
if status.last_token is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
for room_id, status in self.rooms._statuses.items()
|
||||
}
|
||||
|
||||
receipts = {
|
||||
room_id: HaveSentRoom(
|
||||
status=status.status,
|
||||
last_token=(
|
||||
await MultiWriterStreamToken.parse(store, status.last_token)
|
||||
if status.last_token is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
for room_id, status in self.receipts._statuses.items()
|
||||
}
|
||||
|
||||
return PerConnectionState(
|
||||
rooms=RoomStatusMap(rooms),
|
||||
receipts=RoomStatusMap(receipts),
|
||||
room_configs=self.room_configs,
|
||||
)
|
|
@ -28,6 +28,11 @@ if TYPE_CHECKING:
|
|||
from synapse.storage.database import LoggingDatabaseConnection
|
||||
|
||||
|
||||
# A string that will be replaced with the appropriate auto increment directive
|
||||
# for the database engine, expands to an auto incrementing integer primary key.
|
||||
AUTO_INCREMENT_PRIMARY_KEYPLACEHOLDER = "$%AUTO_INCREMENT_PRIMARY_KEY%$"
|
||||
|
||||
|
||||
class IsolationLevel(IntEnum):
|
||||
READ_COMMITTED: int = 1
|
||||
REPEATABLE_READ: int = 2
|
||||
|
|
|
@ -25,6 +25,7 @@ from typing import TYPE_CHECKING, Any, Mapping, NoReturn, Optional, Tuple, cast
|
|||
import psycopg2.extensions
|
||||
|
||||
from synapse.storage.engines._base import (
|
||||
AUTO_INCREMENT_PRIMARY_KEYPLACEHOLDER,
|
||||
BaseDatabaseEngine,
|
||||
IncorrectDatabaseSetup,
|
||||
IsolationLevel,
|
||||
|
@ -256,4 +257,10 @@ class PostgresEngine(
|
|||
executing the script in its own transaction. The script transaction is
|
||||
left open and it is the responsibility of the caller to commit it.
|
||||
"""
|
||||
# Replace auto increment placeholder with the appropriate directive
|
||||
script = script.replace(
|
||||
AUTO_INCREMENT_PRIMARY_KEYPLACEHOLDER,
|
||||
"BIGINT PRIMARY KEY GENERATED ALWAYS AS IDENTITY",
|
||||
)
|
||||
|
||||
cursor.execute(f"COMMIT; BEGIN TRANSACTION; {script}")
|
||||
|
|
|
@ -25,6 +25,7 @@ import threading
|
|||
from typing import TYPE_CHECKING, Any, List, Mapping, Optional
|
||||
|
||||
from synapse.storage.engines import BaseDatabaseEngine
|
||||
from synapse.storage.engines._base import AUTO_INCREMENT_PRIMARY_KEYPLACEHOLDER
|
||||
from synapse.storage.types import Cursor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -168,6 +169,11 @@ class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection, sqlite3.Cursor]):
|
|||
> first. No other implicit transaction control is performed; any transaction
|
||||
> control must be added to sql_script.
|
||||
"""
|
||||
# Replace auto increment placeholder with the appropriate directive
|
||||
script = script.replace(
|
||||
AUTO_INCREMENT_PRIMARY_KEYPLACEHOLDER, "INTEGER PRIMARY KEY AUTOINCREMENT"
|
||||
)
|
||||
|
||||
# The implementation of `executescript` can be found at
|
||||
# https://github.com/python/cpython/blob/3.11/Modules/_sqlite/cursor.c#L1035.
|
||||
cursor.executescript(f"BEGIN TRANSACTION; {script}")
|
||||
|
|
|
@ -146,6 +146,9 @@ Changes in SCHEMA_VERSION = 86
|
|||
Changes in SCHEMA_VERSION = 87
|
||||
- Add tables to store Sliding Sync data for quick filtering/sorting
|
||||
(`sliding_sync_joined_rooms`, `sliding_sync_membership_snapshots`)
|
||||
- Add tables for storing the per-connection state for sliding sync requests:
|
||||
sliding_sync_connections, sliding_sync_connection_positions, sliding_sync_connection_required_state,
|
||||
sliding_sync_connection_room_configs, sliding_sync_connection_streams
|
||||
"""
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,81 @@
|
|||
--
|
||||
-- This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
--
|
||||
-- Copyright (C) 2024 New Vector, Ltd
|
||||
--
|
||||
-- This program is free software: you can redistribute it and/or modify
|
||||
-- it under the terms of the GNU Affero General Public License as
|
||||
-- published by the Free Software Foundation, either version 3 of the
|
||||
-- License, or (at your option) any later version.
|
||||
--
|
||||
-- See the GNU Affero General Public License for more details:
|
||||
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
|
||||
|
||||
-- Table to track active sliding sync connections.
|
||||
--
|
||||
-- A new connection will be created for every sliding sync request without a
|
||||
-- `since` token for a given `conn_id` for a device.#
|
||||
--
|
||||
-- Once a new connection is created and used we delete all other connections for
|
||||
-- the `conn_id`.
|
||||
CREATE TABLE sliding_sync_connections(
|
||||
connection_key $%AUTO_INCREMENT_PRIMARY_KEY%$,
|
||||
user_id TEXT NOT NULL,
|
||||
-- Generally the device ID, but may be something else for e.g. puppeted accounts.
|
||||
effective_device_id TEXT NOT NULL,
|
||||
conn_id TEXT NOT NULL,
|
||||
created_ts BIGINT NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX sliding_sync_connections_idx ON sliding_sync_connections(user_id, effective_device_id, conn_id);
|
||||
CREATE INDEX sliding_sync_connections_ts_idx ON sliding_sync_connections(created_ts);
|
||||
|
||||
-- We track per-connection state by associating changes to the state with
|
||||
-- connection positions. This ensures that we correctly track state even if we
|
||||
-- see retries of requests.
|
||||
--
|
||||
-- If the client starts a "new" connection (by not specifying a since token),
|
||||
-- we'll clear out the other connections (to ensure that we don't end up with
|
||||
-- lots of connection keys).
|
||||
CREATE TABLE sliding_sync_connection_positions(
|
||||
connection_position $%AUTO_INCREMENT_PRIMARY_KEY%$,
|
||||
connection_key BIGINT NOT NULL REFERENCES sliding_sync_connections(connection_key) ON DELETE CASCADE,
|
||||
created_ts BIGINT NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX sliding_sync_connection_positions_key ON sliding_sync_connection_positions(connection_key);
|
||||
CREATE INDEX sliding_sync_connection_positions_ts_idx ON sliding_sync_connection_positions(created_ts);
|
||||
|
||||
|
||||
-- To save space we deduplicate the `required_state` json by assigning IDs to
|
||||
-- different values.
|
||||
CREATE TABLE sliding_sync_connection_required_state(
|
||||
required_state_id $%AUTO_INCREMENT_PRIMARY_KEY%$,
|
||||
connection_key BIGINT NOT NULL REFERENCES sliding_sync_connections(connection_key) ON DELETE CASCADE,
|
||||
required_state TEXT NOT NULL -- We store this as a json list of event type / state key tuples.
|
||||
);
|
||||
|
||||
CREATE INDEX sliding_sync_connection_required_state_conn_pos ON sliding_sync_connection_required_state(connection_key);
|
||||
|
||||
|
||||
-- Stores the room configs we have seen for rooms in a connection.
|
||||
CREATE TABLE sliding_sync_connection_room_configs(
|
||||
connection_position BIGINT NOT NULL REFERENCES sliding_sync_connection_positions(connection_position) ON DELETE CASCADE,
|
||||
room_id TEXT NOT NULL,
|
||||
timeline_limit BIGINT NOT NULL,
|
||||
required_state_id BIGINT NOT NULL REFERENCES sliding_sync_connection_required_state(required_state_id)
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX sliding_sync_connection_room_configs_idx ON sliding_sync_connection_room_configs(connection_position, room_id);
|
||||
|
||||
-- Stores what data we have sent for given streams down given connections.
|
||||
CREATE TABLE sliding_sync_connection_streams(
|
||||
connection_position BIGINT NOT NULL REFERENCES sliding_sync_connection_positions(connection_position) ON DELETE CASCADE,
|
||||
stream TEXT NOT NULL, -- e.g. "events" or "receipts"
|
||||
room_id TEXT NOT NULL,
|
||||
room_status TEXT NOT NULL, -- "live" or "previously", i.e. the `HaveSentRoomFlag` value
|
||||
last_token TEXT -- For "previously" the token for the stream we have sent up to.
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX sliding_sync_connection_streams_idx ON sliding_sync_connection_streams(connection_position, room_id, stream);
|
|
@ -730,6 +730,9 @@ class RoomStatusMap(Generic[T]):
|
|||
|
||||
return RoomStatusMap(statuses=dict(self._statuses))
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._statuses)
|
||||
|
||||
|
||||
class MutableRoomStatusMap(RoomStatusMap[T]):
|
||||
"""A mutable version of `RoomStatusMap`"""
|
||||
|
@ -831,6 +834,9 @@ class PerConnectionState:
|
|||
room_configs=dict(self.room_configs),
|
||||
)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.rooms) + len(self.receipts) + len(self.room_configs)
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class MutablePerConnectionState(PerConnectionState):
|
||||
|
|
|
@ -191,8 +191,14 @@ class SlidingSyncRoomsRequiredStateTestCase(SlidingSyncBase):
|
|||
}
|
||||
_, from_token = self.do_sync(sync_body, tok=user1_tok)
|
||||
|
||||
# Reset the in-memory cache
|
||||
self.hs.get_sliding_sync_handler().connection_store._connections.clear()
|
||||
# Reset the positions
|
||||
self.get_success(
|
||||
self.store.db_pool.simple_delete(
|
||||
table="sliding_sync_connections",
|
||||
keyvalues={"user_id": user1_id},
|
||||
desc="clear_sliding_sync_connections_cache",
|
||||
)
|
||||
)
|
||||
|
||||
# Make the Sliding Sync request
|
||||
channel = self.make_request(
|
||||
|
|
Loading…
Reference in a new issue