forked from MirrorHub/synapse
Add type hints to some storage classes (#11307)
This commit is contained in:
parent
6ce19b94e8
commit
64ef25391d
9 changed files with 116 additions and 54 deletions
1
changelog.d/11307.misc
Normal file
1
changelog.d/11307.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Add type hints to storage classes.
|
7
mypy.ini
7
mypy.ini
|
@ -27,8 +27,6 @@ exclude = (?x)
|
||||||
|synapse/storage/databases/main/__init__.py
|
|synapse/storage/databases/main/__init__.py
|
||||||
|synapse/storage/databases/main/account_data.py
|
|synapse/storage/databases/main/account_data.py
|
||||||
|synapse/storage/databases/main/cache.py
|
|synapse/storage/databases/main/cache.py
|
||||||
|synapse/storage/databases/main/censor_events.py
|
|
||||||
|synapse/storage/databases/main/deviceinbox.py
|
|
||||||
|synapse/storage/databases/main/devices.py
|
|synapse/storage/databases/main/devices.py
|
||||||
|synapse/storage/databases/main/directory.py
|
|synapse/storage/databases/main/directory.py
|
||||||
|synapse/storage/databases/main/e2e_room_keys.py
|
|synapse/storage/databases/main/e2e_room_keys.py
|
||||||
|
@ -38,19 +36,15 @@ exclude = (?x)
|
||||||
|synapse/storage/databases/main/events_bg_updates.py
|
|synapse/storage/databases/main/events_bg_updates.py
|
||||||
|synapse/storage/databases/main/events_forward_extremities.py
|
|synapse/storage/databases/main/events_forward_extremities.py
|
||||||
|synapse/storage/databases/main/events_worker.py
|
|synapse/storage/databases/main/events_worker.py
|
||||||
|synapse/storage/databases/main/filtering.py
|
|
||||||
|synapse/storage/databases/main/group_server.py
|
|synapse/storage/databases/main/group_server.py
|
||||||
|synapse/storage/databases/main/lock.py
|
|
||||||
|synapse/storage/databases/main/media_repository.py
|
|synapse/storage/databases/main/media_repository.py
|
||||||
|synapse/storage/databases/main/metrics.py
|
|synapse/storage/databases/main/metrics.py
|
||||||
|synapse/storage/databases/main/monthly_active_users.py
|
|synapse/storage/databases/main/monthly_active_users.py
|
||||||
|synapse/storage/databases/main/openid.py
|
|
||||||
|synapse/storage/databases/main/presence.py
|
|synapse/storage/databases/main/presence.py
|
||||||
|synapse/storage/databases/main/profile.py
|
|synapse/storage/databases/main/profile.py
|
||||||
|synapse/storage/databases/main/purge_events.py
|
|synapse/storage/databases/main/purge_events.py
|
||||||
|synapse/storage/databases/main/push_rule.py
|
|synapse/storage/databases/main/push_rule.py
|
||||||
|synapse/storage/databases/main/receipts.py
|
|synapse/storage/databases/main/receipts.py
|
||||||
|synapse/storage/databases/main/rejections.py
|
|
||||||
|synapse/storage/databases/main/room.py
|
|synapse/storage/databases/main/room.py
|
||||||
|synapse/storage/databases/main/room_batch.py
|
|synapse/storage/databases/main/room_batch.py
|
||||||
|synapse/storage/databases/main/roommember.py
|
|synapse/storage/databases/main/roommember.py
|
||||||
|
@ -59,7 +53,6 @@ exclude = (?x)
|
||||||
|synapse/storage/databases/main/state.py
|
|synapse/storage/databases/main/state.py
|
||||||
|synapse/storage/databases/main/state_deltas.py
|
|synapse/storage/databases/main/state_deltas.py
|
||||||
|synapse/storage/databases/main/stats.py
|
|synapse/storage/databases/main/stats.py
|
||||||
|synapse/storage/databases/main/tags.py
|
|
||||||
|synapse/storage/databases/main/transactions.py
|
|synapse/storage/databases/main/transactions.py
|
||||||
|synapse/storage/databases/main/user_directory.py
|
|synapse/storage/databases/main/user_directory.py
|
||||||
|synapse/storage/databases/main/user_erasure_store.py
|
|synapse/storage/databases/main/user_erasure_store.py
|
||||||
|
|
|
@ -13,12 +13,12 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
from synapse.events.utils import prune_event_dict
|
from synapse.events.utils import prune_event_dict
|
||||||
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||||
from synapse.storage._base import SQLBaseStore
|
from synapse.storage._base import SQLBaseStore
|
||||||
from synapse.storage.database import DatabasePool
|
from synapse.storage.database import DatabasePool, LoggingTransaction
|
||||||
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
||||||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||||
from synapse.util import json_encoder
|
from synapse.util import json_encoder
|
||||||
|
@ -41,7 +41,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
|
||||||
hs.get_clock().looping_call(self._censor_redactions, 5 * 60 * 1000)
|
hs.get_clock().looping_call(self._censor_redactions, 5 * 60 * 1000)
|
||||||
|
|
||||||
@wrap_as_background_process("_censor_redactions")
|
@wrap_as_background_process("_censor_redactions")
|
||||||
async def _censor_redactions(self):
|
async def _censor_redactions(self) -> None:
|
||||||
"""Censors all redactions older than the configured period that haven't
|
"""Censors all redactions older than the configured period that haven't
|
||||||
been censored yet.
|
been censored yet.
|
||||||
|
|
||||||
|
@ -105,7 +105,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
|
||||||
and original_event.internal_metadata.is_redacted()
|
and original_event.internal_metadata.is_redacted()
|
||||||
):
|
):
|
||||||
# Redaction was allowed
|
# Redaction was allowed
|
||||||
pruned_json = json_encoder.encode(
|
pruned_json: Optional[str] = json_encoder.encode(
|
||||||
prune_event_dict(
|
prune_event_dict(
|
||||||
original_event.room_version, original_event.get_dict()
|
original_event.room_version, original_event.get_dict()
|
||||||
)
|
)
|
||||||
|
@ -116,7 +116,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
|
||||||
|
|
||||||
updates.append((redaction_id, event_id, pruned_json))
|
updates.append((redaction_id, event_id, pruned_json))
|
||||||
|
|
||||||
def _update_censor_txn(txn):
|
def _update_censor_txn(txn: LoggingTransaction) -> None:
|
||||||
for redaction_id, event_id, pruned_json in updates:
|
for redaction_id, event_id, pruned_json in updates:
|
||||||
if pruned_json:
|
if pruned_json:
|
||||||
self._censor_event_txn(txn, event_id, pruned_json)
|
self._censor_event_txn(txn, event_id, pruned_json)
|
||||||
|
@ -130,14 +130,16 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
|
||||||
|
|
||||||
await self.db_pool.runInteraction("_update_censor_txn", _update_censor_txn)
|
await self.db_pool.runInteraction("_update_censor_txn", _update_censor_txn)
|
||||||
|
|
||||||
def _censor_event_txn(self, txn, event_id, pruned_json):
|
def _censor_event_txn(
|
||||||
|
self, txn: LoggingTransaction, event_id: str, pruned_json: str
|
||||||
|
) -> None:
|
||||||
"""Censor an event by replacing its JSON in the event_json table with the
|
"""Censor an event by replacing its JSON in the event_json table with the
|
||||||
provided pruned JSON.
|
provided pruned JSON.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
txn (LoggingTransaction): The database transaction.
|
txn: The database transaction.
|
||||||
event_id (str): The ID of the event to censor.
|
event_id: The ID of the event to censor.
|
||||||
pruned_json (str): The pruned JSON
|
pruned_json: The pruned JSON
|
||||||
"""
|
"""
|
||||||
self.db_pool.simple_update_one_txn(
|
self.db_pool.simple_update_one_txn(
|
||||||
txn,
|
txn,
|
||||||
|
@ -157,7 +159,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
|
||||||
# Try to retrieve the event's content from the database or the event cache.
|
# Try to retrieve the event's content from the database or the event cache.
|
||||||
event = await self.get_event(event_id)
|
event = await self.get_event(event_id)
|
||||||
|
|
||||||
def delete_expired_event_txn(txn):
|
def delete_expired_event_txn(txn: LoggingTransaction) -> None:
|
||||||
# Delete the expiry timestamp associated with this event from the database.
|
# Delete the expiry timestamp associated with this event from the database.
|
||||||
self._delete_event_expiry_txn(txn, event_id)
|
self._delete_event_expiry_txn(txn, event_id)
|
||||||
|
|
||||||
|
@ -194,14 +196,14 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
|
||||||
"delete_expired_event", delete_expired_event_txn
|
"delete_expired_event", delete_expired_event_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
def _delete_event_expiry_txn(self, txn, event_id):
|
def _delete_event_expiry_txn(self, txn: LoggingTransaction, event_id: str) -> None:
|
||||||
"""Delete the expiry timestamp associated with an event ID without deleting the
|
"""Delete the expiry timestamp associated with an event ID without deleting the
|
||||||
actual event.
|
actual event.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
txn (LoggingTransaction): The transaction to use to perform the deletion.
|
txn: The transaction to use to perform the deletion.
|
||||||
event_id (str): The event ID to delete the associated expiry timestamp of.
|
event_id: The event ID to delete the associated expiry timestamp of.
|
||||||
"""
|
"""
|
||||||
return self.db_pool.simple_delete_txn(
|
self.db_pool.simple_delete_txn(
|
||||||
txn=txn, table="event_expiry", keyvalues={"event_id": event_id}
|
txn=txn, table="event_expiry", keyvalues={"event_id": event_id}
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
# Copyright 2016 OpenMarket Ltd
|
# Copyright 2016 OpenMarket Ltd
|
||||||
|
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -19,9 +20,17 @@ from synapse.logging import issue9533_logger
|
||||||
from synapse.logging.opentracing import log_kv, set_tag, trace
|
from synapse.logging.opentracing import log_kv, set_tag, trace
|
||||||
from synapse.replication.tcp.streams import ToDeviceStream
|
from synapse.replication.tcp.streams import ToDeviceStream
|
||||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||||
from synapse.storage.database import DatabasePool, LoggingTransaction
|
from synapse.storage.database import (
|
||||||
|
DatabasePool,
|
||||||
|
LoggingDatabaseConnection,
|
||||||
|
LoggingTransaction,
|
||||||
|
)
|
||||||
from synapse.storage.engines import PostgresEngine
|
from synapse.storage.engines import PostgresEngine
|
||||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
|
from synapse.storage.util.id_generators import (
|
||||||
|
AbstractStreamIdGenerator,
|
||||||
|
MultiWriterIdGenerator,
|
||||||
|
StreamIdGenerator,
|
||||||
|
)
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
from synapse.util import json_encoder
|
from synapse.util import json_encoder
|
||||||
from synapse.util.caches.expiringcache import ExpiringCache
|
from synapse.util.caches.expiringcache import ExpiringCache
|
||||||
|
@ -34,14 +43,21 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class DeviceInboxWorkerStore(SQLBaseStore):
|
class DeviceInboxWorkerStore(SQLBaseStore):
|
||||||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
|
def __init__(
|
||||||
|
self,
|
||||||
|
database: DatabasePool,
|
||||||
|
db_conn: LoggingDatabaseConnection,
|
||||||
|
hs: "HomeServer",
|
||||||
|
):
|
||||||
super().__init__(database, db_conn, hs)
|
super().__init__(database, db_conn, hs)
|
||||||
|
|
||||||
self._instance_name = hs.get_instance_name()
|
self._instance_name = hs.get_instance_name()
|
||||||
|
|
||||||
# Map of (user_id, device_id) to the last stream_id that has been
|
# Map of (user_id, device_id) to the last stream_id that has been
|
||||||
# deleted up to. This is so that we can no op deletions.
|
# deleted up to. This is so that we can no op deletions.
|
||||||
self._last_device_delete_cache = ExpiringCache(
|
self._last_device_delete_cache: ExpiringCache[
|
||||||
|
Tuple[str, Optional[str]], int
|
||||||
|
] = ExpiringCache(
|
||||||
cache_name="last_device_delete_cache",
|
cache_name="last_device_delete_cache",
|
||||||
clock=self._clock,
|
clock=self._clock,
|
||||||
max_len=10000,
|
max_len=10000,
|
||||||
|
@ -53,14 +69,16 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
||||||
self._instance_name in hs.config.worker.writers.to_device
|
self._instance_name in hs.config.worker.writers.to_device
|
||||||
)
|
)
|
||||||
|
|
||||||
self._device_inbox_id_gen = MultiWriterIdGenerator(
|
self._device_inbox_id_gen: AbstractStreamIdGenerator = (
|
||||||
db_conn=db_conn,
|
MultiWriterIdGenerator(
|
||||||
db=database,
|
db_conn=db_conn,
|
||||||
stream_name="to_device",
|
db=database,
|
||||||
instance_name=self._instance_name,
|
stream_name="to_device",
|
||||||
tables=[("device_inbox", "instance_name", "stream_id")],
|
instance_name=self._instance_name,
|
||||||
sequence_name="device_inbox_sequence",
|
tables=[("device_inbox", "instance_name", "stream_id")],
|
||||||
writers=hs.config.worker.writers.to_device,
|
sequence_name="device_inbox_sequence",
|
||||||
|
writers=hs.config.worker.writers.to_device,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self._can_write_to_device = True
|
self._can_write_to_device = True
|
||||||
|
@ -101,6 +119,8 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||||
if stream_name == ToDeviceStream.NAME:
|
if stream_name == ToDeviceStream.NAME:
|
||||||
|
# If replication is happening than postgres must be being used.
|
||||||
|
assert isinstance(self._device_inbox_id_gen, MultiWriterIdGenerator)
|
||||||
self._device_inbox_id_gen.advance(instance_name, token)
|
self._device_inbox_id_gen.advance(instance_name, token)
|
||||||
for row in rows:
|
for row in rows:
|
||||||
if row.entity.startswith("@"):
|
if row.entity.startswith("@"):
|
||||||
|
@ -220,11 +240,11 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
||||||
log_kv({"message": f"deleted {count} messages for device", "count": count})
|
log_kv({"message": f"deleted {count} messages for device", "count": count})
|
||||||
|
|
||||||
# Update the cache, ensuring that we only ever increase the value
|
# Update the cache, ensuring that we only ever increase the value
|
||||||
last_deleted_stream_id = self._last_device_delete_cache.get(
|
updated_last_deleted_stream_id = self._last_device_delete_cache.get(
|
||||||
(user_id, device_id), 0
|
(user_id, device_id), 0
|
||||||
)
|
)
|
||||||
self._last_device_delete_cache[(user_id, device_id)] = max(
|
self._last_device_delete_cache[(user_id, device_id)] = max(
|
||||||
last_deleted_stream_id, up_to_stream_id
|
updated_last_deleted_stream_id, up_to_stream_id
|
||||||
)
|
)
|
||||||
|
|
||||||
return count
|
return count
|
||||||
|
@ -432,7 +452,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
async with self._device_inbox_id_gen.get_next() as stream_id:
|
async with self._device_inbox_id_gen.get_next() as stream_id:
|
||||||
now_ms = self.clock.time_msec()
|
now_ms = self._clock.time_msec()
|
||||||
await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
|
"add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
|
||||||
)
|
)
|
||||||
|
@ -483,7 +503,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
async with self._device_inbox_id_gen.get_next() as stream_id:
|
async with self._device_inbox_id_gen.get_next() as stream_id:
|
||||||
now_ms = self.clock.time_msec()
|
now_ms = self._clock.time_msec()
|
||||||
await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"add_messages_from_remote_to_device_inbox",
|
"add_messages_from_remote_to_device_inbox",
|
||||||
add_messages_txn,
|
add_messages_txn,
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
# Copyright 2015, 2016 OpenMarket Ltd
|
# Copyright 2015, 2016 OpenMarket Ltd
|
||||||
|
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -18,6 +19,7 @@ from canonicaljson import encode_canonical_json
|
||||||
|
|
||||||
from synapse.api.errors import Codes, SynapseError
|
from synapse.api.errors import Codes, SynapseError
|
||||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||||
|
from synapse.storage.database import LoggingTransaction
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import cached
|
||||||
|
|
||||||
|
@ -49,7 +51,7 @@ class FilteringStore(SQLBaseStore):
|
||||||
|
|
||||||
# Need an atomic transaction to SELECT the maximal ID so far then
|
# Need an atomic transaction to SELECT the maximal ID so far then
|
||||||
# INSERT a new one
|
# INSERT a new one
|
||||||
def _do_txn(txn):
|
def _do_txn(txn: LoggingTransaction) -> int:
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT filter_id FROM user_filters "
|
"SELECT filter_id FROM user_filters "
|
||||||
"WHERE user_id = ? AND filter_json = ?"
|
"WHERE user_id = ? AND filter_json = ?"
|
||||||
|
@ -61,7 +63,7 @@ class FilteringStore(SQLBaseStore):
|
||||||
|
|
||||||
sql = "SELECT MAX(filter_id) FROM user_filters WHERE user_id = ?"
|
sql = "SELECT MAX(filter_id) FROM user_filters WHERE user_id = ?"
|
||||||
txn.execute(sql, (user_localpart,))
|
txn.execute(sql, (user_localpart,))
|
||||||
max_id = txn.fetchone()[0]
|
max_id = txn.fetchone()[0] # type: ignore[index]
|
||||||
if max_id is None:
|
if max_id is None:
|
||||||
filter_id = 0
|
filter_id = 0
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
from types import TracebackType
|
from types import TracebackType
|
||||||
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Type
|
from typing import TYPE_CHECKING, Optional, Tuple, Type
|
||||||
from weakref import WeakValueDictionary
|
from weakref import WeakValueDictionary
|
||||||
|
|
||||||
from twisted.internet.interfaces import IReactorCore
|
from twisted.internet.interfaces import IReactorCore
|
||||||
|
@ -62,7 +62,9 @@ class LockStore(SQLBaseStore):
|
||||||
|
|
||||||
# A map from `(lock_name, lock_key)` to the token of any locks that we
|
# A map from `(lock_name, lock_key)` to the token of any locks that we
|
||||||
# think we currently hold.
|
# think we currently hold.
|
||||||
self._live_tokens: Dict[Tuple[str, str], Lock] = WeakValueDictionary()
|
self._live_tokens: WeakValueDictionary[
|
||||||
|
Tuple[str, str], Lock
|
||||||
|
] = WeakValueDictionary()
|
||||||
|
|
||||||
# When we shut down we want to remove the locks. Technically this can
|
# When we shut down we want to remove the locks. Technically this can
|
||||||
# lead to a race, as we may drop the lock while we are still processing.
|
# lead to a race, as we may drop the lock while we are still processing.
|
||||||
|
|
|
@ -1,6 +1,21 @@
|
||||||
|
# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from synapse.storage._base import SQLBaseStore
|
from synapse.storage._base import SQLBaseStore
|
||||||
|
from synapse.storage.database import LoggingTransaction
|
||||||
|
|
||||||
|
|
||||||
class OpenIdStore(SQLBaseStore):
|
class OpenIdStore(SQLBaseStore):
|
||||||
|
@ -20,7 +35,7 @@ class OpenIdStore(SQLBaseStore):
|
||||||
async def get_user_id_for_open_id_token(
|
async def get_user_id_for_open_id_token(
|
||||||
self, token: str, ts_now_ms: int
|
self, token: str, ts_now_ms: int
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
def get_user_id_for_token_txn(txn):
|
def get_user_id_for_token_txn(txn: LoggingTransaction) -> Optional[str]:
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT user_id FROM open_id_tokens"
|
"SELECT user_id FROM open_id_tokens"
|
||||||
" WHERE token = ? AND ? <= ts_valid_until_ms"
|
" WHERE token = ? AND ? <= ts_valid_until_ms"
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# Copyright 2014-2016 OpenMarket Ltd
|
# Copyright 2014-2016 OpenMarket Ltd
|
||||||
# Copyright 2018 New Vector Ltd
|
# Copyright 2018 New Vector Ltd
|
||||||
|
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -14,9 +15,10 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple, cast
|
||||||
|
|
||||||
from synapse.storage._base import db_to_json
|
from synapse.storage._base import db_to_json
|
||||||
|
from synapse.storage.database import LoggingTransaction
|
||||||
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
|
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
from synapse.util import json_encoder
|
from synapse.util import json_encoder
|
||||||
|
@ -50,7 +52,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
|
||||||
|
|
||||||
async def get_all_updated_tags(
|
async def get_all_updated_tags(
|
||||||
self, instance_name: str, last_id: int, current_id: int, limit: int
|
self, instance_name: str, last_id: int, current_id: int, limit: int
|
||||||
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
|
) -> Tuple[List[Tuple[int, Tuple[str, str, str]]], int, bool]:
|
||||||
"""Get updates for tags replication stream.
|
"""Get updates for tags replication stream.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -75,7 +77,9 @@ class TagsWorkerStore(AccountDataWorkerStore):
|
||||||
if last_id == current_id:
|
if last_id == current_id:
|
||||||
return [], current_id, False
|
return [], current_id, False
|
||||||
|
|
||||||
def get_all_updated_tags_txn(txn):
|
def get_all_updated_tags_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> List[Tuple[int, str, str]]:
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT stream_id, user_id, room_id"
|
"SELECT stream_id, user_id, room_id"
|
||||||
" FROM room_tags_revisions as r"
|
" FROM room_tags_revisions as r"
|
||||||
|
@ -83,13 +87,16 @@ class TagsWorkerStore(AccountDataWorkerStore):
|
||||||
" ORDER BY stream_id ASC LIMIT ?"
|
" ORDER BY stream_id ASC LIMIT ?"
|
||||||
)
|
)
|
||||||
txn.execute(sql, (last_id, current_id, limit))
|
txn.execute(sql, (last_id, current_id, limit))
|
||||||
return txn.fetchall()
|
# mypy doesn't understand what the query is selecting.
|
||||||
|
return cast(List[Tuple[int, str, str]], txn.fetchall())
|
||||||
|
|
||||||
tag_ids = await self.db_pool.runInteraction(
|
tag_ids = await self.db_pool.runInteraction(
|
||||||
"get_all_updated_tags", get_all_updated_tags_txn
|
"get_all_updated_tags", get_all_updated_tags_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_tag_content(txn, tag_ids):
|
def get_tag_content(
|
||||||
|
txn: LoggingTransaction, tag_ids
|
||||||
|
) -> List[Tuple[int, Tuple[str, str, str]]]:
|
||||||
sql = "SELECT tag, content FROM room_tags WHERE user_id=? AND room_id=?"
|
sql = "SELECT tag, content FROM room_tags WHERE user_id=? AND room_id=?"
|
||||||
results = []
|
results = []
|
||||||
for stream_id, user_id, room_id in tag_ids:
|
for stream_id, user_id, room_id in tag_ids:
|
||||||
|
@ -127,15 +134,15 @@ class TagsWorkerStore(AccountDataWorkerStore):
|
||||||
given version
|
given version
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id(str): The user to get the tags for.
|
user_id: The user to get the tags for.
|
||||||
stream_id(int): The earliest update to get for the user.
|
stream_id: The earliest update to get for the user.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A mapping from room_id strings to lists of tag strings for all the
|
A mapping from room_id strings to lists of tag strings for all the
|
||||||
rooms that changed since the stream_id token.
|
rooms that changed since the stream_id token.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_updated_tags_txn(txn):
|
def get_updated_tags_txn(txn: LoggingTransaction) -> List[str]:
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT room_id from room_tags_revisions"
|
"SELECT room_id from room_tags_revisions"
|
||||||
" WHERE user_id = ? AND stream_id > ?"
|
" WHERE user_id = ? AND stream_id > ?"
|
||||||
|
@ -200,7 +207,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
|
||||||
|
|
||||||
content_json = json_encoder.encode(content)
|
content_json = json_encoder.encode(content)
|
||||||
|
|
||||||
def add_tag_txn(txn, next_id):
|
def add_tag_txn(txn: LoggingTransaction, next_id: int) -> None:
|
||||||
self.db_pool.simple_upsert_txn(
|
self.db_pool.simple_upsert_txn(
|
||||||
txn,
|
txn,
|
||||||
table="room_tags",
|
table="room_tags",
|
||||||
|
@ -224,7 +231,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
|
||||||
"""
|
"""
|
||||||
assert self._can_write_to_account_data
|
assert self._can_write_to_account_data
|
||||||
|
|
||||||
def remove_tag_txn(txn, next_id):
|
def remove_tag_txn(txn: LoggingTransaction, next_id: int) -> None:
|
||||||
sql = (
|
sql = (
|
||||||
"DELETE FROM room_tags "
|
"DELETE FROM room_tags "
|
||||||
" WHERE user_id = ? AND room_id = ? AND tag = ?"
|
" WHERE user_id = ? AND room_id = ? AND tag = ?"
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
# Copyright 2014-2016 OpenMarket Ltd
|
# Copyright 2014-2016 OpenMarket Ltd
|
||||||
|
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -11,6 +12,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import abc
|
||||||
import heapq
|
import heapq
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
|
@ -87,7 +89,25 @@ def _load_current_id(
|
||||||
return (max if step > 0 else min)(current_id, step)
|
return (max if step > 0 else min)(current_id, step)
|
||||||
|
|
||||||
|
|
||||||
class StreamIdGenerator:
|
class AbstractStreamIdGenerator(metaclass=abc.ABCMeta):
|
||||||
|
@abc.abstractmethod
|
||||||
|
def get_next(self) -> AsyncContextManager[int]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def get_current_token(self) -> int:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def get_current_token_for_writer(self, instance_name: str) -> int:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class StreamIdGenerator(AbstractStreamIdGenerator):
|
||||||
"""Used to generate new stream ids when persisting events while keeping
|
"""Used to generate new stream ids when persisting events while keeping
|
||||||
track of which transactions have been completed.
|
track of which transactions have been completed.
|
||||||
|
|
||||||
|
@ -209,7 +229,7 @@ class StreamIdGenerator:
|
||||||
return self.get_current_token()
|
return self.get_current_token()
|
||||||
|
|
||||||
|
|
||||||
class MultiWriterIdGenerator:
|
class MultiWriterIdGenerator(AbstractStreamIdGenerator):
|
||||||
"""An ID generator that tracks a stream that can have multiple writers.
|
"""An ID generator that tracks a stream that can have multiple writers.
|
||||||
|
|
||||||
Uses a Postgres sequence to coordinate ID assignment, but positions of other
|
Uses a Postgres sequence to coordinate ID assignment, but positions of other
|
||||||
|
|
Loading…
Reference in a new issue