Don't pull out the full state when calculating push actions (#13078)

This commit is contained in:
Erik Johnston 2022-07-11 21:08:39 +01:00 committed by GitHub
parent bc8eefc1e1
commit e5716b631c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 164 additions and 344 deletions

1
changelog.d/13078.misc Normal file
View file

@ -0,0 +1 @@
Reduce memory consumption when processing incoming events in large rooms.

View file

@ -1 +1 @@
Improve memory usage of calculating push actions for events in large rooms.
Reduce memory consumption when processing incoming events in large rooms.

View file

@ -17,7 +17,6 @@ import itertools
import logging
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import attr
from prometheus_client import Counter
from synapse.api.constants import EventTypes, Membership, RelationTypes
@ -26,13 +25,11 @@ from synapse.events import EventBase, relation_from_event
from synapse.events.snapshot import EventContext
from synapse.state import POWER_KEY
from synapse.storage.databases.main.roommember import EventIdMembership
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import CacheMetric, register_cache
from synapse.util.caches.descriptors import lru_cache
from synapse.util.caches.lrucache import LruCache
from synapse.storage.state import StateFilter
from synapse.util.caches import register_cache
from synapse.util.metrics import measure_func
from synapse.visibility import filter_event_for_clients_with_state
from ..storage.state import StateFilter
from .push_rule_evaluator import PushRuleEvaluatorForEvent
if TYPE_CHECKING:
@ -48,15 +45,6 @@ push_rules_state_size_counter = Counter(
"synapse_push_bulk_push_rule_evaluator_push_rules_state_size_counter", ""
)
# Measures whether we use the fast path of using state deltas, or if we have to
# recalculate from scratch
push_rules_delta_state_cache_metric = register_cache(
"cache",
"push_rules_delta_state_cache_metric",
cache=[], # Meaningless size, as this isn't a cache that stores values
resizable=False,
)
STATE_EVENT_TYPES_TO_MARK_UNREAD = {
EventTypes.Topic,
@ -111,10 +99,6 @@ class BulkPushRuleEvaluator:
self.clock = hs.get_clock()
self._event_auth_handler = hs.get_event_auth_handler()
# Used by `RulesForRoom` to ensure only one thing mutates the cache at a
# time. Keyed off room_id.
self._rules_linearizer = Linearizer(name="rules_for_room")
self.room_push_rule_cache_metrics = register_cache(
"cache",
"room_push_rule_cache",
@ -126,48 +110,48 @@ class BulkPushRuleEvaluator:
self._relations_match_enabled = self.hs.config.experimental.msc3772_enabled
async def _get_rules_for_event(
self, event: EventBase, context: EventContext
self,
event: EventBase,
) -> Dict[str, List[Dict[str, Any]]]:
"""This gets the rules for all users in the room at the time of the event,
as well as the push rules for the invitee if the event is an invite.
"""Get the push rules for all users who may need to be notified about
the event.
Note: this does not check if the user is allowed to see the event.
Returns:
dict of user_id -> push_rules
Mapping of user ID to their push rules.
"""
room_id = event.room_id
# We get the users who may need to be notified by first fetching the
# local users currently in the room, finding those that have push rules,
# and *then* checking which users are actually allowed to see the event.
#
# The alternative is to first fetch all users that were joined at the
# event, but that requires fetching the full state at the event, which
# may be expensive for large rooms with few local users.
rules_for_room_data = self._get_rules_for_room(room_id)
rules_for_room = RulesForRoom(
hs=self.hs,
room_id=room_id,
rules_for_room_cache=self._get_rules_for_room.cache,
room_push_rule_cache_metrics=self.room_push_rule_cache_metrics,
linearizer=self._rules_linearizer,
cached_data=rules_for_room_data,
)
rules_by_user = await rules_for_room.get_rules(event, context)
local_users = await self.store.get_local_users_in_room(event.room_id)
# if this event is an invite event, we may need to run rules for the user
# who's been invited, otherwise they won't get told they've been invited
if event.type == "m.room.member" and event.content["membership"] == "invite":
if event.type == EventTypes.Member and event.membership == Membership.INVITE:
invited = event.state_key
if invited and self.hs.is_mine_id(invited):
rules_by_user = dict(rules_by_user)
rules_by_user[invited] = await self.store.get_push_rules_for_user(
invited
)
if invited and self.hs.is_mine_id(invited) and invited not in local_users:
local_users = list(local_users)
local_users.append(invited)
rules_by_user = await self.store.bulk_get_push_rules(local_users)
logger.debug("Users in room: %s", local_users)
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"Returning push rules for %r %r",
event.room_id,
list(rules_by_user.keys()),
)
return rules_by_user
@lru_cache()
def _get_rules_for_room(self, room_id: str) -> "RulesForRoomData":
"""Get the current RulesForRoomData object for the given room id"""
# It's important that the RulesForRoomData object gets added to self._get_rules_for_room.cache
# before any lookup methods get called on it as otherwise there may be
# a race if invalidate_all gets called (which assumes its in the cache)
return RulesForRoomData()
async def _get_power_levels_and_sender_level(
self, event: EventBase, context: EventContext
) -> Tuple[dict, int]:
@ -262,10 +246,12 @@ class BulkPushRuleEvaluator:
count_as_unread = _should_count_as_unread(event, context)
rules_by_user = await self._get_rules_for_event(event, context)
rules_by_user = await self._get_rules_for_event(event)
actions_by_user: Dict[str, List[Union[dict, str]]] = {}
room_members = await self.store.get_joined_users_from_context(event, context)
room_member_count = await self.store.get_number_joined_users_in_room(
event.room_id
)
(
power_levels,
@ -278,30 +264,36 @@ class BulkPushRuleEvaluator:
evaluator = PushRuleEvaluatorForEvent(
event,
len(room_members),
room_member_count,
sender_power_level,
power_levels,
relations,
self._relations_match_enabled,
)
# If the event is not a state event check if any users ignore the sender.
if not event.is_state():
ignorers = await self.store.ignored_by(event.sender)
else:
ignorers = frozenset()
users = rules_by_user.keys()
profiles = await self.store.get_subset_users_in_room_with_profiles(
event.room_id, users
)
# This is a check for the case where user joins a room without being
# allowed to see history, and then the server receives a delayed event
# from before the user joined, which they should not be pushed for
uids_with_visibility = await filter_event_for_clients_with_state(
self.store, users, event, context
)
for uid, rules in rules_by_user.items():
if event.sender == uid:
continue
if uid in ignorers:
if uid not in uids_with_visibility:
continue
display_name = None
profile_info = room_members.get(uid)
if profile_info:
display_name = profile_info.display_name
profile = profiles.get(uid)
if profile:
display_name = profile.display_name
if not display_name:
# Handle the case where we are pushing a membership event to
@ -346,283 +338,3 @@ MemberMap = Dict[str, Optional[EventIdMembership]]
Rule = Dict[str, dict]
RulesByUser = Dict[str, List[Rule]]
StateGroup = Union[object, int]
@attr.s(slots=True, auto_attribs=True)
class RulesForRoomData:
"""The data stored in the cache by `RulesForRoom`.
We don't store `RulesForRoom` directly in the cache as we want our caches to
*only* include data, and not references to e.g. the data stores.
"""
# event_id -> EventIdMembership
member_map: MemberMap = attr.Factory(dict)
# user_id -> rules
rules_by_user: RulesByUser = attr.Factory(dict)
# The last state group we updated the caches for. If the state_group of
# a new event comes along, we know that we can just return the cached
# result.
# On invalidation of the rules themselves (if the user changes them),
# we invalidate everything and set state_group to `object()`
state_group: StateGroup = attr.Factory(object)
# A sequence number to keep track of when we're allowed to update the
# cache. We bump the sequence number when we invalidate the cache. If
# the sequence number changes while we're calculating stuff we should
# not update the cache with it.
sequence: int = 0
# A cache of user_ids that we *know* aren't interesting, e.g. user_ids
# owned by AS's, or remote users, etc. (I.e. users we will never need to
# calculate push for)
# These never need to be invalidated as we will never set up push for
# them.
uninteresting_user_set: Set[str] = attr.Factory(set)
class RulesForRoom:
"""Caches push rules for users in a room.
This efficiently handles users joining/leaving the room by not invalidating
the entire cache for the room.
A new instance is constructed for each call to
`BulkPushRuleEvaluator._get_rules_for_event`, with the cached data from
previous calls passed in.
"""
def __init__(
self,
hs: "HomeServer",
room_id: str,
rules_for_room_cache: LruCache,
room_push_rule_cache_metrics: CacheMetric,
linearizer: Linearizer,
cached_data: RulesForRoomData,
):
"""
Args:
hs: The HomeServer object.
room_id: The room ID.
rules_for_room_cache: The cache object that caches these
RoomsForUser objects.
room_push_rule_cache_metrics: The metrics object
linearizer: The linearizer used to ensure only one thing mutates
the cache at a time. Keyed off room_id
cached_data: Cached data from previous calls to `self.get_rules`,
can be mutated.
"""
self.room_id = room_id
self.is_mine_id = hs.is_mine_id
self.store = hs.get_datastores().main
self.room_push_rule_cache_metrics = room_push_rule_cache_metrics
# Used to ensure only one thing mutates the cache at a time. Keyed off
# room_id.
self.linearizer = linearizer
self.data = cached_data
# We need to be clever on the invalidating caches callbacks, as
# otherwise the invalidation callback holds a reference to the object,
# potentially causing it to leak.
# To get around this we pass a function that on invalidations looks ups
# the RoomsForUser entry in the cache, rather than keeping a reference
# to self around in the callback.
self.invalidate_all_cb = _Invalidation(rules_for_room_cache, room_id)
async def get_rules(
self, event: EventBase, context: EventContext
) -> Dict[str, List[Dict[str, dict]]]:
"""Given an event context return the rules for all users who are
currently in the room.
"""
state_group = context.state_group
if state_group and self.data.state_group == state_group:
logger.debug("Using cached rules for %r", self.room_id)
self.room_push_rule_cache_metrics.inc_hits()
return self.data.rules_by_user
async with self.linearizer.queue(self.room_id):
if state_group and self.data.state_group == state_group:
logger.debug("Using cached rules for %r", self.room_id)
self.room_push_rule_cache_metrics.inc_hits()
return self.data.rules_by_user
self.room_push_rule_cache_metrics.inc_misses()
ret_rules_by_user = {}
missing_member_event_ids = {}
if state_group and self.data.state_group == context.prev_group:
# If we have a simple delta then we can reuse most of the previous
# results.
ret_rules_by_user = self.data.rules_by_user
current_state_ids = context.delta_ids
push_rules_delta_state_cache_metric.inc_hits()
else:
current_state_ids = await context.get_current_state_ids()
push_rules_delta_state_cache_metric.inc_misses()
# Ensure the state IDs exist.
assert current_state_ids is not None
push_rules_state_size_counter.inc(len(current_state_ids))
logger.debug(
"Looking for member changes in %r %r", state_group, current_state_ids
)
# Loop through to see which member events we've seen and have rules
# for and which we need to fetch
for key in current_state_ids:
typ, user_id = key
if typ != EventTypes.Member:
continue
if user_id in self.data.uninteresting_user_set:
continue
if not self.is_mine_id(user_id):
self.data.uninteresting_user_set.add(user_id)
continue
if self.store.get_if_app_services_interested_in_user(user_id):
self.data.uninteresting_user_set.add(user_id)
continue
event_id = current_state_ids[key]
res = self.data.member_map.get(event_id, None)
if res:
if res.membership == Membership.JOIN:
rules = self.data.rules_by_user.get(res.user_id, None)
if rules:
ret_rules_by_user[res.user_id] = rules
continue
# If a user has left a room we remove their push rule. If they
# joined then we re-add it later in _update_rules_with_member_event_ids
ret_rules_by_user.pop(user_id, None)
missing_member_event_ids[user_id] = event_id
if missing_member_event_ids:
# If we have some member events we haven't seen, look them up
# and fetch push rules for them if appropriate.
logger.debug("Found new member events %r", missing_member_event_ids)
await self._update_rules_with_member_event_ids(
ret_rules_by_user, missing_member_event_ids, state_group, event
)
else:
# The push rules didn't change but lets update the cache anyway
self.update_cache(
self.data.sequence,
members={}, # There were no membership changes
rules_by_user=ret_rules_by_user,
state_group=state_group,
)
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"Returning push rules for %r %r", self.room_id, ret_rules_by_user.keys()
)
return ret_rules_by_user
async def _update_rules_with_member_event_ids(
self,
ret_rules_by_user: Dict[str, list],
member_event_ids: Dict[str, str],
state_group: Optional[int],
event: EventBase,
) -> None:
"""Update the partially filled rules_by_user dict by fetching rules for
any newly joined users in the `member_event_ids` list.
Args:
ret_rules_by_user: Partially filled dict of push rules. Gets
updated with any new rules.
member_event_ids: Dict of user id to event id for membership events
that have happened since the last time we filled rules_by_user
state_group: The state group we are currently computing push rules
for. Used when updating the cache.
event: The event we are currently computing push rules for.
"""
sequence = self.data.sequence
members = await self.store.get_membership_from_event_ids(
member_event_ids.values()
)
# If the event is a join event then it will be in current state events
# map but not in the DB, so we have to explicitly insert it.
if event.type == EventTypes.Member:
for event_id in member_event_ids.values():
if event_id == event.event_id:
members[event_id] = EventIdMembership(
user_id=event.state_key, membership=event.membership
)
if logger.isEnabledFor(logging.DEBUG):
logger.debug("Found members %r: %r", self.room_id, members.values())
joined_user_ids = {
entry.user_id
for entry in members.values()
if entry and entry.membership == Membership.JOIN
}
logger.debug("Joined: %r", joined_user_ids)
# Previously we only considered users with pushers or read receipts in that
# room. We can't do this anymore because we use push actions to calculate unread
# counts, which don't rely on the user having pushers or sent a read receipt into
# the room. Therefore we just need to filter for local users here.
user_ids = list(filter(self.is_mine_id, joined_user_ids))
rules_by_user = await self.store.bulk_get_push_rules(
user_ids, on_invalidate=self.invalidate_all_cb
)
ret_rules_by_user.update(
item for item in rules_by_user.items() if item[0] is not None
)
self.update_cache(sequence, members, ret_rules_by_user, state_group)
def update_cache(
self,
sequence: int,
members: MemberMap,
rules_by_user: RulesByUser,
state_group: StateGroup,
) -> None:
if sequence == self.data.sequence:
self.data.member_map.update(members)
self.data.rules_by_user = rules_by_user
self.data.state_group = state_group
@attr.attrs(slots=True, frozen=True, auto_attribs=True)
class _Invalidation:
# _Invalidation is passed as an `on_invalidate` callback to bulk_get_push_rules,
# which means that it it is stored on the bulk_get_push_rules cache entry. In order
# to ensure that we don't accumulate lots of redundant callbacks on the cache entry,
# we need to ensure that two _Invalidation objects are "equal" if they refer to the
# same `cache` and `room_id`.
#
# attrs provides suitable __hash__ and __eq__ methods, provided we remember to
# set `frozen=True`.
cache: LruCache
room_id: str
def __call__(self) -> None:
rules_data = self.cache.get(self.room_id, None, update_metrics=False)
if rules_data:
rules_data.sequence += 1
rules_data.state_group = object()
rules_data.member_map = {}
rules_data.rules_by_user = {}
push_rules_invalidation_counter.inc()

View file

@ -75,6 +75,15 @@ class SQLBaseStore(metaclass=ABCMeta):
self._attempt_to_invalidate_cache(
"get_users_in_room_with_profiles", (room_id,)
)
self._attempt_to_invalidate_cache(
"get_number_joined_users_in_room", (room_id,)
)
self._attempt_to_invalidate_cache("get_local_users_in_room", (room_id,))
for user_id in members_changed:
self._attempt_to_invalidate_cache(
"get_user_in_room_with_profile", (room_id, user_id)
)
# Purge other caches based on room state.
self._attempt_to_invalidate_cache("get_room_summary", (room_id,))

View file

@ -1797,6 +1797,18 @@ class PersistEventsStore:
self.store.get_invited_rooms_for_local_user.invalidate,
(event.state_key,),
)
txn.call_after(
self.store.get_local_users_in_room.invalidate,
(event.room_id,),
)
txn.call_after(
self.store.get_number_joined_users_in_room.invalidate,
(event.room_id,),
)
txn.call_after(
self.store.get_user_in_room_with_profile.invalidate,
(event.room_id, event.state_key),
)
# The `_get_membership_from_event_id` is immutable, except for the
# case where we look up an event *before* persisting it.

View file

@ -212,6 +212,60 @@ class RoomMemberWorkerStore(EventsWorkerStore):
txn.execute(sql, (room_id, Membership.JOIN))
return [r[0] for r in txn]
@cached()
def get_user_in_room_with_profile(
self, room_id: str, user_id: str
) -> Dict[str, ProfileInfo]:
raise NotImplementedError()
@cachedList(
cached_method_name="get_user_in_room_with_profile", list_name="user_ids"
)
async def get_subset_users_in_room_with_profiles(
self, room_id: str, user_ids: Collection[str]
) -> Dict[str, ProfileInfo]:
"""Get a mapping from user ID to profile information for a list of users
in a given room.
The profile information comes directly from this room's `m.room.member`
events, and so may be specific to this room rather than part of a user's
global profile. To avoid privacy leaks, the profile data should only be
revealed to users who are already in this room.
Args:
room_id: The ID of the room to retrieve the users of.
user_ids: a list of users in the room to run the query for
Returns:
A mapping from user ID to ProfileInfo.
"""
def _get_subset_users_in_room_with_profiles(
txn: LoggingTransaction,
) -> Dict[str, ProfileInfo]:
clause, ids = make_in_list_sql_clause(
self.database_engine, "m.user_id", user_ids
)
sql = """
SELECT state_key, display_name, avatar_url FROM room_memberships as m
INNER JOIN current_state_events as c
ON m.event_id = c.event_id
AND m.room_id = c.room_id
AND m.user_id = c.state_key
WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ? AND %s
""" % (
clause,
)
txn.execute(sql, (room_id, Membership.JOIN, *ids))
return {r[0]: ProfileInfo(display_name=r[1], avatar_url=r[2]) for r in txn}
return await self.db_pool.runInteraction(
"get_subset_users_in_room_with_profiles",
_get_subset_users_in_room_with_profiles,
)
@cached(max_entries=100000, iterable=True)
async def get_users_in_room_with_profiles(
self, room_id: str
@ -337,6 +391,15 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"get_room_summary", _get_room_summary_txn
)
@cached()
async def get_number_joined_users_in_room(self, room_id: str) -> int:
return await self.db_pool.simple_select_one_onecol(
table="current_state_events",
keyvalues={"room_id": room_id, "membership": Membership.JOIN},
retcol="COUNT(*)",
desc="get_number_joined_users_in_room",
)
@cached()
async def get_invited_rooms_for_local_user(
self, user_id: str
@ -416,6 +479,17 @@ class RoomMemberWorkerStore(EventsWorkerStore):
user_id: str,
membership_list: List[str],
) -> List[RoomsForUser]:
"""Get all the rooms for this *local* user where the membership for this user
matches one in the membership list.
Args:
user_id: The user ID.
membership_list: A list of synapse.api.constants.Membership
values which the user must be in.
Returns:
The RoomsForUser that the user matches the membership types.
"""
# Paranoia check.
if not self.hs.is_mine_id(user_id):
raise Exception(
@ -444,6 +518,18 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return results
@cached(iterable=True)
async def get_local_users_in_room(self, room_id: str) -> List[str]:
"""
Retrieves a list of the current roommembers who are local to the server.
"""
return await self.db_pool.simple_select_onecol(
table="local_current_membership",
keyvalues={"room_id": room_id, "membership": Membership.JOIN},
retcol="user_id",
desc="get_local_users_in_room",
)
async def get_local_current_membership_for_user_in_room(
self, user_id: str, room_id: str
) -> Tuple[Optional[str], Optional[str]]:

View file

@ -709,7 +709,7 @@ class RoomsCreateTestCase(RoomBase):
self.assertEqual(200, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body)
assert channel.resource_usage is not None
self.assertEqual(33, channel.resource_usage.db_txn_count)
self.assertEqual(37, channel.resource_usage.db_txn_count)
def test_post_room_initial_state(self) -> None:
# POST with initial_state config key, expect new room id
@ -722,7 +722,7 @@ class RoomsCreateTestCase(RoomBase):
self.assertEqual(200, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body)
assert channel.resource_usage is not None
self.assertEqual(37, channel.resource_usage.db_txn_count)
self.assertEqual(41, channel.resource_usage.db_txn_count)
def test_post_room_visibility_key(self) -> None:
# POST with visibility config key, expect new room id