0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-15 14:32:30 +01:00
synapse/synapse/storage/databases/main/session.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

152 lines
5 KiB
Python
Raw Normal View History

#
2023-11-21 21:29:58 +01:00
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2021 The Matrix.org Foundation C.I.C.
2023-11-21 21:29:58 +01:00
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
#
from typing import TYPE_CHECKING
import synapse.util.stringutils as stringutils
from synapse.api.errors import StoreError
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.types import JsonDict
from synapse.util import json_encoder
if TYPE_CHECKING:
from synapse.server import HomeServer
class SessionStore(SQLBaseStore):
"""
A store for generic session data.
Each type of session should provide a unique type (to separate sessions).
Sessions are automatically removed when they expire.
"""
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
# Create a background job for culling expired sessions.
if hs.config.worker.run_background_tasks:
self._clock.looping_call(self._delete_expired_sessions, 30 * 60 * 1000)
async def create_session(
self, session_type: str, value: JsonDict, expiry_ms: int
) -> str:
"""
Creates a new pagination session for the room hierarchy endpoint.
Args:
session_type: The type for this session.
value: The value to store.
expiry_ms: How long before an item is evicted from the cache
in milliseconds. Default is 0, indicating items never get
evicted based on time.
Returns:
The newly created session ID.
Raises:
StoreError if a unique session ID cannot be generated.
"""
# autogen a session ID and try to create it. We may clash, so just
# try a few times till one goes through, giving up eventually.
attempts = 0
while attempts < 5:
session_id = stringutils.random_string(24)
try:
await self.db_pool.simple_insert(
table="sessions",
values={
"session_id": session_id,
"session_type": session_type,
"value": json_encoder.encode(value),
"expiry_time_ms": self.hs.get_clock().time_msec() + expiry_ms,
},
desc="create_session",
)
return session_id
except self.db_pool.engine.module.IntegrityError:
attempts += 1
raise StoreError(500, "Couldn't generate a session ID.")
async def get_session(self, session_type: str, session_id: str) -> JsonDict:
"""
Retrieve data stored with create_session
Args:
session_type: The type for this session.
session_id: The session ID returned from create_session.
Raises:
StoreError if the session cannot be found.
"""
def _get_session(
txn: LoggingTransaction, session_type: str, session_id: str, ts: int
) -> JsonDict:
# This includes the expiry time since items are only periodically
# deleted, not upon expiry.
select_sql = """
SELECT value FROM sessions WHERE
session_type = ? AND session_id = ? AND expiry_time_ms > ?
"""
txn.execute(select_sql, [session_type, session_id, ts])
row = txn.fetchone()
if not row:
raise StoreError(404, "No session")
return db_to_json(row[0])
return await self.db_pool.runInteraction(
"get_session",
_get_session,
session_type,
session_id,
self._clock.time_msec(),
)
@wrap_as_background_process("delete_expired_sessions")
async def _delete_expired_sessions(self) -> None:
"""Remove sessions with expiry dates that have passed."""
def _delete_expired_sessions_txn(txn: LoggingTransaction, ts: int) -> None:
sql = "DELETE FROM sessions WHERE expiry_time_ms <= ?"
txn.execute(sql, (ts,))
await self.db_pool.runInteraction(
"delete_expired_sessions",
_delete_expired_sessions_txn,
self._clock.time_msec(),
)