0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-14 06:23:47 +01:00

Persist room hierarchy pagination sessions to the database. (#10613)

This commit is contained in:
Patrick Cloke 2021-08-24 08:14:03 -04:00 committed by GitHub
parent 15db8b7c7f
commit d12ba52f17
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 212 additions and 38 deletions

View file

@ -0,0 +1 @@
Add pagination to the spaces summary based on updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946).

View file

@ -57,6 +57,7 @@ files =
synapse/storage/databases/main/keys.py, synapse/storage/databases/main/keys.py,
synapse/storage/databases/main/pusher.py, synapse/storage/databases/main/pusher.py,
synapse/storage/databases/main/registration.py, synapse/storage/databases/main/registration.py,
synapse/storage/databases/main/session.py,
synapse/storage/databases/main/stream.py, synapse/storage/databases/main/stream.py,
synapse/storage/databases/main/ui_auth.py, synapse/storage/databases/main/ui_auth.py,
synapse/storage/database.py, synapse/storage/database.py,

View file

@ -118,6 +118,7 @@ from synapse.storage.databases.main.monthly_active_users import (
from synapse.storage.databases.main.presence import PresenceStore from synapse.storage.databases.main.presence import PresenceStore
from synapse.storage.databases.main.room import RoomWorkerStore from synapse.storage.databases.main.room import RoomWorkerStore
from synapse.storage.databases.main.search import SearchStore from synapse.storage.databases.main.search import SearchStore
from synapse.storage.databases.main.session import SessionStore
from synapse.storage.databases.main.stats import StatsStore from synapse.storage.databases.main.stats import StatsStore
from synapse.storage.databases.main.transactions import TransactionWorkerStore from synapse.storage.databases.main.transactions import TransactionWorkerStore
from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore
@ -253,6 +254,7 @@ class GenericWorkerSlavedStore(
SearchStore, SearchStore,
TransactionWorkerStore, TransactionWorkerStore,
LockStore, LockStore,
SessionStore,
BaseSlavedStore, BaseSlavedStore,
): ):
pass pass

View file

@ -28,12 +28,11 @@ from synapse.api.constants import (
Membership, Membership,
RoomTypes, RoomTypes,
) )
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.utils import format_event_for_client_v2 from synapse.events.utils import format_event_for_client_v2
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import random_string
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -76,6 +75,9 @@ class _PaginationSession:
class RoomSummaryHandler: class RoomSummaryHandler:
# A unique key used for pagination sessions for the room hierarchy endpoint.
_PAGINATION_SESSION_TYPE = "room_hierarchy_pagination"
# The time a pagination session remains valid for. # The time a pagination session remains valid for.
_PAGINATION_SESSION_VALIDITY_PERIOD_MS = 5 * 60 * 1000 _PAGINATION_SESSION_VALIDITY_PERIOD_MS = 5 * 60 * 1000
@ -87,12 +89,6 @@ class RoomSummaryHandler:
self._server_name = hs.hostname self._server_name = hs.hostname
self._federation_client = hs.get_federation_client() self._federation_client = hs.get_federation_client()
# A map of query information to the current pagination state.
#
# TODO Allow for multiple workers to share this data.
# TODO Expire pagination tokens.
self._pagination_sessions: Dict[_PaginationKey, _PaginationSession] = {}
# If a user tries to fetch the same page multiple times in quick succession, # If a user tries to fetch the same page multiple times in quick succession,
# only process the first attempt and return its result to subsequent requests. # only process the first attempt and return its result to subsequent requests.
self._pagination_response_cache: ResponseCache[ self._pagination_response_cache: ResponseCache[
@ -102,21 +98,6 @@ class RoomSummaryHandler:
"get_room_hierarchy", "get_room_hierarchy",
) )
def _expire_pagination_sessions(self):
"""Expire pagination session which are old."""
expire_before = (
self._clock.time_msec() - self._PAGINATION_SESSION_VALIDITY_PERIOD_MS
)
to_expire = []
for key, value in self._pagination_sessions.items():
if value.creation_time_ms < expire_before:
to_expire.append(key)
for key in to_expire:
logger.debug("Expiring pagination session id %s", key)
del self._pagination_sessions[key]
async def get_space_summary( async def get_space_summary(
self, self,
requester: str, requester: str,
@ -327,18 +308,29 @@ class RoomSummaryHandler:
# If this is continuing a previous session, pull the persisted data. # If this is continuing a previous session, pull the persisted data.
if from_token: if from_token:
self._expire_pagination_sessions() try:
pagination_session = await self._store.get_session(
session_type=self._PAGINATION_SESSION_TYPE,
session_id=from_token,
)
except StoreError:
raise SynapseError(400, "Unknown pagination token", Codes.INVALID_PARAM)
pagination_key = _PaginationKey( # If the requester, room ID, suggested-only, or max depth were modified
requested_room_id, suggested_only, max_depth, from_token # the session is invalid.
) if (
if pagination_key not in self._pagination_sessions: requester != pagination_session["requester"]
or requested_room_id != pagination_session["room_id"]
or suggested_only != pagination_session["suggested_only"]
or max_depth != pagination_session["max_depth"]
):
raise SynapseError(400, "Unknown pagination token", Codes.INVALID_PARAM) raise SynapseError(400, "Unknown pagination token", Codes.INVALID_PARAM)
# Load the previous state. # Load the previous state.
pagination_session = self._pagination_sessions[pagination_key] room_queue = [
room_queue = pagination_session.room_queue _RoomQueueEntry(*fields) for fields in pagination_session["room_queue"]
processed_rooms = pagination_session.processed_rooms ]
processed_rooms = set(pagination_session["processed_rooms"])
else: else:
# The queue of rooms to process, the next room is last on the stack. # The queue of rooms to process, the next room is last on the stack.
room_queue = [_RoomQueueEntry(requested_room_id, ())] room_queue = [_RoomQueueEntry(requested_room_id, ())]
@ -456,13 +448,21 @@ class RoomSummaryHandler:
# If there's additional data, generate a pagination token (and persist state). # If there's additional data, generate a pagination token (and persist state).
if room_queue: if room_queue:
next_batch = random_string(24) result["next_batch"] = await self._store.create_session(
result["next_batch"] = next_batch session_type=self._PAGINATION_SESSION_TYPE,
pagination_key = _PaginationKey( value={
requested_room_id, suggested_only, max_depth, next_batch # Information which must be identical across pagination.
) "requester": requester,
self._pagination_sessions[pagination_key] = _PaginationSession( "room_id": requested_room_id,
self._clock.time_msec(), room_queue, processed_rooms "suggested_only": suggested_only,
"max_depth": max_depth,
# The stored state.
"room_queue": [
attr.astuple(room_entry) for room_entry in room_queue
],
"processed_rooms": list(processed_rooms),
},
expiry_ms=self._PAGINATION_SESSION_VALIDITY_PERIOD_MS,
) )
return result return result

View file

@ -63,6 +63,7 @@ from .relations import RelationsStore
from .room import RoomStore from .room import RoomStore
from .roommember import RoomMemberStore from .roommember import RoomMemberStore
from .search import SearchStore from .search import SearchStore
from .session import SessionStore
from .signatures import SignatureStore from .signatures import SignatureStore
from .state import StateStore from .state import StateStore
from .stats import StatsStore from .stats import StatsStore
@ -121,6 +122,7 @@ class DataStore(
ServerMetricsStore, ServerMetricsStore,
EventForwardExtremitiesStore, EventForwardExtremitiesStore,
LockStore, LockStore,
SessionStore,
): ):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
self.hs = hs self.hs = hs

View file

@ -0,0 +1,145 @@
# -*- coding: utf-8 -*-
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
import synapse.util.stringutils as stringutils
from synapse.api.errors import StoreError
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.types import JsonDict
from synapse.util import json_encoder
if TYPE_CHECKING:
from synapse.server import HomeServer
class SessionStore(SQLBaseStore):
"""
A store for generic session data.
Each type of session should provide a unique type (to separate sessions).
Sessions are automatically removed when they expire.
"""
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
# Create a background job for culling expired sessions.
if hs.config.run_background_tasks:
self._clock.looping_call(self._delete_expired_sessions, 30 * 60 * 1000)
async def create_session(
self, session_type: str, value: JsonDict, expiry_ms: int
) -> str:
"""
Creates a new pagination session for the room hierarchy endpoint.
Args:
session_type: The type for this session.
value: The value to store.
expiry_ms: How long before an item is evicted from the cache
in milliseconds. Default is 0, indicating items never get
evicted based on time.
Returns:
The newly created session ID.
Raises:
StoreError if a unique session ID cannot be generated.
"""
# autogen a session ID and try to create it. We may clash, so just
# try a few times till one goes through, giving up eventually.
attempts = 0
while attempts < 5:
session_id = stringutils.random_string(24)
try:
await self.db_pool.simple_insert(
table="sessions",
values={
"session_id": session_id,
"session_type": session_type,
"value": json_encoder.encode(value),
"expiry_time_ms": self.hs.get_clock().time_msec() + expiry_ms,
},
desc="create_session",
)
return session_id
except self.db_pool.engine.module.IntegrityError:
attempts += 1
raise StoreError(500, "Couldn't generate a session ID.")
async def get_session(self, session_type: str, session_id: str) -> JsonDict:
"""
Retrieve data stored with create_session
Args:
session_type: The type for this session.
session_id: The session ID returned from create_session.
Raises:
StoreError if the session cannot be found.
"""
def _get_session(
txn: LoggingTransaction, session_type: str, session_id: str, ts: int
) -> JsonDict:
# This includes the expiry time since items are only periodically
# deleted, not upon expiry.
select_sql = """
SELECT value FROM sessions WHERE
session_type = ? AND session_id = ? AND expiry_time_ms > ?
"""
txn.execute(select_sql, [session_type, session_id, ts])
row = txn.fetchone()
if not row:
raise StoreError(404, "No session")
return db_to_json(row[0])
return await self.db_pool.runInteraction(
"get_session",
_get_session,
session_type,
session_id,
self._clock.time_msec(),
)
@wrap_as_background_process("delete_expired_sessions")
async def _delete_expired_sessions(self) -> None:
"""Remove sessions with expiry dates that have passed."""
def _delete_expired_sessions_txn(txn: LoggingTransaction, ts: int) -> None:
sql = "DELETE FROM sessions WHERE expiry_time_ms <= ?"
txn.execute(sql, (ts,))
await self.db_pool.runInteraction(
"delete_expired_sessions",
_delete_expired_sessions_txn,
self._clock.time_msec(),
)

View file

@ -0,0 +1,23 @@
/*
* Copyright 2021 The Matrix.org Foundation C.I.C.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS sessions(
session_type TEXT NOT NULL, -- The unique key for this type of session.
session_id TEXT NOT NULL, -- The session ID passed to the client.
value TEXT NOT NULL, -- A JSON dictionary to persist.
expiry_time_ms BIGINT NOT NULL, -- The time this session will expire (epoch time in milliseconds).
UNIQUE (session_type, session_id)
);