From f8712228801d9c4753e65819fa4241818b56953b Mon Sep 17 00:00:00 2001
From: reivilibre <oliverw@matrix.org>
Date: Fri, 1 Apr 2022 13:08:55 +0100
Subject: [PATCH] Move `update_client_ip` background job from the main process
 to the background worker. (#12251)

---
 changelog.d/12251.feature                     |   1 +
 synapse/app/admin_cmd.py                      |   2 -
 synapse/app/generic_worker.py                 |   2 -
 .../replication/slave/storage/client_ips.py   |  59 ----------
 synapse/replication/tcp/commands.py           |   8 +-
 synapse/replication/tcp/handler.py            |  46 +++++---
 synapse/storage/databases/main/__init__.py    |   8 +-
 synapse/storage/databases/main/client_ips.py  | 101 +++++++++++-------
 .../databases/main/monthly_active_users.py    |  60 ++++++-----
 .../storage/databases/main/registration.py    |  24 ++---
 10 files changed, 159 insertions(+), 152 deletions(-)
 create mode 100644 changelog.d/12251.feature
 delete mode 100644 synapse/replication/slave/storage/client_ips.py

diff --git a/changelog.d/12251.feature b/changelog.d/12251.feature
new file mode 100644
index 000000000..ba9ede03c
--- /dev/null
+++ b/changelog.d/12251.feature
@@ -0,0 +1 @@
+Offload the `update_client_ip` background job from the main process to the background worker, when using Redis-based replication.
diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py
index 6f8e33a15..2b0d92cba 100644
--- a/synapse/app/admin_cmd.py
+++ b/synapse/app/admin_cmd.py
@@ -33,7 +33,6 @@ from synapse.handlers.admin import ExfiltrationWriter
 from synapse.replication.slave.storage._base import BaseSlavedStore
 from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
 from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
-from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
 from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
 from synapse.replication.slave.storage.devices import SlavedDeviceStore
 from synapse.replication.slave.storage.events import SlavedEventStore
@@ -61,7 +60,6 @@ class AdminCmdSlavedStore(
     SlavedDeviceStore,
     SlavedPushRuleStore,
     SlavedEventStore,
-    SlavedClientIpStore,
     BaseSlavedStore,
     RoomWorkerStore,
 ):
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index b6f510ed3..1865c671f 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -53,7 +53,6 @@ from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource
 from synapse.replication.slave.storage._base import BaseSlavedStore
 from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
 from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
-from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
 from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
 from synapse.replication.slave.storage.devices import SlavedDeviceStore
 from synapse.replication.slave.storage.directory import DirectoryStore
@@ -247,7 +246,6 @@ class GenericWorkerSlavedStore(
     SlavedApplicationServiceStore,
     SlavedRegistrationStore,
     SlavedProfileStore,
-    SlavedClientIpStore,
     SlavedFilteringStore,
     MonthlyActiveUsersWorkerStore,
     MediaRepositoryStore,
diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py
deleted file mode 100644
index 14706a081..000000000
--- a/synapse/replication/slave/storage/client_ips.py
+++ /dev/null
@@ -1,59 +0,0 @@
-# Copyright 2017 Vector Creations Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from typing import TYPE_CHECKING
-
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
-from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
-from synapse.util.caches.lrucache import LruCache
-
-from ._base import BaseSlavedStore
-
-if TYPE_CHECKING:
-    from synapse.server import HomeServer
-
-
-class SlavedClientIpStore(BaseSlavedStore):
-    def __init__(
-        self,
-        database: DatabasePool,
-        db_conn: LoggingDatabaseConnection,
-        hs: "HomeServer",
-    ):
-        super().__init__(database, db_conn, hs)
-
-        self.client_ip_last_seen: LruCache[tuple, int] = LruCache(
-            cache_name="client_ip_last_seen", max_size=50000
-        )
-
-    async def insert_client_ip(
-        self, user_id: str, access_token: str, ip: str, user_agent: str, device_id: str
-    ) -> None:
-        now = int(self._clock.time_msec())
-        key = (user_id, access_token, ip)
-
-        try:
-            last_seen = self.client_ip_last_seen.get(key)
-        except KeyError:
-            last_seen = None
-
-        # Rate-limited inserts
-        if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
-            return
-
-        self.client_ip_last_seen.set(key, now)
-
-        self.hs.get_replication_command_handler().send_user_ip(
-            user_id, access_token, ip, user_agent, device_id, now
-        )
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 3654f6c03..fe3494816 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -356,7 +356,7 @@ class UserIpCommand(Command):
         access_token: str,
         ip: str,
         user_agent: str,
-        device_id: str,
+        device_id: Optional[str],
         last_seen: int,
     ):
         self.user_id = user_id
@@ -389,6 +389,12 @@ class UserIpCommand(Command):
             )
         )
 
+    def __repr__(self) -> str:
+        return (
+            f"UserIpCommand({self.user_id!r}, .., {self.ip!r}, "
+            f"{self.user_agent!r}, {self.device_id!r}, {self.last_seen})"
+        )
+
 
 class RemoteServerUpCommand(_SimpleCommand):
     """Sent when a worker has detected that a remote server is no longer
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index b217c35f9..615f1828d 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -235,6 +235,14 @@ class ReplicationCommandHandler:
         if self._is_master:
             self._server_notices_sender = hs.get_server_notices_sender()
 
+        if hs.config.redis.redis_enabled:
+            # If we're using Redis, it's the background worker that should
+            # receive USER_IP commands and store the relevant client IPs.
+            self._should_insert_client_ips = hs.config.worker.run_background_tasks
+        else:
+            # If we're NOT using Redis, this must be handled by the master
+            self._should_insert_client_ips = hs.get_instance_name() == "master"
+
     def _add_command_to_stream_queue(
         self, conn: IReplicationConnection, cmd: Union[RdataCommand, PositionCommand]
     ) -> None:
@@ -401,23 +409,37 @@ class ReplicationCommandHandler:
     ) -> Optional[Awaitable[None]]:
         user_ip_cache_counter.inc()
 
-        if self._is_master:
+        if self._is_master or self._should_insert_client_ips:
+            # We make a point of only returning an awaitable if there's actually
+            # something to do; on_USER_IP is not an async function, but
+            # _handle_user_ip is.
+            # If on_USER_IP returns an awaitable, it gets scheduled as a
+            # background process (see `BaseReplicationStreamProtocol.handle_command`).
             return self._handle_user_ip(cmd)
         else:
+            # Returning None when this process definitely has nothing to do
+            # reduces the overhead of handling the USER_IP command, which is
+            # currently broadcast to all workers regardless of utility.
             return None
 
     async def _handle_user_ip(self, cmd: UserIpCommand) -> None:
-        await self._store.insert_client_ip(
-            cmd.user_id,
-            cmd.access_token,
-            cmd.ip,
-            cmd.user_agent,
-            cmd.device_id,
-            cmd.last_seen,
-        )
+        """
+        Handles a User IP, branching depending on whether we are the main process
+        and/or the background worker.
+        """
+        if self._is_master:
+            assert self._server_notices_sender is not None
+            await self._server_notices_sender.on_user_ip(cmd.user_id)
 
-        assert self._server_notices_sender is not None
-        await self._server_notices_sender.on_user_ip(cmd.user_id)
+        if self._should_insert_client_ips:
+            await self._store.insert_client_ip(
+                cmd.user_id,
+                cmd.access_token,
+                cmd.ip,
+                cmd.user_agent,
+                cmd.device_id,
+                cmd.last_seen,
+            )
 
     def on_RDATA(self, conn: IReplicationConnection, cmd: RdataCommand) -> None:
         if cmd.instance_name == self._instance_name:
@@ -698,7 +720,7 @@ class ReplicationCommandHandler:
         access_token: str,
         ip: str,
         user_agent: str,
-        device_id: str,
+        device_id: Optional[str],
         last_seen: int,
     ) -> None:
         """Tell the master that the user made a request."""
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index f024761ba..1ea0b2aa6 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -33,7 +33,7 @@ from .account_data import AccountDataStore
 from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore
 from .cache import CacheInvalidationWorkerStore
 from .censor_events import CensorEventsStore
-from .client_ips import ClientIpStore
+from .client_ips import ClientIpWorkerStore
 from .deviceinbox import DeviceInboxStore
 from .devices import DeviceStore
 from .directory import DirectoryStore
@@ -49,7 +49,7 @@ from .keys import KeyStore
 from .lock import LockStore
 from .media_repository import MediaRepositoryStore
 from .metrics import ServerMetricsStore
-from .monthly_active_users import MonthlyActiveUsersStore
+from .monthly_active_users import MonthlyActiveUsersWorkerStore
 from .openid import OpenIdStore
 from .presence import PresenceStore
 from .profile import ProfileStore
@@ -112,13 +112,13 @@ class DataStore(
     AccountDataStore,
     EventPushActionsStore,
     OpenIdStore,
-    ClientIpStore,
+    ClientIpWorkerStore,
     DeviceStore,
     DeviceInboxStore,
     UserDirectoryStore,
     GroupServerStore,
     UserErasureStore,
-    MonthlyActiveUsersStore,
+    MonthlyActiveUsersWorkerStore,
     StatsStore,
     RelationsStore,
     CensorEventsStore,
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 8b0c614ec..8480ea4e1 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -25,7 +25,9 @@ from synapse.storage.database import (
     LoggingTransaction,
     make_tuple_comparison_clause,
 )
-from synapse.storage.databases.main.monthly_active_users import MonthlyActiveUsersStore
+from synapse.storage.databases.main.monthly_active_users import (
+    MonthlyActiveUsersWorkerStore,
+)
 from synapse.types import JsonDict, UserID
 from synapse.util.caches.lrucache import LruCache
 
@@ -397,7 +399,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
         return updated
 
 
-class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
+class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorkerStore):
     def __init__(
         self,
         database: DatabasePool,
@@ -406,11 +408,40 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
     ):
         super().__init__(database, db_conn, hs)
 
+        if hs.config.redis.redis_enabled:
+            # If we're using Redis, we can shift this update process off to
+            # the background worker
+            self._update_on_this_worker = hs.config.worker.run_background_tasks
+        else:
+            # If we're NOT using Redis, this must be handled by the master
+            self._update_on_this_worker = hs.get_instance_name() == "master"
+
         self.user_ips_max_age = hs.config.server.user_ips_max_age
 
+        # (user_id, access_token, ip,) -> last_seen
+        self.client_ip_last_seen = LruCache[Tuple[str, str, str], int](
+            cache_name="client_ip_last_seen", max_size=50000
+        )
+
         if hs.config.worker.run_background_tasks and self.user_ips_max_age:
             self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
 
+        if self._update_on_this_worker:
+            # This is the designated worker that can write to the client IP
+            # tables.
+
+            # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
+            self._batch_row_update: Dict[
+                Tuple[str, str, str], Tuple[str, Optional[str], int]
+            ] = {}
+
+            self._client_ip_looper = self._clock.looping_call(
+                self._update_client_ips_batch, 5 * 1000
+            )
+            self.hs.get_reactor().addSystemEventTrigger(
+                "before", "shutdown", self._update_client_ips_batch
+            )
+
     @wrap_as_background_process("prune_old_user_ips")
     async def _prune_old_user_ips(self) -> None:
         """Removes entries in user IPs older than the configured period."""
@@ -456,7 +487,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
             "_prune_old_user_ips", _prune_old_user_ips_txn
         )
 
-    async def get_last_client_ip_by_device(
+    async def _get_last_client_ip_by_device_from_database(
         self, user_id: str, device_id: Optional[str]
     ) -> Dict[Tuple[str, str], DeviceLastConnectionInfo]:
         """For each device_id listed, give the user_ip it was last seen on.
@@ -487,7 +518,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
 
         return {(d["user_id"], d["device_id"]): d for d in res}
 
-    async def get_user_ip_and_agents(
+    async def _get_user_ip_and_agents_from_database(
         self, user: UserID, since_ts: int = 0
     ) -> List[LastConnectionInfo]:
         """Fetch the IPs and user agents for a user since the given timestamp.
@@ -539,34 +570,6 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
             for access_token, ip, user_agent, last_seen in rows
         ]
 
-
-class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
-    def __init__(
-        self,
-        database: DatabasePool,
-        db_conn: LoggingDatabaseConnection,
-        hs: "HomeServer",
-    ):
-
-        # (user_id, access_token, ip,) -> last_seen
-        self.client_ip_last_seen = LruCache[Tuple[str, str, str], int](
-            cache_name="client_ip_last_seen", max_size=50000
-        )
-
-        super().__init__(database, db_conn, hs)
-
-        # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
-        self._batch_row_update: Dict[
-            Tuple[str, str, str], Tuple[str, Optional[str], int]
-        ] = {}
-
-        self._client_ip_looper = self._clock.looping_call(
-            self._update_client_ips_batch, 5 * 1000
-        )
-        self.hs.get_reactor().addSystemEventTrigger(
-            "before", "shutdown", self._update_client_ips_batch
-        )
-
     async def insert_client_ip(
         self,
         user_id: str,
@@ -584,17 +587,27 @@ class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
             last_seen = self.client_ip_last_seen.get(key)
         except KeyError:
             last_seen = None
-        await self.populate_monthly_active_users(user_id)
+
         # Rate-limited inserts
         if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
             return
 
         self.client_ip_last_seen.set(key, now)
 
-        self._batch_row_update[key] = (user_agent, device_id, now)
+        if self._update_on_this_worker:
+            await self.populate_monthly_active_users(user_id)
+            self._batch_row_update[key] = (user_agent, device_id, now)
+        else:
+            # We are not the designated writer-worker, so stream over replication
+            self.hs.get_replication_command_handler().send_user_ip(
+                user_id, access_token, ip, user_agent, device_id, now
+            )
 
     @wrap_as_background_process("update_client_ips")
     async def _update_client_ips_batch(self) -> None:
+        assert (
+            self._update_on_this_worker
+        ), "This worker is not designated to update client IPs"
 
         # If the DB pool has already terminated, don't try updating
         if not self.db_pool.is_running():
@@ -612,6 +625,10 @@ class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
         txn: LoggingTransaction,
         to_update: Mapping[Tuple[str, str, str], Tuple[str, Optional[str], int]],
     ) -> None:
+        assert (
+            self._update_on_this_worker
+        ), "This worker is not designated to update client IPs"
+
         if "user_ips" in self.db_pool._unsafe_to_upsert_tables or (
             not self.database_engine.can_native_upsert
         ):
@@ -662,7 +679,12 @@ class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
             A dictionary mapping a tuple of (user_id, device_id) to dicts, with
             keys giving the column names from the devices table.
         """
-        ret = await super().get_last_client_ip_by_device(user_id, device_id)
+        ret = await self._get_last_client_ip_by_device_from_database(user_id, device_id)
+
+        if not self._update_on_this_worker:
+            # Only the writing-worker has additional in-memory data to enhance
+            # the result
+            return ret
 
         # Update what is retrieved from the database with data which is pending
         # insertion, as if it has already been stored in the database.
@@ -707,9 +729,16 @@ class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
             Only the latest user agent for each access token and IP address combination
             is available.
         """
+        rows_from_db = await self._get_user_ip_and_agents_from_database(user, since_ts)
+
+        if not self._update_on_this_worker:
+            # Only the writing-worker has additional in-memory data to enhance
+            # the result
+            return rows_from_db
+
         results: Dict[Tuple[str, str], LastConnectionInfo] = {
             (connection["access_token"], connection["ip"]): connection
-            for connection in await super().get_user_ip_and_agents(user, since_ts)
+            for connection in rows_from_db
         }
 
         # Overlay data that is pending insertion on top of the results from the
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index 216622964..4f1c22c71 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -15,7 +15,6 @@ import logging
 from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
 
 from synapse.metrics.background_process_metrics import wrap_as_background_process
-from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import (
     DatabasePool,
     LoggingDatabaseConnection,
@@ -36,7 +35,7 @@ logger = logging.getLogger(__name__)
 LAST_SEEN_GRANULARITY = 60 * 60 * 1000
 
 
-class MonthlyActiveUsersWorkerStore(SQLBaseStore):
+class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore):
     def __init__(
         self,
         database: DatabasePool,
@@ -47,9 +46,30 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
         self._clock = hs.get_clock()
         self.hs = hs
 
+        if hs.config.redis.redis_enabled:
+            # If we're using Redis, we can shift this update process off to
+            # the background worker
+            self._update_on_this_worker = hs.config.worker.run_background_tasks
+        else:
+            # If we're NOT using Redis, this must be handled by the master
+            self._update_on_this_worker = hs.get_instance_name() == "master"
+
         self._limit_usage_by_mau = hs.config.server.limit_usage_by_mau
         self._max_mau_value = hs.config.server.max_mau_value
 
+        self._mau_stats_only = hs.config.server.mau_stats_only
+
+        if self._update_on_this_worker:
+            # Do not add more reserved users than the total allowable number
+            self.db_pool.new_transaction(
+                db_conn,
+                "initialise_mau_threepids",
+                [],
+                [],
+                self._initialise_reserved_users,
+                hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value],
+            )
+
     @cached(num_args=0)
     async def get_monthly_active_count(self) -> int:
         """Generates current count of monthly active users
@@ -222,28 +242,6 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
             "reap_monthly_active_users", _reap_users, reserved_users
         )
 
-
-class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerStore):
-    def __init__(
-        self,
-        database: DatabasePool,
-        db_conn: LoggingDatabaseConnection,
-        hs: "HomeServer",
-    ):
-        super().__init__(database, db_conn, hs)
-
-        self._mau_stats_only = hs.config.server.mau_stats_only
-
-        # Do not add more reserved users than the total allowable number
-        self.db_pool.new_transaction(
-            db_conn,
-            "initialise_mau_threepids",
-            [],
-            [],
-            self._initialise_reserved_users,
-            hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value],
-        )
-
     def _initialise_reserved_users(
         self, txn: LoggingTransaction, threepids: List[dict]
     ) -> None:
@@ -254,6 +252,9 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerS
             txn:
             threepids: List of threepid dicts to reserve
         """
+        assert (
+            self._update_on_this_worker
+        ), "This worker is not designated to update MAUs"
 
         # XXX what is this function trying to achieve?  It upserts into
         # monthly_active_users for each *registered* reserved mau user, but why?
@@ -287,6 +288,10 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerS
         Args:
             user_id: user to add/update
         """
+        assert (
+            self._update_on_this_worker
+        ), "This worker is not designated to update MAUs"
+
         # Support user never to be included in MAU stats. Note I can't easily call this
         # from upsert_monthly_active_user_txn because then I need a _txn form of
         # is_support_user which is complicated because I want to cache the result.
@@ -322,6 +327,9 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerS
             txn (cursor):
             user_id (str): user to add/update
         """
+        assert (
+            self._update_on_this_worker
+        ), "This worker is not designated to update MAUs"
 
         # Am consciously deciding to lock the table on the basis that is ought
         # never be a big table and alternative approaches (batching multiple
@@ -349,6 +357,10 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerS
         Args:
             user_id(str): the user_id to query
         """
+        assert (
+            self._update_on_this_worker
+        ), "This worker is not designated to update MAUs"
+
         if self._limit_usage_by_mau or self._mau_stats_only:
             # Trial users and guests should not be included as part of MAU group
             is_guest = await self.is_guest(user_id)  # type: ignore[attr-defined]
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 7f3d190e9..c7634c92f 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -1745,6 +1745,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             "replace_refresh_token", _replace_refresh_token_txn
         )
 
+    @cached()
+    async def is_guest(self, user_id: str) -> bool:
+        res = await self.db_pool.simple_select_one_onecol(
+            table="users",
+            keyvalues={"name": user_id},
+            retcol="is_guest",
+            allow_none=True,
+            desc="is_guest",
+        )
+
+        return res if res else False
+
 
 class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
     def __init__(
@@ -1887,18 +1899,6 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
         self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
         txn.call_after(self.is_guest.invalidate, (user_id,))
 
-    @cached()
-    async def is_guest(self, user_id: str) -> bool:
-        res = await self.db_pool.simple_select_one_onecol(
-            table="users",
-            keyvalues={"name": user_id},
-            retcol="is_guest",
-            allow_none=True,
-            desc="is_guest",
-        )
-
-        return res if res else False
-
 
 class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
     def __init__(