Add type hints to some storage classes (#11307)

This commit is contained in:
Patrick Cloke 2021-11-11 08:47:31 -05:00 committed by GitHub
parent 6ce19b94e8
commit 64ef25391d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 116 additions and 54 deletions

1
changelog.d/11307.misc Normal file
View file

@ -0,0 +1 @@
Add type hints to storage classes.

View file

@ -27,8 +27,6 @@ exclude = (?x)
|synapse/storage/databases/main/__init__.py |synapse/storage/databases/main/__init__.py
|synapse/storage/databases/main/account_data.py |synapse/storage/databases/main/account_data.py
|synapse/storage/databases/main/cache.py |synapse/storage/databases/main/cache.py
|synapse/storage/databases/main/censor_events.py
|synapse/storage/databases/main/deviceinbox.py
|synapse/storage/databases/main/devices.py |synapse/storage/databases/main/devices.py
|synapse/storage/databases/main/directory.py |synapse/storage/databases/main/directory.py
|synapse/storage/databases/main/e2e_room_keys.py |synapse/storage/databases/main/e2e_room_keys.py
@ -38,19 +36,15 @@ exclude = (?x)
|synapse/storage/databases/main/events_bg_updates.py |synapse/storage/databases/main/events_bg_updates.py
|synapse/storage/databases/main/events_forward_extremities.py |synapse/storage/databases/main/events_forward_extremities.py
|synapse/storage/databases/main/events_worker.py |synapse/storage/databases/main/events_worker.py
|synapse/storage/databases/main/filtering.py
|synapse/storage/databases/main/group_server.py |synapse/storage/databases/main/group_server.py
|synapse/storage/databases/main/lock.py
|synapse/storage/databases/main/media_repository.py |synapse/storage/databases/main/media_repository.py
|synapse/storage/databases/main/metrics.py |synapse/storage/databases/main/metrics.py
|synapse/storage/databases/main/monthly_active_users.py |synapse/storage/databases/main/monthly_active_users.py
|synapse/storage/databases/main/openid.py
|synapse/storage/databases/main/presence.py |synapse/storage/databases/main/presence.py
|synapse/storage/databases/main/profile.py |synapse/storage/databases/main/profile.py
|synapse/storage/databases/main/purge_events.py |synapse/storage/databases/main/purge_events.py
|synapse/storage/databases/main/push_rule.py |synapse/storage/databases/main/push_rule.py
|synapse/storage/databases/main/receipts.py |synapse/storage/databases/main/receipts.py
|synapse/storage/databases/main/rejections.py
|synapse/storage/databases/main/room.py |synapse/storage/databases/main/room.py
|synapse/storage/databases/main/room_batch.py |synapse/storage/databases/main/room_batch.py
|synapse/storage/databases/main/roommember.py |synapse/storage/databases/main/roommember.py
@ -59,7 +53,6 @@ exclude = (?x)
|synapse/storage/databases/main/state.py |synapse/storage/databases/main/state.py
|synapse/storage/databases/main/state_deltas.py |synapse/storage/databases/main/state_deltas.py
|synapse/storage/databases/main/stats.py |synapse/storage/databases/main/stats.py
|synapse/storage/databases/main/tags.py
|synapse/storage/databases/main/transactions.py |synapse/storage/databases/main/transactions.py
|synapse/storage/databases/main/user_directory.py |synapse/storage/databases/main/user_directory.py
|synapse/storage/databases/main/user_erasure_store.py |synapse/storage/databases/main/user_erasure_store.py

View file

@ -13,12 +13,12 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Optional
from synapse.events.utils import prune_event_dict from synapse.events.utils import prune_event_dict
from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.util import json_encoder from synapse.util import json_encoder
@ -41,7 +41,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
hs.get_clock().looping_call(self._censor_redactions, 5 * 60 * 1000) hs.get_clock().looping_call(self._censor_redactions, 5 * 60 * 1000)
@wrap_as_background_process("_censor_redactions") @wrap_as_background_process("_censor_redactions")
async def _censor_redactions(self): async def _censor_redactions(self) -> None:
"""Censors all redactions older than the configured period that haven't """Censors all redactions older than the configured period that haven't
been censored yet. been censored yet.
@ -105,7 +105,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
and original_event.internal_metadata.is_redacted() and original_event.internal_metadata.is_redacted()
): ):
# Redaction was allowed # Redaction was allowed
pruned_json = json_encoder.encode( pruned_json: Optional[str] = json_encoder.encode(
prune_event_dict( prune_event_dict(
original_event.room_version, original_event.get_dict() original_event.room_version, original_event.get_dict()
) )
@ -116,7 +116,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
updates.append((redaction_id, event_id, pruned_json)) updates.append((redaction_id, event_id, pruned_json))
def _update_censor_txn(txn): def _update_censor_txn(txn: LoggingTransaction) -> None:
for redaction_id, event_id, pruned_json in updates: for redaction_id, event_id, pruned_json in updates:
if pruned_json: if pruned_json:
self._censor_event_txn(txn, event_id, pruned_json) self._censor_event_txn(txn, event_id, pruned_json)
@ -130,14 +130,16 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
await self.db_pool.runInteraction("_update_censor_txn", _update_censor_txn) await self.db_pool.runInteraction("_update_censor_txn", _update_censor_txn)
def _censor_event_txn(self, txn, event_id, pruned_json): def _censor_event_txn(
self, txn: LoggingTransaction, event_id: str, pruned_json: str
) -> None:
"""Censor an event by replacing its JSON in the event_json table with the """Censor an event by replacing its JSON in the event_json table with the
provided pruned JSON. provided pruned JSON.
Args: Args:
txn (LoggingTransaction): The database transaction. txn: The database transaction.
event_id (str): The ID of the event to censor. event_id: The ID of the event to censor.
pruned_json (str): The pruned JSON pruned_json: The pruned JSON
""" """
self.db_pool.simple_update_one_txn( self.db_pool.simple_update_one_txn(
txn, txn,
@ -157,7 +159,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
# Try to retrieve the event's content from the database or the event cache. # Try to retrieve the event's content from the database or the event cache.
event = await self.get_event(event_id) event = await self.get_event(event_id)
def delete_expired_event_txn(txn): def delete_expired_event_txn(txn: LoggingTransaction) -> None:
# Delete the expiry timestamp associated with this event from the database. # Delete the expiry timestamp associated with this event from the database.
self._delete_event_expiry_txn(txn, event_id) self._delete_event_expiry_txn(txn, event_id)
@ -194,14 +196,14 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
"delete_expired_event", delete_expired_event_txn "delete_expired_event", delete_expired_event_txn
) )
def _delete_event_expiry_txn(self, txn, event_id): def _delete_event_expiry_txn(self, txn: LoggingTransaction, event_id: str) -> None:
"""Delete the expiry timestamp associated with an event ID without deleting the """Delete the expiry timestamp associated with an event ID without deleting the
actual event. actual event.
Args: Args:
txn (LoggingTransaction): The transaction to use to perform the deletion. txn: The transaction to use to perform the deletion.
event_id (str): The event ID to delete the associated expiry timestamp of. event_id: The event ID to delete the associated expiry timestamp of.
""" """
return self.db_pool.simple_delete_txn( self.db_pool.simple_delete_txn(
txn=txn, table="event_expiry", keyvalues={"event_id": event_id} txn=txn, table="event_expiry", keyvalues={"event_id": event_id}
) )

View file

@ -1,4 +1,5 @@
# Copyright 2016 OpenMarket Ltd # Copyright 2016 OpenMarket Ltd
# Copyright 2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -19,9 +20,17 @@ from synapse.logging import issue9533_logger
from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.replication.tcp.streams import ToDeviceStream from synapse.replication.tcp.streams import ToDeviceStream
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator,
MultiWriterIdGenerator,
StreamIdGenerator,
)
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
@ -34,14 +43,21 @@ logger = logging.getLogger(__name__)
class DeviceInboxWorkerStore(SQLBaseStore): class DeviceInboxWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self._instance_name = hs.get_instance_name() self._instance_name = hs.get_instance_name()
# Map of (user_id, device_id) to the last stream_id that has been # Map of (user_id, device_id) to the last stream_id that has been
# deleted up to. This is so that we can no op deletions. # deleted up to. This is so that we can no op deletions.
self._last_device_delete_cache = ExpiringCache( self._last_device_delete_cache: ExpiringCache[
Tuple[str, Optional[str]], int
] = ExpiringCache(
cache_name="last_device_delete_cache", cache_name="last_device_delete_cache",
clock=self._clock, clock=self._clock,
max_len=10000, max_len=10000,
@ -53,14 +69,16 @@ class DeviceInboxWorkerStore(SQLBaseStore):
self._instance_name in hs.config.worker.writers.to_device self._instance_name in hs.config.worker.writers.to_device
) )
self._device_inbox_id_gen = MultiWriterIdGenerator( self._device_inbox_id_gen: AbstractStreamIdGenerator = (
db_conn=db_conn, MultiWriterIdGenerator(
db=database, db_conn=db_conn,
stream_name="to_device", db=database,
instance_name=self._instance_name, stream_name="to_device",
tables=[("device_inbox", "instance_name", "stream_id")], instance_name=self._instance_name,
sequence_name="device_inbox_sequence", tables=[("device_inbox", "instance_name", "stream_id")],
writers=hs.config.worker.writers.to_device, sequence_name="device_inbox_sequence",
writers=hs.config.worker.writers.to_device,
)
) )
else: else:
self._can_write_to_device = True self._can_write_to_device = True
@ -101,6 +119,8 @@ class DeviceInboxWorkerStore(SQLBaseStore):
def process_replication_rows(self, stream_name, instance_name, token, rows): def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == ToDeviceStream.NAME: if stream_name == ToDeviceStream.NAME:
# If replication is happening than postgres must be being used.
assert isinstance(self._device_inbox_id_gen, MultiWriterIdGenerator)
self._device_inbox_id_gen.advance(instance_name, token) self._device_inbox_id_gen.advance(instance_name, token)
for row in rows: for row in rows:
if row.entity.startswith("@"): if row.entity.startswith("@"):
@ -220,11 +240,11 @@ class DeviceInboxWorkerStore(SQLBaseStore):
log_kv({"message": f"deleted {count} messages for device", "count": count}) log_kv({"message": f"deleted {count} messages for device", "count": count})
# Update the cache, ensuring that we only ever increase the value # Update the cache, ensuring that we only ever increase the value
last_deleted_stream_id = self._last_device_delete_cache.get( updated_last_deleted_stream_id = self._last_device_delete_cache.get(
(user_id, device_id), 0 (user_id, device_id), 0
) )
self._last_device_delete_cache[(user_id, device_id)] = max( self._last_device_delete_cache[(user_id, device_id)] = max(
last_deleted_stream_id, up_to_stream_id updated_last_deleted_stream_id, up_to_stream_id
) )
return count return count
@ -432,7 +452,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
) )
async with self._device_inbox_id_gen.get_next() as stream_id: async with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec() now_ms = self._clock.time_msec()
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
"add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
) )
@ -483,7 +503,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
) )
async with self._device_inbox_id_gen.get_next() as stream_id: async with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec() now_ms = self._clock.time_msec()
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
"add_messages_from_remote_to_device_inbox", "add_messages_from_remote_to_device_inbox",
add_messages_txn, add_messages_txn,

View file

@ -1,4 +1,5 @@
# Copyright 2015, 2016 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -18,6 +19,7 @@ from canonicaljson import encode_canonical_json
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import LoggingTransaction
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
@ -49,7 +51,7 @@ class FilteringStore(SQLBaseStore):
# Need an atomic transaction to SELECT the maximal ID so far then # Need an atomic transaction to SELECT the maximal ID so far then
# INSERT a new one # INSERT a new one
def _do_txn(txn): def _do_txn(txn: LoggingTransaction) -> int:
sql = ( sql = (
"SELECT filter_id FROM user_filters " "SELECT filter_id FROM user_filters "
"WHERE user_id = ? AND filter_json = ?" "WHERE user_id = ? AND filter_json = ?"
@ -61,7 +63,7 @@ class FilteringStore(SQLBaseStore):
sql = "SELECT MAX(filter_id) FROM user_filters WHERE user_id = ?" sql = "SELECT MAX(filter_id) FROM user_filters WHERE user_id = ?"
txn.execute(sql, (user_localpart,)) txn.execute(sql, (user_localpart,))
max_id = txn.fetchone()[0] max_id = txn.fetchone()[0] # type: ignore[index]
if max_id is None: if max_id is None:
filter_id = 0 filter_id = 0
else: else:

View file

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from types import TracebackType from types import TracebackType
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Type from typing import TYPE_CHECKING, Optional, Tuple, Type
from weakref import WeakValueDictionary from weakref import WeakValueDictionary
from twisted.internet.interfaces import IReactorCore from twisted.internet.interfaces import IReactorCore
@ -62,7 +62,9 @@ class LockStore(SQLBaseStore):
# A map from `(lock_name, lock_key)` to the token of any locks that we # A map from `(lock_name, lock_key)` to the token of any locks that we
# think we currently hold. # think we currently hold.
self._live_tokens: Dict[Tuple[str, str], Lock] = WeakValueDictionary() self._live_tokens: WeakValueDictionary[
Tuple[str, str], Lock
] = WeakValueDictionary()
# When we shut down we want to remove the locks. Technically this can # When we shut down we want to remove the locks. Technically this can
# lead to a race, as we may drop the lock while we are still processing. # lead to a race, as we may drop the lock while we are still processing.

View file

@ -1,6 +1,21 @@
# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional from typing import Optional
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import LoggingTransaction
class OpenIdStore(SQLBaseStore): class OpenIdStore(SQLBaseStore):
@ -20,7 +35,7 @@ class OpenIdStore(SQLBaseStore):
async def get_user_id_for_open_id_token( async def get_user_id_for_open_id_token(
self, token: str, ts_now_ms: int self, token: str, ts_now_ms: int
) -> Optional[str]: ) -> Optional[str]:
def get_user_id_for_token_txn(txn): def get_user_id_for_token_txn(txn: LoggingTransaction) -> Optional[str]:
sql = ( sql = (
"SELECT user_id FROM open_id_tokens" "SELECT user_id FROM open_id_tokens"
" WHERE token = ? AND ? <= ts_valid_until_ms" " WHERE token = ? AND ? <= ts_valid_until_ms"

View file

@ -1,5 +1,6 @@
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd # Copyright 2018 New Vector Ltd
# Copyright 2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -14,9 +15,10 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Dict, List, Tuple from typing import Dict, List, Tuple, cast
from synapse.storage._base import db_to_json from synapse.storage._base import db_to_json
from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main.account_data import AccountDataWorkerStore from synapse.storage.databases.main.account_data import AccountDataWorkerStore
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import json_encoder from synapse.util import json_encoder
@ -50,7 +52,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
async def get_all_updated_tags( async def get_all_updated_tags(
self, instance_name: str, last_id: int, current_id: int, limit: int self, instance_name: str, last_id: int, current_id: int, limit: int
) -> Tuple[List[Tuple[int, tuple]], int, bool]: ) -> Tuple[List[Tuple[int, Tuple[str, str, str]]], int, bool]:
"""Get updates for tags replication stream. """Get updates for tags replication stream.
Args: Args:
@ -75,7 +77,9 @@ class TagsWorkerStore(AccountDataWorkerStore):
if last_id == current_id: if last_id == current_id:
return [], current_id, False return [], current_id, False
def get_all_updated_tags_txn(txn): def get_all_updated_tags_txn(
txn: LoggingTransaction,
) -> List[Tuple[int, str, str]]:
sql = ( sql = (
"SELECT stream_id, user_id, room_id" "SELECT stream_id, user_id, room_id"
" FROM room_tags_revisions as r" " FROM room_tags_revisions as r"
@ -83,13 +87,16 @@ class TagsWorkerStore(AccountDataWorkerStore):
" ORDER BY stream_id ASC LIMIT ?" " ORDER BY stream_id ASC LIMIT ?"
) )
txn.execute(sql, (last_id, current_id, limit)) txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall() # mypy doesn't understand what the query is selecting.
return cast(List[Tuple[int, str, str]], txn.fetchall())
tag_ids = await self.db_pool.runInteraction( tag_ids = await self.db_pool.runInteraction(
"get_all_updated_tags", get_all_updated_tags_txn "get_all_updated_tags", get_all_updated_tags_txn
) )
def get_tag_content(txn, tag_ids): def get_tag_content(
txn: LoggingTransaction, tag_ids
) -> List[Tuple[int, Tuple[str, str, str]]]:
sql = "SELECT tag, content FROM room_tags WHERE user_id=? AND room_id=?" sql = "SELECT tag, content FROM room_tags WHERE user_id=? AND room_id=?"
results = [] results = []
for stream_id, user_id, room_id in tag_ids: for stream_id, user_id, room_id in tag_ids:
@ -127,15 +134,15 @@ class TagsWorkerStore(AccountDataWorkerStore):
given version given version
Args: Args:
user_id(str): The user to get the tags for. user_id: The user to get the tags for.
stream_id(int): The earliest update to get for the user. stream_id: The earliest update to get for the user.
Returns: Returns:
A mapping from room_id strings to lists of tag strings for all the A mapping from room_id strings to lists of tag strings for all the
rooms that changed since the stream_id token. rooms that changed since the stream_id token.
""" """
def get_updated_tags_txn(txn): def get_updated_tags_txn(txn: LoggingTransaction) -> List[str]:
sql = ( sql = (
"SELECT room_id from room_tags_revisions" "SELECT room_id from room_tags_revisions"
" WHERE user_id = ? AND stream_id > ?" " WHERE user_id = ? AND stream_id > ?"
@ -200,7 +207,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
content_json = json_encoder.encode(content) content_json = json_encoder.encode(content)
def add_tag_txn(txn, next_id): def add_tag_txn(txn: LoggingTransaction, next_id: int) -> None:
self.db_pool.simple_upsert_txn( self.db_pool.simple_upsert_txn(
txn, txn,
table="room_tags", table="room_tags",
@ -224,7 +231,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
""" """
assert self._can_write_to_account_data assert self._can_write_to_account_data
def remove_tag_txn(txn, next_id): def remove_tag_txn(txn: LoggingTransaction, next_id: int) -> None:
sql = ( sql = (
"DELETE FROM room_tags " "DELETE FROM room_tags "
" WHERE user_id = ? AND room_id = ? AND tag = ?" " WHERE user_id = ? AND room_id = ? AND tag = ?"

View file

@ -1,4 +1,5 @@
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -11,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import abc
import heapq import heapq
import logging import logging
import threading import threading
@ -87,7 +89,25 @@ def _load_current_id(
return (max if step > 0 else min)(current_id, step) return (max if step > 0 else min)(current_id, step)
class StreamIdGenerator: class AbstractStreamIdGenerator(metaclass=abc.ABCMeta):
@abc.abstractmethod
def get_next(self) -> AsyncContextManager[int]:
raise NotImplementedError()
@abc.abstractmethod
def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
raise NotImplementedError()
@abc.abstractmethod
def get_current_token(self) -> int:
raise NotImplementedError()
@abc.abstractmethod
def get_current_token_for_writer(self, instance_name: str) -> int:
raise NotImplementedError()
class StreamIdGenerator(AbstractStreamIdGenerator):
"""Used to generate new stream ids when persisting events while keeping """Used to generate new stream ids when persisting events while keeping
track of which transactions have been completed. track of which transactions have been completed.
@ -209,7 +229,7 @@ class StreamIdGenerator:
return self.get_current_token() return self.get_current_token()
class MultiWriterIdGenerator: class MultiWriterIdGenerator(AbstractStreamIdGenerator):
"""An ID generator that tracks a stream that can have multiple writers. """An ID generator that tracks a stream that can have multiple writers.
Uses a Postgres sequence to coordinate ID assignment, but positions of other Uses a Postgres sequence to coordinate ID assignment, but positions of other