0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-18 14:54:03 +01:00

Add a get_next_txn method to StreamIdGenerator to match MultiWriterIdGenerator (#15191

This commit is contained in:
Andrew Morgan 2023-03-02 18:27:00 +00:00 committed by GitHub
parent ecbe0ddbe7
commit 1eea662780
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 48 additions and 11 deletions

1
changelog.d/15191.misc Normal file
View file

@ -0,0 +1 @@
Add a `get_next_txn` method to `StreamIdGenerator` to match `MultiWriterIdGenerator`.

View file

@ -40,7 +40,6 @@ from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import ( from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator, AbstractStreamIdGenerator,
AbstractStreamIdTracker,
MultiWriterIdGenerator, MultiWriterIdGenerator,
StreamIdGenerator, StreamIdGenerator,
) )
@ -64,14 +63,12 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
): ):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
# `_can_write_to_account_data` indicates whether the current worker is allowed
# to write account data. A value of `True` implies that `_account_data_id_gen`
# is an `AbstractStreamIdGenerator` and not just a tracker.
self._account_data_id_gen: AbstractStreamIdTracker
self._can_write_to_account_data = ( self._can_write_to_account_data = (
self._instance_name in hs.config.worker.writers.account_data self._instance_name in hs.config.worker.writers.account_data
) )
self._account_data_id_gen: AbstractStreamIdGenerator
if isinstance(database.engine, PostgresEngine): if isinstance(database.engine, PostgresEngine):
self._account_data_id_gen = MultiWriterIdGenerator( self._account_data_id_gen = MultiWriterIdGenerator(
db_conn=db_conn, db_conn=db_conn,
@ -558,7 +555,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
The maximum stream ID. The maximum stream ID.
""" """
assert self._can_write_to_account_data assert self._can_write_to_account_data
assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
content_json = json_encoder.encode(content) content_json = json_encoder.encode(content)
@ -598,7 +594,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
data to delete. data to delete.
""" """
assert self._can_write_to_account_data assert self._can_write_to_account_data
assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
def _remove_account_data_for_room_txn( def _remove_account_data_for_room_txn(
txn: LoggingTransaction, next_id: int txn: LoggingTransaction, next_id: int
@ -663,7 +658,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
The maximum stream ID. The maximum stream ID.
""" """
assert self._can_write_to_account_data assert self._can_write_to_account_data
assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
async with self._account_data_id_gen.get_next() as next_id: async with self._account_data_id_gen.get_next() as next_id:
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
@ -770,7 +764,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
to delete. to delete.
""" """
assert self._can_write_to_account_data assert self._can_write_to_account_data
assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
def _remove_account_data_for_user_txn( def _remove_account_data_for_user_txn(
txn: LoggingTransaction, next_id: int txn: LoggingTransaction, next_id: int

View file

@ -158,6 +158,15 @@ class AbstractStreamIdGenerator(AbstractStreamIdTracker):
""" """
raise NotImplementedError() raise NotImplementedError()
@abc.abstractmethod
def get_next_txn(self, txn: LoggingTransaction) -> int:
"""
Usage:
stream_id_gen.get_next_txn(txn)
# ... persist events ...
"""
raise NotImplementedError()
class StreamIdGenerator(AbstractStreamIdGenerator): class StreamIdGenerator(AbstractStreamIdGenerator):
"""Generates and tracks stream IDs for a stream with a single writer. """Generates and tracks stream IDs for a stream with a single writer.
@ -263,6 +272,40 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
return _AsyncCtxManagerWrapper(manager()) return _AsyncCtxManagerWrapper(manager())
def get_next_txn(self, txn: LoggingTransaction) -> int:
"""
Retrieve the next stream ID from within a database transaction.
Clean-up functions will be called when the transaction finishes.
Args:
txn: The database transaction object.
Returns:
The next stream ID.
"""
if not self._is_writer:
raise Exception("Tried to allocate stream ID on non-writer")
# Get the next stream ID.
with self._lock:
self._current += self._step
next_id = self._current
self._unfinished_ids[next_id] = next_id
def clear_unfinished_id(id_to_clear: int) -> None:
"""A function to mark processing this ID as finished"""
with self._lock:
self._unfinished_ids.pop(id_to_clear)
# Mark this ID as finished once the database transaction itself finishes.
txn.call_after(clear_unfinished_id, next_id)
txn.call_on_exception(clear_unfinished_id, next_id)
# Return the new ID.
return next_id
def get_current_token(self) -> int: def get_current_token(self) -> int:
if not self._is_writer: if not self._is_writer:
return self._current return self._current
@ -568,7 +611,7 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
""" """
Usage: Usage:
stream_id = stream_id_gen.get_next(txn) stream_id = stream_id_gen.get_next_txn(txn)
# ... persist event ... # ... persist event ...
""" """

View file

@ -205,7 +205,7 @@ class LocalSequenceGenerator(SequenceGenerator):
""" """
Args: Args:
get_first_callback: a callback which is called on the first call to get_first_callback: a callback which is called on the first call to
get_next_id_txn; should return the curreent maximum id get_next_id_txn; should return the current maximum id
""" """
# the callback. this is cleared after it is called, so that it can be GCed. # the callback. this is cleared after it is called, so that it can be GCed.
self._callback: Optional[GetFirstCallbackType] = get_first_callback self._callback: Optional[GetFirstCallbackType] = get_first_callback