From 316590d1ea273115a9e7925236e02d577a231de4 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 17 Jan 2023 09:58:22 +0000 Subject: [PATCH] Fix bug in `wait_for_stream_position` (#14856) We were incorrectly checking if the *local* token had been advanced, rather than the token for the remote instance. In practice, I don't think this has caused any bugs due to where we use `wait_for_stream_position`, as critically we don't use it on instances that also write to the given streams (and so the local token will lag behind all remote tokens). --- changelog.d/14856.misc | 1 + synapse/replication/tcp/client.py | 2 +- tests/replication/tcp/test_handler.py | 78 +++++++++++++++++++++++++++ 3 files changed, 80 insertions(+), 1 deletion(-) create mode 100644 changelog.d/14856.misc diff --git a/changelog.d/14856.misc b/changelog.d/14856.misc new file mode 100644 index 000000000..3731d6cbf --- /dev/null +++ b/changelog.d/14856.misc @@ -0,0 +1 @@ +Fix `wait_for_stream_position` to correctly wait for the right instance to advance its token. diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 31022ce5f..322d695bc 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -325,7 +325,7 @@ class ReplicationDataHandler: # anyway in that case we don't need to wait. return - current_position = self._streams[stream_name].current_token(self._instance_name) + current_position = self._streams[stream_name].current_token(instance_name) if position <= current_position: # We're already past the position return diff --git a/tests/replication/tcp/test_handler.py b/tests/replication/tcp/test_handler.py index 1e299d2d6..555922409 100644 --- a/tests/replication/tcp/test_handler.py +++ b/tests/replication/tcp/test_handler.py @@ -12,6 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.internet import defer + +from synapse.replication.tcp.commands import PositionCommand, RdataCommand + from tests.replication._base import BaseMultiWorkerStreamTestCase @@ -71,3 +75,77 @@ class ChannelsTestCase(BaseMultiWorkerStreamTestCase): self.assertEqual( len(self._redis_server._subscribers_by_channel[b"test/USER_IP"]), 1 ) + + def test_wait_for_stream_position(self) -> None: + """Check that wait for stream position correctly waits for an update from the + correct instance. + """ + store = self.hs.get_datastores().main + cmd_handler = self.hs.get_replication_command_handler() + data_handler = self.hs.get_replication_data_handler() + + worker1 = self.make_worker_hs( + "synapse.app.generic_worker", + extra_config={ + "worker_name": "worker1", + "run_background_tasks_on": "worker1", + "redis": {"enabled": True}, + }, + ) + + cache_id_gen = worker1.get_datastores().main._cache_id_gen + assert cache_id_gen is not None + + self.replicate() + + # First, make sure the master knows that `worker1` exists. + initial_token = cache_id_gen.get_current_token() + cmd_handler.send_command( + PositionCommand("caches", "worker1", initial_token, initial_token) + ) + self.replicate() + + # Next send out a normal RDATA, and check that waiting for that stream + # ID returns immediately. + ctx = cache_id_gen.get_next() + next_token = self.get_success(ctx.__aenter__()) + self.get_success(ctx.__aexit__(None, None, None)) + + cmd_handler.send_command( + RdataCommand("caches", "worker1", next_token, ("func_name", [], 0)) + ) + self.replicate() + + self.get_success( + data_handler.wait_for_stream_position("worker1", "caches", next_token) + ) + + # `wait_for_stream_position` should only return once master receives an + # RDATA from the worker + ctx = cache_id_gen.get_next() + next_token = self.get_success(ctx.__aenter__()) + self.get_success(ctx.__aexit__(None, None, None)) + + d = defer.ensureDeferred( + data_handler.wait_for_stream_position("worker1", "caches", next_token) + ) + self.assertFalse(d.called) + + # ... updating the cache ID gen on the master still shouldn't cause the + # deferred to wake up. + ctx = store._cache_id_gen.get_next() + self.get_success(ctx.__aenter__()) + self.get_success(ctx.__aexit__(None, None, None)) + + d = defer.ensureDeferred( + data_handler.wait_for_stream_position("worker1", "caches", next_token) + ) + self.assertFalse(d.called) + + # ... but receiving the RDATA should + cmd_handler.send_command( + RdataCommand("caches", "worker1", next_token, ("func_name", [], 0)) + ) + self.replicate() + + self.assertTrue(d.called)