Add functions to MultiWriterIdGen used by events stream (#8164)

This commit is contained in:
Erik Johnston 2020-08-25 17:32:30 +01:00 committed by GitHub
parent 5099bd68da
commit eba98fb024
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 145 additions and 3 deletions

1
changelog.d/8164.misc Normal file
View file

@ -0,0 +1 @@
Add functions to `MultiWriterIdGen` used by events stream.

View file

@ -14,9 +14,10 @@
# limitations under the License. # limitations under the License.
import contextlib import contextlib
import heapq
import threading import threading
from collections import deque from collections import deque
from typing import Dict, Set from typing import Dict, List, Set
from typing_extensions import Deque from typing_extensions import Deque
@ -210,6 +211,23 @@ class MultiWriterIdGenerator:
# should be less than the minimum of this set (if not empty). # should be less than the minimum of this set (if not empty).
self._unfinished_ids = set() # type: Set[int] self._unfinished_ids = set() # type: Set[int]
# We track the max position where we know everything before has been
# persisted. This is done by a) looking at the min across all instances
# and b) noting that if we have seen a run of persisted positions
# without gaps (e.g. 5, 6, 7) then we can skip forward (e.g. to 7).
#
# Note: There is no guarentee that the IDs generated by the sequence
# will be gapless; gaps can form when e.g. a transaction was rolled
# back. This means that sometimes we won't be able to skip forward the
# position even though everything has been persisted. However, since
# gaps should be relatively rare it's still worth doing the book keeping
# that allows us to skip forwards when there are gapless runs of
# positions.
self._persisted_upto_position = (
min(self._current_positions.values()) if self._current_positions else 0
)
self._known_persisted_positions = [] # type: List[int]
self._sequence_gen = PostgresSequenceGenerator(sequence_name) self._sequence_gen = PostgresSequenceGenerator(sequence_name)
def _load_current_ids( def _load_current_ids(
@ -234,9 +252,12 @@ class MultiWriterIdGenerator:
return current_positions return current_positions
def _load_next_id_txn(self, txn): def _load_next_id_txn(self, txn) -> int:
return self._sequence_gen.get_next_id_txn(txn) return self._sequence_gen.get_next_id_txn(txn)
def _load_next_mult_id_txn(self, txn, n: int) -> List[int]:
return self._sequence_gen.get_next_mult_txn(txn, n)
async def get_next(self): async def get_next(self):
""" """
Usage: Usage:
@ -262,6 +283,34 @@ class MultiWriterIdGenerator:
return manager() return manager()
async def get_next_mult(self, n: int):
"""
Usage:
with await stream_id_gen.get_next_mult(5) as stream_ids:
# ... persist events ...
"""
next_ids = await self._db.runInteraction(
"_load_next_mult_id", self._load_next_mult_id_txn, n
)
# Assert the fetched ID is actually greater than any ID we've already
# seen. If not, then the sequence and table have got out of sync
# somehow.
assert max(self.get_positions().values(), default=0) < min(next_ids)
with self._lock:
self._unfinished_ids.update(next_ids)
@contextlib.contextmanager
def manager():
try:
yield next_ids
finally:
for i in next_ids:
self._mark_id_as_finished(i)
return manager()
def get_next_txn(self, txn: LoggingTransaction): def get_next_txn(self, txn: LoggingTransaction):
""" """
Usage: Usage:
@ -326,3 +375,53 @@ class MultiWriterIdGenerator:
self._current_positions[instance_name] = max( self._current_positions[instance_name] = max(
new_id, self._current_positions.get(instance_name, 0) new_id, self._current_positions.get(instance_name, 0)
) )
self._add_persisted_position(new_id)
def get_persisted_upto_position(self) -> int:
"""Get the max position where all previous positions have been
persisted.
Note: In the worst case scenario this will be equal to the minimum
position across writers. This means that the returned position here can
lag if one writer doesn't write very often.
"""
with self._lock:
return self._persisted_upto_position
def _add_persisted_position(self, new_id: int):
"""Record that we have persisted a position.
This is used to keep the `_current_positions` up to date.
"""
# We require that the lock is locked by caller
assert self._lock.locked()
heapq.heappush(self._known_persisted_positions, new_id)
# We move the current min position up if the minimum current positions
# of all instances is higher (since by definition all positions less
# that that have been persisted).
min_curr = min(self._current_positions.values())
self._persisted_upto_position = max(min_curr, self._persisted_upto_position)
# We now iterate through the seen positions, discarding those that are
# less than the current min positions, and incrementing the min position
# if its exactly one greater.
#
# This is also where we discard items from `_known_persisted_positions`
# (to ensure the list doesn't infinitely grow).
while self._known_persisted_positions:
if self._known_persisted_positions[0] <= self._persisted_upto_position:
heapq.heappop(self._known_persisted_positions)
elif (
self._known_persisted_positions[0] == self._persisted_upto_position + 1
):
heapq.heappop(self._known_persisted_positions)
self._persisted_upto_position += 1
else:
# There was a gap in seen positions, so there is nothing more to
# do.
break

View file

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import abc import abc
import threading import threading
from typing import Callable, Optional from typing import Callable, List, Optional
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
from synapse.storage.types import Cursor from synapse.storage.types import Cursor
@ -39,6 +39,12 @@ class PostgresSequenceGenerator(SequenceGenerator):
txn.execute("SELECT nextval(?)", (self._sequence_name,)) txn.execute("SELECT nextval(?)", (self._sequence_name,))
return txn.fetchone()[0] return txn.fetchone()[0]
def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
txn.execute(
"SELECT nextval(?) FROM generate_series(1, ?)", (self._sequence_name, n)
)
return [i for (i,) in txn]
GetFirstCallbackType = Callable[[Cursor], int] GetFirstCallbackType = Callable[[Cursor], int]

View file

@ -182,3 +182,39 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen.get_positions(), {"master": 8}) self.assertEqual(id_gen.get_positions(), {"master": 8})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 8) self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
def test_get_persisted_upto_position(self):
"""Test that `get_persisted_upto_position` correctly tracks updates to
positions.
"""
self._insert_rows("first", 3)
self._insert_rows("second", 5)
id_gen = self._create_id_generator("first")
# Min is 3 and there is a gap between 5, so we expect it to be 3.
self.assertEqual(id_gen.get_persisted_upto_position(), 3)
# We advance "first" straight to 6. Min is now 5 but there is no gap so
# we expect it to be 6
id_gen.advance("first", 6)
self.assertEqual(id_gen.get_persisted_upto_position(), 6)
# No gap, so we expect 7.
id_gen.advance("second", 7)
self.assertEqual(id_gen.get_persisted_upto_position(), 7)
# We haven't seen 8 yet, so we expect 7 still.
id_gen.advance("second", 9)
self.assertEqual(id_gen.get_persisted_upto_position(), 7)
# Now that we've seen 7, 8 and 9 we can got straight to 9.
id_gen.advance("first", 8)
self.assertEqual(id_gen.get_persisted_upto_position(), 9)
# Jump forward with gaps. The minimum is 11, even though we haven't seen
# 10 we know that everything before 11 must be persisted.
id_gen.advance("first", 11)
id_gen.advance("second", 15)
self.assertEqual(id_gen.get_persisted_upto_position(), 11)