From 5c35074d859077f5ade846c450d19ea9dceb62f0 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 12 Oct 2021 08:55:33 -0400 Subject: [PATCH 1/2] Reset global cache state before cache tests. (#11036) This reverts #11019 and structures the code a bit more like it was before #10985. The global cache state must be reset before running the tests since other test cases might have configured caching (and thus touched the global state). --- changelog.d/11036.misc | 1 + tests/config/test_cache.py | 28 +++++++++++++--------------- 2 files changed, 14 insertions(+), 15 deletions(-) create mode 100644 changelog.d/11036.misc diff --git a/changelog.d/11036.misc b/changelog.d/11036.misc new file mode 100644 index 0000000000..aae5ee62b2 --- /dev/null +++ b/changelog.d/11036.misc @@ -0,0 +1 @@ +Ensure that cache config tests do not share state. diff --git a/tests/config/test_cache.py b/tests/config/test_cache.py index 79d417568d..4bb82e810e 100644 --- a/tests/config/test_cache.py +++ b/tests/config/test_cache.py @@ -12,25 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import patch - from synapse.config.cache import CacheConfig, add_resizable_cache from synapse.util.caches.lrucache import LruCache from tests.unittest import TestCase -# Patch the global _CACHES so that each test runs against its own state. -@patch("synapse.config.cache._CACHES", new_callable=dict) class CacheConfigTests(TestCase): def setUp(self): - # Reset caches before each test + # Reset caches before each test since there's global state involved. self.config = CacheConfig() - - def tearDown(self): self.config.reset() - def test_individual_caches_from_environ(self, _caches): + def tearDown(self): + # Also reset the caches after each test to leave state pristine. + self.config.reset() + + def test_individual_caches_from_environ(self): """ Individual cache factors will be loaded from the environment. """ @@ -43,7 +41,7 @@ class CacheConfigTests(TestCase): self.assertEqual(dict(self.config.cache_factors), {"something_or_other": 2.0}) - def test_config_overrides_environ(self, _caches): + def test_config_overrides_environ(self): """ Individual cache factors defined in the environment will take precedence over those in the config. @@ -60,7 +58,7 @@ class CacheConfigTests(TestCase): {"foo": 1.0, "bar": 3.0, "something_or_other": 2.0}, ) - def test_individual_instantiated_before_config_load(self, _caches): + def test_individual_instantiated_before_config_load(self): """ If a cache is instantiated before the config is read, it will be given the default cache size in the interim, and then resized once the config @@ -76,7 +74,7 @@ class CacheConfigTests(TestCase): self.assertEqual(cache.max_size, 300) - def test_individual_instantiated_after_config_load(self, _caches): + def test_individual_instantiated_after_config_load(self): """ If a cache is instantiated after the config is read, it will be immediately resized to the correct size given the per_cache_factor if @@ -89,7 +87,7 @@ class CacheConfigTests(TestCase): add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) self.assertEqual(cache.max_size, 200) - def test_global_instantiated_before_config_load(self, _caches): + def test_global_instantiated_before_config_load(self): """ If a cache is instantiated before the config is read, it will be given the default cache size in the interim, and then resized to the new @@ -104,7 +102,7 @@ class CacheConfigTests(TestCase): self.assertEqual(cache.max_size, 400) - def test_global_instantiated_after_config_load(self, _caches): + def test_global_instantiated_after_config_load(self): """ If a cache is instantiated after the config is read, it will be immediately resized to the correct size given the global factor if there @@ -117,7 +115,7 @@ class CacheConfigTests(TestCase): add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) self.assertEqual(cache.max_size, 150) - def test_cache_with_asterisk_in_name(self, _caches): + def test_cache_with_asterisk_in_name(self): """Some caches have asterisks in their name, test that they are set correctly.""" config = { @@ -143,7 +141,7 @@ class CacheConfigTests(TestCase): add_resizable_cache("*cache_c*", cache_resize_callback=cache_c.set_cache_factor) self.assertEqual(cache_c.max_size, 200) - def test_apply_cache_factor_from_config(self, _caches): + def test_apply_cache_factor_from_config(self): """Caches can disable applying cache factor updates, mainly used by event cache size. """ From 333d6f4e843c49002623417e6aa22da0521c4742 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 12 Oct 2021 14:27:09 +0100 Subject: [PATCH 2/2] Fix race in `MultiWriterIdGenerator` (#11045) The race allowed the current position to advance too far when stream IDs are still being persisted. This happened when it received a new stream ID from a remote write between a new stream ID being allocated and it being added to the set of unpersisted stream IDs. Fixes #9424. --- changelog.d/11045.bugfix | 1 + synapse/storage/util/id_generators.py | 82 ++++++++++++++++++++++----- 2 files changed, 68 insertions(+), 15 deletions(-) create mode 100644 changelog.d/11045.bugfix diff --git a/changelog.d/11045.bugfix b/changelog.d/11045.bugfix new file mode 100644 index 0000000000..d712dc946a --- /dev/null +++ b/changelog.d/11045.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug when using multiple event persister workers where events were not correctly sent down `/sync` due to a race. diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 852bd79fee..670811611f 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -36,7 +36,7 @@ from typing import ( ) import attr -from sortedcontainers import SortedSet +from sortedcontainers import SortedList, SortedSet from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.database import ( @@ -265,6 +265,15 @@ class MultiWriterIdGenerator: # should be less than the minimum of this set (if not empty). self._unfinished_ids: SortedSet[int] = SortedSet() + # We also need to track when we've requested some new stream IDs but + # they haven't yet been added to the `_unfinished_ids` set. Every time + # we request a new stream ID we add the current max stream ID to the + # list, and remove it once we've added the newly allocated IDs to the + # `_unfinished_ids` set. This means that we *may* be allocated stream + # IDs above those in the list, and so we can't advance the local current + # position beyond the minimum stream ID in this list. + self._in_flight_fetches: SortedList[int] = SortedList() + # Set of local IDs that we've processed that are larger than the current # position, due to there being smaller unpersisted IDs. self._finished_ids: Set[int] = set() @@ -290,6 +299,9 @@ class MultiWriterIdGenerator: ) self._known_persisted_positions: List[int] = [] + # The maximum stream ID that we have seen been allocated across any writer. + self._max_seen_allocated_stream_id = 1 + self._sequence_gen = PostgresSequenceGenerator(sequence_name) # We check that the table and sequence haven't diverged. @@ -305,6 +317,10 @@ class MultiWriterIdGenerator: # This goes and fills out the above state from the database. self._load_current_ids(db_conn, tables) + self._max_seen_allocated_stream_id = max( + self._current_positions.values(), default=1 + ) + def _load_current_ids( self, db_conn: LoggingDatabaseConnection, @@ -411,10 +427,32 @@ class MultiWriterIdGenerator: cur.close() def _load_next_id_txn(self, txn: Cursor) -> int: - return self._sequence_gen.get_next_id_txn(txn) + stream_ids = self._load_next_mult_id_txn(txn, 1) + return stream_ids[0] def _load_next_mult_id_txn(self, txn: Cursor, n: int) -> List[int]: - return self._sequence_gen.get_next_mult_txn(txn, n) + # We need to track that we've requested some more stream IDs, and what + # the current max allocated stream ID is. This is to prevent a race + # where we've been allocated stream IDs but they have not yet been added + # to the `_unfinished_ids` set, allowing the current position to advance + # past them. + with self._lock: + current_max = self._max_seen_allocated_stream_id + self._in_flight_fetches.add(current_max) + + try: + stream_ids = self._sequence_gen.get_next_mult_txn(txn, n) + + with self._lock: + self._unfinished_ids.update(stream_ids) + self._max_seen_allocated_stream_id = max( + self._max_seen_allocated_stream_id, self._unfinished_ids[-1] + ) + finally: + with self._lock: + self._in_flight_fetches.remove(current_max) + + return stream_ids def get_next(self) -> AsyncContextManager[int]: """ @@ -463,9 +501,6 @@ class MultiWriterIdGenerator: next_id = self._load_next_id_txn(txn) - with self._lock: - self._unfinished_ids.add(next_id) - txn.call_after(self._mark_id_as_finished, next_id) txn.call_on_exception(self._mark_id_as_finished, next_id) @@ -497,15 +532,27 @@ class MultiWriterIdGenerator: new_cur: Optional[int] = None - if self._unfinished_ids: + if self._unfinished_ids or self._in_flight_fetches: # If there are unfinished IDs then the new position will be the - # largest finished ID less than the minimum unfinished ID. + # largest finished ID strictly less than the minimum unfinished + # ID. + + # The minimum unfinished ID needs to take account of both + # `_unfinished_ids` and `_in_flight_fetches`. + if self._unfinished_ids and self._in_flight_fetches: + # `_in_flight_fetches` stores the maximum safe stream ID, so + # we add one to make it equivalent to the minimum unsafe ID. + min_unfinished = min( + self._unfinished_ids[0], self._in_flight_fetches[0] + 1 + ) + elif self._in_flight_fetches: + min_unfinished = self._in_flight_fetches[0] + 1 + else: + min_unfinished = self._unfinished_ids[0] finished = set() - - min_unfinshed = self._unfinished_ids[0] for s in self._finished_ids: - if s < min_unfinshed: + if s < min_unfinished: if new_cur is None or new_cur < s: new_cur = s else: @@ -575,6 +622,10 @@ class MultiWriterIdGenerator: new_id, self._current_positions.get(instance_name, 0) ) + self._max_seen_allocated_stream_id = max( + self._max_seen_allocated_stream_id, new_id + ) + self._add_persisted_position(new_id) def get_persisted_upto_position(self) -> int: @@ -605,7 +656,11 @@ class MultiWriterIdGenerator: # to report a recent position when asked, rather than a potentially old # one (if this instance hasn't written anything for a while). our_current_position = self._current_positions.get(self._instance_name) - if our_current_position and not self._unfinished_ids: + if ( + our_current_position + and not self._unfinished_ids + and not self._in_flight_fetches + ): self._current_positions[self._instance_name] = max( our_current_position, new_id ) @@ -697,9 +752,6 @@ class _MultiWriterCtxManager: db_autocommit=True, ) - with self.id_gen._lock: - self.id_gen._unfinished_ids.update(self.stream_ids) - if self.multiple_ids is None: return self.stream_ids[0] * self.id_gen._return_factor else: