From 4f21c33be301b8ea6369039c3ad8baa51878e4d5 Mon Sep 17 00:00:00 2001
From: Erik Johnston <erik@matrix.org>
Date: Mon, 30 Mar 2020 16:37:24 +0100
Subject: [PATCH] Remove usage of "conn_id" for presence. (#7128)

* Remove `conn_id` usage for UserSyncCommand.

Each tcp replication connection is assigned a "conn_id", which is used
to give an ID to a remotely connected worker. In a redis world, there
will no longer be a one to one mapping between connection and instance,
so instead we need to replace such usages with an ID generated by the
remote instances and included in the replicaiton commands.

This really only effects UserSyncCommand.

* Add CLEAR_USER_SYNCS command that is sent on shutdown.

This should help with the case where a synchrotron gets restarted
gracefully, rather than rely on 5 minute timeout.
---
 changelog.d/7128.misc               |  1 +
 docs/tcp_replication.md             |  6 +++++
 synapse/app/generic_worker.py       | 20 ++++++++++++----
 synapse/replication/tcp/client.py   |  6 +++--
 synapse/replication/tcp/commands.py | 36 +++++++++++++++++++++++++----
 synapse/replication/tcp/protocol.py |  9 ++++++--
 synapse/replication/tcp/resource.py | 17 ++++++--------
 synapse/server.py                   | 11 +++++++++
 synapse/server.pyi                  |  2 ++
 9 files changed, 86 insertions(+), 22 deletions(-)
 create mode 100644 changelog.d/7128.misc

diff --git a/changelog.d/7128.misc b/changelog.d/7128.misc
new file mode 100644
index 000000000..5703f6d2e
--- /dev/null
+++ b/changelog.d/7128.misc
@@ -0,0 +1 @@
+Add explicit `instance_id` for USER_SYNC commands and remove implicit `conn_id` usage.
diff --git a/docs/tcp_replication.md b/docs/tcp_replication.md
index d4f7d9ec1..3be8e50c4 100644
--- a/docs/tcp_replication.md
+++ b/docs/tcp_replication.md
@@ -198,6 +198,12 @@ Asks the server for the current position of all streams.
 
    A user has started or stopped syncing
 
+#### CLEAR_USER_SYNC (C)
+
+   The server should clear all associated user sync data from the worker.
+
+   This is used when a worker is shutting down.
+
 #### FEDERATION_ACK (C)
 
    Acknowledge receipt of some federation data
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index fba7ad955..1ee266f7c 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -65,6 +65,7 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto
 from synapse.replication.slave.storage.room import RoomStore
 from synapse.replication.slave.storage.transactions import SlavedTransactionStore
 from synapse.replication.tcp.client import ReplicationClientHandler
+from synapse.replication.tcp.commands import ClearUserSyncsCommand
 from synapse.replication.tcp.streams import (
     AccountDataStream,
     DeviceListsStream,
@@ -124,7 +125,6 @@ from synapse.types import ReadReceipt
 from synapse.util.async_helpers import Linearizer
 from synapse.util.httpresourcetree import create_resource_tree
 from synapse.util.manhole import manhole
-from synapse.util.stringutils import random_string
 from synapse.util.versionstring import get_version_string
 
 logger = logging.getLogger("synapse.app.generic_worker")
@@ -233,6 +233,7 @@ class GenericWorkerPresence(object):
         self.user_to_num_current_syncs = {}
         self.clock = hs.get_clock()
         self.notifier = hs.get_notifier()
+        self.instance_id = hs.get_instance_id()
 
         active_presence = self.store.take_presence_startup_info()
         self.user_to_current_state = {state.user_id: state for state in active_presence}
@@ -245,13 +246,24 @@ class GenericWorkerPresence(object):
             self.send_stop_syncing, UPDATE_SYNCING_USERS_MS
         )
 
-        self.process_id = random_string(16)
-        logger.info("Presence process_id is %r", self.process_id)
+        hs.get_reactor().addSystemEventTrigger(
+            "before",
+            "shutdown",
+            run_as_background_process,
+            "generic_presence.on_shutdown",
+            self._on_shutdown,
+        )
+
+    def _on_shutdown(self):
+        if self.hs.config.use_presence:
+            self.hs.get_tcp_replication().send_command(
+                ClearUserSyncsCommand(self.instance_id)
+            )
 
     def send_user_sync(self, user_id, is_syncing, last_sync_ms):
         if self.hs.config.use_presence:
             self.hs.get_tcp_replication().send_user_sync(
-                user_id, is_syncing, last_sync_ms
+                self.instance_id, user_id, is_syncing, last_sync_ms
             )
 
     def mark_as_coming_online(self, user_id):
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 7e7ad0f79..e86d9805f 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -189,10 +189,12 @@ class ReplicationClientHandler(AbstractReplicationClientHandler):
         """
         self.send_command(FederationAckCommand(token))
 
-    def send_user_sync(self, user_id, is_syncing, last_sync_ms):
+    def send_user_sync(self, instance_id, user_id, is_syncing, last_sync_ms):
         """Poke the master that a user has started/stopped syncing.
         """
-        self.send_command(UserSyncCommand(user_id, is_syncing, last_sync_ms))
+        self.send_command(
+            UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms)
+        )
 
     def send_remove_pusher(self, app_id, push_key, user_id):
         """Poke the master to remove a pusher for a user
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 5a6b73409..e4eec643f 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -207,30 +207,32 @@ class UserSyncCommand(Command):
 
     Format::
 
-        USER_SYNC <user_id> <state> <last_sync_ms>
+        USER_SYNC <instance_id> <user_id> <state> <last_sync_ms>
 
     Where <state> is either "start" or "stop"
     """
 
     NAME = "USER_SYNC"
 
-    def __init__(self, user_id, is_syncing, last_sync_ms):
+    def __init__(self, instance_id, user_id, is_syncing, last_sync_ms):
+        self.instance_id = instance_id
         self.user_id = user_id
         self.is_syncing = is_syncing
         self.last_sync_ms = last_sync_ms
 
     @classmethod
     def from_line(cls, line):
-        user_id, state, last_sync_ms = line.split(" ", 2)
+        instance_id, user_id, state, last_sync_ms = line.split(" ", 3)
 
         if state not in ("start", "end"):
             raise Exception("Invalid USER_SYNC state %r" % (state,))
 
-        return cls(user_id, state == "start", int(last_sync_ms))
+        return cls(instance_id, user_id, state == "start", int(last_sync_ms))
 
     def to_line(self):
         return " ".join(
             (
+                self.instance_id,
                 self.user_id,
                 "start" if self.is_syncing else "end",
                 str(self.last_sync_ms),
@@ -238,6 +240,30 @@ class UserSyncCommand(Command):
         )
 
 
+class ClearUserSyncsCommand(Command):
+    """Sent by the client to inform the server that it should drop all
+    information about syncing users sent by the client.
+
+    Mainly used when client is about to shut down.
+
+    Format::
+
+        CLEAR_USER_SYNC <instance_id>
+    """
+
+    NAME = "CLEAR_USER_SYNC"
+
+    def __init__(self, instance_id):
+        self.instance_id = instance_id
+
+    @classmethod
+    def from_line(cls, line):
+        return cls(line)
+
+    def to_line(self):
+        return self.instance_id
+
+
 class FederationAckCommand(Command):
     """Sent by the client when it has processed up to a given point in the
     federation stream. This allows the master to drop in-memory caches of the
@@ -398,6 +424,7 @@ _COMMANDS = (
     InvalidateCacheCommand,
     UserIpCommand,
     RemoteServerUpCommand,
+    ClearUserSyncsCommand,
 )  # type: Tuple[Type[Command], ...]
 
 # Map of command name to command type.
@@ -420,6 +447,7 @@ VALID_CLIENT_COMMANDS = (
     ReplicateCommand.NAME,
     PingCommand.NAME,
     UserSyncCommand.NAME,
+    ClearUserSyncsCommand.NAME,
     FederationAckCommand.NAME,
     RemovePusherCommand.NAME,
     InvalidateCacheCommand.NAME,
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index f81d2e244..dae246825 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -423,9 +423,12 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
 
     async def on_USER_SYNC(self, cmd):
         await self.streamer.on_user_sync(
-            self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
+            cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
         )
 
+    async def on_CLEAR_USER_SYNC(self, cmd):
+        await self.streamer.on_clear_user_syncs(cmd.instance_id)
+
     async def on_REPLICATE(self, cmd):
         # Subscribe to all streams we're publishing to.
         for stream_name in self.streamer.streams_by_name:
@@ -551,6 +554,8 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
     ):
         BaseReplicationStreamProtocol.__init__(self, clock)
 
+        self.instance_id = hs.get_instance_id()
+
         self.client_name = client_name
         self.server_name = server_name
         self.handler = handler
@@ -580,7 +585,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
         currently_syncing = self.handler.get_currently_syncing_users()
         now = self.clock.time_msec()
         for user_id in currently_syncing:
-            self.send_command(UserSyncCommand(user_id, True, now))
+            self.send_command(UserSyncCommand(self.instance_id, user_id, True, now))
 
         # We've now finished connecting to so inform the client handler
         self.handler.update_connection(self)
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 4374e99e3..8b6067e20 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -251,14 +251,19 @@ class ReplicationStreamer(object):
             self.federation_sender.federation_ack(token)
 
     @measure_func("repl.on_user_sync")
-    async def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms):
+    async def on_user_sync(self, instance_id, user_id, is_syncing, last_sync_ms):
         """A client has started/stopped syncing on a worker.
         """
         user_sync_counter.inc()
         await self.presence_handler.update_external_syncs_row(
-            conn_id, user_id, is_syncing, last_sync_ms
+            instance_id, user_id, is_syncing, last_sync_ms
         )
 
+    async def on_clear_user_syncs(self, instance_id):
+        """A replication client wants us to drop all their UserSync data.
+        """
+        await self.presence_handler.update_external_syncs_clear(instance_id)
+
     @measure_func("repl.on_remove_pusher")
     async def on_remove_pusher(self, app_id, push_key, user_id):
         """A client has asked us to remove a pusher
@@ -321,14 +326,6 @@ class ReplicationStreamer(object):
         except ValueError:
             pass
 
-        # We need to tell the presence handler that the connection has been
-        # lost so that it can handle any ongoing syncs on that connection.
-        run_as_background_process(
-            "update_external_syncs_clear",
-            self.presence_handler.update_external_syncs_clear,
-            connection.conn_id,
-        )
-
 
 def _batch_updates(updates):
     """Takes a list of updates of form [(token, row)] and sets the token to
diff --git a/synapse/server.py b/synapse/server.py
index c7ca2bda0..cd86475d6 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -103,6 +103,7 @@ from synapse.storage import DataStores, Storage
 from synapse.streams.events import EventSources
 from synapse.util import Clock
 from synapse.util.distributor import Distributor
+from synapse.util.stringutils import random_string
 
 logger = logging.getLogger(__name__)
 
@@ -230,6 +231,8 @@ class HomeServer(object):
         self._listening_services = []
         self.start_time = None
 
+        self.instance_id = random_string(5)
+
         self.clock = Clock(reactor)
         self.distributor = Distributor()
         self.ratelimiter = Ratelimiter()
@@ -242,6 +245,14 @@ class HomeServer(object):
         for depname in kwargs:
             setattr(self, depname, kwargs[depname])
 
+    def get_instance_id(self):
+        """A unique ID for this synapse process instance.
+
+        This is used to distinguish running instances in worker-based
+        deployments.
+        """
+        return self.instance_id
+
     def setup(self):
         logger.info("Setting up.")
         self.start_time = int(self.get_clock().time())
diff --git a/synapse/server.pyi b/synapse/server.pyi
index 3844f0e12..9d1dfa71e 100644
--- a/synapse/server.pyi
+++ b/synapse/server.pyi
@@ -114,3 +114,5 @@ class HomeServer(object):
         pass
     def is_mine_id(self, domain_id: str) -> bool:
         pass
+    def get_instance_id(self) -> str:
+        pass