Move update_client_ip background job from the main process to the background worker. ()

This commit is contained in:
reivilibre 2022-04-01 13:08:55 +01:00 committed by GitHub
parent 319a805cd3
commit f871222880
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 159 additions and 152 deletions

View file

@ -0,0 +1 @@
Offload the `update_client_ip` background job from the main process to the background worker, when using Redis-based replication.

View file

@ -33,7 +33,6 @@ from synapse.handlers.admin import ExfiltrationWriter
from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
from synapse.replication.slave.storage.devices import SlavedDeviceStore from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.events import SlavedEventStore
@ -61,7 +60,6 @@ class AdminCmdSlavedStore(
SlavedDeviceStore, SlavedDeviceStore,
SlavedPushRuleStore, SlavedPushRuleStore,
SlavedEventStore, SlavedEventStore,
SlavedClientIpStore,
BaseSlavedStore, BaseSlavedStore,
RoomWorkerStore, RoomWorkerStore,
): ):

View file

@ -53,7 +53,6 @@ from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource
from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
from synapse.replication.slave.storage.devices import SlavedDeviceStore from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.replication.slave.storage.directory import DirectoryStore from synapse.replication.slave.storage.directory import DirectoryStore
@ -247,7 +246,6 @@ class GenericWorkerSlavedStore(
SlavedApplicationServiceStore, SlavedApplicationServiceStore,
SlavedRegistrationStore, SlavedRegistrationStore,
SlavedProfileStore, SlavedProfileStore,
SlavedClientIpStore,
SlavedFilteringStore, SlavedFilteringStore,
MonthlyActiveUsersWorkerStore, MonthlyActiveUsersWorkerStore,
MediaRepositoryStore, MediaRepositoryStore,

View file

@ -1,59 +0,0 @@
# Copyright 2017 Vector Creations Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
from synapse.util.caches.lrucache import LruCache
from ._base import BaseSlavedStore
if TYPE_CHECKING:
from synapse.server import HomeServer
class SlavedClientIpStore(BaseSlavedStore):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self.client_ip_last_seen: LruCache[tuple, int] = LruCache(
cache_name="client_ip_last_seen", max_size=50000
)
async def insert_client_ip(
self, user_id: str, access_token: str, ip: str, user_agent: str, device_id: str
) -> None:
now = int(self._clock.time_msec())
key = (user_id, access_token, ip)
try:
last_seen = self.client_ip_last_seen.get(key)
except KeyError:
last_seen = None
# Rate-limited inserts
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
return
self.client_ip_last_seen.set(key, now)
self.hs.get_replication_command_handler().send_user_ip(
user_id, access_token, ip, user_agent, device_id, now
)

View file

@ -356,7 +356,7 @@ class UserIpCommand(Command):
access_token: str, access_token: str,
ip: str, ip: str,
user_agent: str, user_agent: str,
device_id: str, device_id: Optional[str],
last_seen: int, last_seen: int,
): ):
self.user_id = user_id self.user_id = user_id
@ -389,6 +389,12 @@ class UserIpCommand(Command):
) )
) )
def __repr__(self) -> str:
return (
f"UserIpCommand({self.user_id!r}, .., {self.ip!r}, "
f"{self.user_agent!r}, {self.device_id!r}, {self.last_seen})"
)
class RemoteServerUpCommand(_SimpleCommand): class RemoteServerUpCommand(_SimpleCommand):
"""Sent when a worker has detected that a remote server is no longer """Sent when a worker has detected that a remote server is no longer

View file

@ -235,6 +235,14 @@ class ReplicationCommandHandler:
if self._is_master: if self._is_master:
self._server_notices_sender = hs.get_server_notices_sender() self._server_notices_sender = hs.get_server_notices_sender()
if hs.config.redis.redis_enabled:
# If we're using Redis, it's the background worker that should
# receive USER_IP commands and store the relevant client IPs.
self._should_insert_client_ips = hs.config.worker.run_background_tasks
else:
# If we're NOT using Redis, this must be handled by the master
self._should_insert_client_ips = hs.get_instance_name() == "master"
def _add_command_to_stream_queue( def _add_command_to_stream_queue(
self, conn: IReplicationConnection, cmd: Union[RdataCommand, PositionCommand] self, conn: IReplicationConnection, cmd: Union[RdataCommand, PositionCommand]
) -> None: ) -> None:
@ -401,23 +409,37 @@ class ReplicationCommandHandler:
) -> Optional[Awaitable[None]]: ) -> Optional[Awaitable[None]]:
user_ip_cache_counter.inc() user_ip_cache_counter.inc()
if self._is_master: if self._is_master or self._should_insert_client_ips:
# We make a point of only returning an awaitable if there's actually
# something to do; on_USER_IP is not an async function, but
# _handle_user_ip is.
# If on_USER_IP returns an awaitable, it gets scheduled as a
# background process (see `BaseReplicationStreamProtocol.handle_command`).
return self._handle_user_ip(cmd) return self._handle_user_ip(cmd)
else: else:
# Returning None when this process definitely has nothing to do
# reduces the overhead of handling the USER_IP command, which is
# currently broadcast to all workers regardless of utility.
return None return None
async def _handle_user_ip(self, cmd: UserIpCommand) -> None: async def _handle_user_ip(self, cmd: UserIpCommand) -> None:
await self._store.insert_client_ip( """
cmd.user_id, Handles a User IP, branching depending on whether we are the main process
cmd.access_token, and/or the background worker.
cmd.ip, """
cmd.user_agent, if self._is_master:
cmd.device_id, assert self._server_notices_sender is not None
cmd.last_seen, await self._server_notices_sender.on_user_ip(cmd.user_id)
)
assert self._server_notices_sender is not None if self._should_insert_client_ips:
await self._server_notices_sender.on_user_ip(cmd.user_id) await self._store.insert_client_ip(
cmd.user_id,
cmd.access_token,
cmd.ip,
cmd.user_agent,
cmd.device_id,
cmd.last_seen,
)
def on_RDATA(self, conn: IReplicationConnection, cmd: RdataCommand) -> None: def on_RDATA(self, conn: IReplicationConnection, cmd: RdataCommand) -> None:
if cmd.instance_name == self._instance_name: if cmd.instance_name == self._instance_name:
@ -698,7 +720,7 @@ class ReplicationCommandHandler:
access_token: str, access_token: str,
ip: str, ip: str,
user_agent: str, user_agent: str,
device_id: str, device_id: Optional[str],
last_seen: int, last_seen: int,
) -> None: ) -> None:
"""Tell the master that the user made a request.""" """Tell the master that the user made a request."""

View file

@ -33,7 +33,7 @@ from .account_data import AccountDataStore
from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore
from .cache import CacheInvalidationWorkerStore from .cache import CacheInvalidationWorkerStore
from .censor_events import CensorEventsStore from .censor_events import CensorEventsStore
from .client_ips import ClientIpStore from .client_ips import ClientIpWorkerStore
from .deviceinbox import DeviceInboxStore from .deviceinbox import DeviceInboxStore
from .devices import DeviceStore from .devices import DeviceStore
from .directory import DirectoryStore from .directory import DirectoryStore
@ -49,7 +49,7 @@ from .keys import KeyStore
from .lock import LockStore from .lock import LockStore
from .media_repository import MediaRepositoryStore from .media_repository import MediaRepositoryStore
from .metrics import ServerMetricsStore from .metrics import ServerMetricsStore
from .monthly_active_users import MonthlyActiveUsersStore from .monthly_active_users import MonthlyActiveUsersWorkerStore
from .openid import OpenIdStore from .openid import OpenIdStore
from .presence import PresenceStore from .presence import PresenceStore
from .profile import ProfileStore from .profile import ProfileStore
@ -112,13 +112,13 @@ class DataStore(
AccountDataStore, AccountDataStore,
EventPushActionsStore, EventPushActionsStore,
OpenIdStore, OpenIdStore,
ClientIpStore, ClientIpWorkerStore,
DeviceStore, DeviceStore,
DeviceInboxStore, DeviceInboxStore,
UserDirectoryStore, UserDirectoryStore,
GroupServerStore, GroupServerStore,
UserErasureStore, UserErasureStore,
MonthlyActiveUsersStore, MonthlyActiveUsersWorkerStore,
StatsStore, StatsStore,
RelationsStore, RelationsStore,
CensorEventsStore, CensorEventsStore,

View file

@ -25,7 +25,9 @@ from synapse.storage.database import (
LoggingTransaction, LoggingTransaction,
make_tuple_comparison_clause, make_tuple_comparison_clause,
) )
from synapse.storage.databases.main.monthly_active_users import MonthlyActiveUsersStore from synapse.storage.databases.main.monthly_active_users import (
MonthlyActiveUsersWorkerStore,
)
from synapse.types import JsonDict, UserID from synapse.types import JsonDict, UserID
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
@ -397,7 +399,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
return updated return updated
class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorkerStore):
def __init__( def __init__(
self, self,
database: DatabasePool, database: DatabasePool,
@ -406,11 +408,40 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
): ):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
if hs.config.redis.redis_enabled:
# If we're using Redis, we can shift this update process off to
# the background worker
self._update_on_this_worker = hs.config.worker.run_background_tasks
else:
# If we're NOT using Redis, this must be handled by the master
self._update_on_this_worker = hs.get_instance_name() == "master"
self.user_ips_max_age = hs.config.server.user_ips_max_age self.user_ips_max_age = hs.config.server.user_ips_max_age
# (user_id, access_token, ip,) -> last_seen
self.client_ip_last_seen = LruCache[Tuple[str, str, str], int](
cache_name="client_ip_last_seen", max_size=50000
)
if hs.config.worker.run_background_tasks and self.user_ips_max_age: if hs.config.worker.run_background_tasks and self.user_ips_max_age:
self._clock.looping_call(self._prune_old_user_ips, 5 * 1000) self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
if self._update_on_this_worker:
# This is the designated worker that can write to the client IP
# tables.
# (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
self._batch_row_update: Dict[
Tuple[str, str, str], Tuple[str, Optional[str], int]
] = {}
self._client_ip_looper = self._clock.looping_call(
self._update_client_ips_batch, 5 * 1000
)
self.hs.get_reactor().addSystemEventTrigger(
"before", "shutdown", self._update_client_ips_batch
)
@wrap_as_background_process("prune_old_user_ips") @wrap_as_background_process("prune_old_user_ips")
async def _prune_old_user_ips(self) -> None: async def _prune_old_user_ips(self) -> None:
"""Removes entries in user IPs older than the configured period.""" """Removes entries in user IPs older than the configured period."""
@ -456,7 +487,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
"_prune_old_user_ips", _prune_old_user_ips_txn "_prune_old_user_ips", _prune_old_user_ips_txn
) )
async def get_last_client_ip_by_device( async def _get_last_client_ip_by_device_from_database(
self, user_id: str, device_id: Optional[str] self, user_id: str, device_id: Optional[str]
) -> Dict[Tuple[str, str], DeviceLastConnectionInfo]: ) -> Dict[Tuple[str, str], DeviceLastConnectionInfo]:
"""For each device_id listed, give the user_ip it was last seen on. """For each device_id listed, give the user_ip it was last seen on.
@ -487,7 +518,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
return {(d["user_id"], d["device_id"]): d for d in res} return {(d["user_id"], d["device_id"]): d for d in res}
async def get_user_ip_and_agents( async def _get_user_ip_and_agents_from_database(
self, user: UserID, since_ts: int = 0 self, user: UserID, since_ts: int = 0
) -> List[LastConnectionInfo]: ) -> List[LastConnectionInfo]:
"""Fetch the IPs and user agents for a user since the given timestamp. """Fetch the IPs and user agents for a user since the given timestamp.
@ -539,34 +570,6 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
for access_token, ip, user_agent, last_seen in rows for access_token, ip, user_agent, last_seen in rows
] ]
class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
# (user_id, access_token, ip,) -> last_seen
self.client_ip_last_seen = LruCache[Tuple[str, str, str], int](
cache_name="client_ip_last_seen", max_size=50000
)
super().__init__(database, db_conn, hs)
# (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
self._batch_row_update: Dict[
Tuple[str, str, str], Tuple[str, Optional[str], int]
] = {}
self._client_ip_looper = self._clock.looping_call(
self._update_client_ips_batch, 5 * 1000
)
self.hs.get_reactor().addSystemEventTrigger(
"before", "shutdown", self._update_client_ips_batch
)
async def insert_client_ip( async def insert_client_ip(
self, self,
user_id: str, user_id: str,
@ -584,17 +587,27 @@ class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
last_seen = self.client_ip_last_seen.get(key) last_seen = self.client_ip_last_seen.get(key)
except KeyError: except KeyError:
last_seen = None last_seen = None
await self.populate_monthly_active_users(user_id)
# Rate-limited inserts # Rate-limited inserts
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY: if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
return return
self.client_ip_last_seen.set(key, now) self.client_ip_last_seen.set(key, now)
self._batch_row_update[key] = (user_agent, device_id, now) if self._update_on_this_worker:
await self.populate_monthly_active_users(user_id)
self._batch_row_update[key] = (user_agent, device_id, now)
else:
# We are not the designated writer-worker, so stream over replication
self.hs.get_replication_command_handler().send_user_ip(
user_id, access_token, ip, user_agent, device_id, now
)
@wrap_as_background_process("update_client_ips") @wrap_as_background_process("update_client_ips")
async def _update_client_ips_batch(self) -> None: async def _update_client_ips_batch(self) -> None:
assert (
self._update_on_this_worker
), "This worker is not designated to update client IPs"
# If the DB pool has already terminated, don't try updating # If the DB pool has already terminated, don't try updating
if not self.db_pool.is_running(): if not self.db_pool.is_running():
@ -612,6 +625,10 @@ class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
txn: LoggingTransaction, txn: LoggingTransaction,
to_update: Mapping[Tuple[str, str, str], Tuple[str, Optional[str], int]], to_update: Mapping[Tuple[str, str, str], Tuple[str, Optional[str], int]],
) -> None: ) -> None:
assert (
self._update_on_this_worker
), "This worker is not designated to update client IPs"
if "user_ips" in self.db_pool._unsafe_to_upsert_tables or ( if "user_ips" in self.db_pool._unsafe_to_upsert_tables or (
not self.database_engine.can_native_upsert not self.database_engine.can_native_upsert
): ):
@ -662,7 +679,12 @@ class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
A dictionary mapping a tuple of (user_id, device_id) to dicts, with A dictionary mapping a tuple of (user_id, device_id) to dicts, with
keys giving the column names from the devices table. keys giving the column names from the devices table.
""" """
ret = await super().get_last_client_ip_by_device(user_id, device_id) ret = await self._get_last_client_ip_by_device_from_database(user_id, device_id)
if not self._update_on_this_worker:
# Only the writing-worker has additional in-memory data to enhance
# the result
return ret
# Update what is retrieved from the database with data which is pending # Update what is retrieved from the database with data which is pending
# insertion, as if it has already been stored in the database. # insertion, as if it has already been stored in the database.
@ -707,9 +729,16 @@ class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
Only the latest user agent for each access token and IP address combination Only the latest user agent for each access token and IP address combination
is available. is available.
""" """
rows_from_db = await self._get_user_ip_and_agents_from_database(user, since_ts)
if not self._update_on_this_worker:
# Only the writing-worker has additional in-memory data to enhance
# the result
return rows_from_db
results: Dict[Tuple[str, str], LastConnectionInfo] = { results: Dict[Tuple[str, str], LastConnectionInfo] = {
(connection["access_token"], connection["ip"]): connection (connection["access_token"], connection["ip"]): connection
for connection in await super().get_user_ip_and_agents(user, since_ts) for connection in rows_from_db
} }
# Overlay data that is pending insertion on top of the results from the # Overlay data that is pending insertion on top of the results from the

View file

@ -15,7 +15,6 @@ import logging
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import ( from synapse.storage.database import (
DatabasePool, DatabasePool,
LoggingDatabaseConnection, LoggingDatabaseConnection,
@ -36,7 +35,7 @@ logger = logging.getLogger(__name__)
LAST_SEEN_GRANULARITY = 60 * 60 * 1000 LAST_SEEN_GRANULARITY = 60 * 60 * 1000
class MonthlyActiveUsersWorkerStore(SQLBaseStore): class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore):
def __init__( def __init__(
self, self,
database: DatabasePool, database: DatabasePool,
@ -47,9 +46,30 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
self._clock = hs.get_clock() self._clock = hs.get_clock()
self.hs = hs self.hs = hs
if hs.config.redis.redis_enabled:
# If we're using Redis, we can shift this update process off to
# the background worker
self._update_on_this_worker = hs.config.worker.run_background_tasks
else:
# If we're NOT using Redis, this must be handled by the master
self._update_on_this_worker = hs.get_instance_name() == "master"
self._limit_usage_by_mau = hs.config.server.limit_usage_by_mau self._limit_usage_by_mau = hs.config.server.limit_usage_by_mau
self._max_mau_value = hs.config.server.max_mau_value self._max_mau_value = hs.config.server.max_mau_value
self._mau_stats_only = hs.config.server.mau_stats_only
if self._update_on_this_worker:
# Do not add more reserved users than the total allowable number
self.db_pool.new_transaction(
db_conn,
"initialise_mau_threepids",
[],
[],
self._initialise_reserved_users,
hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value],
)
@cached(num_args=0) @cached(num_args=0)
async def get_monthly_active_count(self) -> int: async def get_monthly_active_count(self) -> int:
"""Generates current count of monthly active users """Generates current count of monthly active users
@ -222,28 +242,6 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
"reap_monthly_active_users", _reap_users, reserved_users "reap_monthly_active_users", _reap_users, reserved_users
) )
class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerStore):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self._mau_stats_only = hs.config.server.mau_stats_only
# Do not add more reserved users than the total allowable number
self.db_pool.new_transaction(
db_conn,
"initialise_mau_threepids",
[],
[],
self._initialise_reserved_users,
hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value],
)
def _initialise_reserved_users( def _initialise_reserved_users(
self, txn: LoggingTransaction, threepids: List[dict] self, txn: LoggingTransaction, threepids: List[dict]
) -> None: ) -> None:
@ -254,6 +252,9 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerS
txn: txn:
threepids: List of threepid dicts to reserve threepids: List of threepid dicts to reserve
""" """
assert (
self._update_on_this_worker
), "This worker is not designated to update MAUs"
# XXX what is this function trying to achieve? It upserts into # XXX what is this function trying to achieve? It upserts into
# monthly_active_users for each *registered* reserved mau user, but why? # monthly_active_users for each *registered* reserved mau user, but why?
@ -287,6 +288,10 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerS
Args: Args:
user_id: user to add/update user_id: user to add/update
""" """
assert (
self._update_on_this_worker
), "This worker is not designated to update MAUs"
# Support user never to be included in MAU stats. Note I can't easily call this # Support user never to be included in MAU stats. Note I can't easily call this
# from upsert_monthly_active_user_txn because then I need a _txn form of # from upsert_monthly_active_user_txn because then I need a _txn form of
# is_support_user which is complicated because I want to cache the result. # is_support_user which is complicated because I want to cache the result.
@ -322,6 +327,9 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerS
txn (cursor): txn (cursor):
user_id (str): user to add/update user_id (str): user to add/update
""" """
assert (
self._update_on_this_worker
), "This worker is not designated to update MAUs"
# Am consciously deciding to lock the table on the basis that is ought # Am consciously deciding to lock the table on the basis that is ought
# never be a big table and alternative approaches (batching multiple # never be a big table and alternative approaches (batching multiple
@ -349,6 +357,10 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerS
Args: Args:
user_id(str): the user_id to query user_id(str): the user_id to query
""" """
assert (
self._update_on_this_worker
), "This worker is not designated to update MAUs"
if self._limit_usage_by_mau or self._mau_stats_only: if self._limit_usage_by_mau or self._mau_stats_only:
# Trial users and guests should not be included as part of MAU group # Trial users and guests should not be included as part of MAU group
is_guest = await self.is_guest(user_id) # type: ignore[attr-defined] is_guest = await self.is_guest(user_id) # type: ignore[attr-defined]

View file

@ -1745,6 +1745,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"replace_refresh_token", _replace_refresh_token_txn "replace_refresh_token", _replace_refresh_token_txn
) )
@cached()
async def is_guest(self, user_id: str) -> bool:
res = await self.db_pool.simple_select_one_onecol(
table="users",
keyvalues={"name": user_id},
retcol="is_guest",
allow_none=True,
desc="is_guest",
)
return res if res else False
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
def __init__( def __init__(
@ -1887,18 +1899,6 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
txn.call_after(self.is_guest.invalidate, (user_id,)) txn.call_after(self.is_guest.invalidate, (user_id,))
@cached()
async def is_guest(self, user_id: str) -> bool:
res = await self.db_pool.simple_select_one_onecol(
table="users",
keyvalues={"name": user_id},
retcol="is_guest",
allow_none=True,
desc="is_guest",
)
return res if res else False
class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
def __init__( def __init__(