Better return type for get_all_entities_changed (#14604)

Help callers from using the return value incorrectly by ensuring
that callers explicitly check if there was a cache hit or not.
This commit is contained in:
Erik Johnston 2022-12-05 20:19:14 +00:00 committed by GitHub
parent 6a8310f3df
commit cee9445884
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 138 additions and 76 deletions

1
changelog.d/14604.bugfix Normal file
View file

@ -0,0 +1 @@
Fix a long-standing bug where a device list update might not be sent to clients in certain circumstances.

View file

@ -615,8 +615,8 @@ class ApplicationServicesHandler:
)
# Fetch the users who have modified their device list since then.
users_with_changed_device_lists = (
await self.store.get_users_whose_devices_changed(from_key, to_key=new_key)
users_with_changed_device_lists = await self.store.get_all_devices_changed(
from_key, to_key=new_key
)
# Filter out any users the application service is not interested in

View file

@ -1692,10 +1692,12 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
if from_key is not None:
# First get all users that have had a presence update
updated_users = stream_change_cache.get_all_entities_changed(from_key)
result = stream_change_cache.get_all_entities_changed(from_key)
# Cross-reference users we're interested in with those that have had updates.
if updated_users is not None:
if result.hit:
updated_users = result.entities
# If we have the full list of changes for presence we can
# simply check which ones share a room with the user.
get_updates_counter.labels("stream").inc()
@ -1767,9 +1769,9 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
updated_users = None
if from_key:
# Only return updates since the last sync
updated_users = self.store.presence_stream_cache.get_all_entities_changed(
from_key
)
result = self.store.presence_stream_cache.get_all_entities_changed(from_key)
if result.hit:
updated_users = result.entities
if updated_users is not None:
# Get the actual presence update for each change

View file

@ -1528,10 +1528,12 @@ class SyncHandler:
#
# If we don't have that info cached then we get all the users that
# share a room with our user and check if those users have changed.
changed_users = self.store.get_cached_device_list_changes(
cache_result = self.store.get_cached_device_list_changes(
since_token.device_list_key
)
if changed_users is not None:
if cache_result.hit:
changed_users = cache_result.entities
result = await self.store.get_rooms_for_users(changed_users)
for changed_user_id, entries in result.items():

View file

@ -420,11 +420,11 @@ class TypingWriterHandler(FollowerTypingHandler):
if last_id == current_id:
return [], current_id, False
changed_rooms: Optional[
Iterable[str]
] = self._typing_stream_change_cache.get_all_entities_changed(last_id)
result = self._typing_stream_change_cache.get_all_entities_changed(last_id)
if changed_rooms is None:
if result.hit:
changed_rooms: Iterable[str] = result.entities
else:
changed_rooms = self._room_serials
rows = []

View file

@ -58,7 +58,10 @@ from synapse.types import JsonDict, get_verify_key_from_cross_signing_key
from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.caches.stream_change_cache import (
AllEntitiesChangedResult,
StreamChangeCache,
)
from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter
from synapse.util.stringutils import shortstr
@ -799,18 +802,66 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
def get_cached_device_list_changes(
self,
from_key: int,
) -> Optional[List[str]]:
) -> AllEntitiesChangedResult:
"""Get set of users whose devices have changed since `from_key`, or None
if that information is not in our cache.
"""
return self._device_list_stream_cache.get_all_entities_changed(from_key)
@cancellable
async def get_all_devices_changed(
self,
from_key: int,
to_key: int,
) -> Set[str]:
"""Get all users whose devices have changed in the given range.
Args:
from_key: The minimum device lists stream token to query device list
changes for, exclusive.
to_key: The maximum device lists stream token to query device list
changes for, inclusive.
Returns:
The set of user_ids whose devices have changed since `from_key`
(exclusive) until `to_key` (inclusive).
"""
result = self._device_list_stream_cache.get_all_entities_changed(from_key)
if result.hit:
# We know which users might have changed devices.
if not result.entities:
# If no users then we can return early.
return set()
# Otherwise we need to filter down the list
return await self.get_users_whose_devices_changed(
from_key, result.entities, to_key
)
# If the cache didn't tell us anything, we just need to query the full
# range.
sql = """
SELECT DISTINCT user_id FROM device_lists_stream
WHERE ? < stream_id AND stream_id <= ?
"""
rows = await self.db_pool.execute(
"get_all_devices_changed",
None,
sql,
from_key,
to_key,
)
return {u for u, in rows}
@cancellable
async def get_users_whose_devices_changed(
self,
from_key: int,
user_ids: Optional[Collection[str]] = None,
user_ids: Collection[str],
to_key: Optional[int] = None,
) -> Set[str]:
"""Get set of users whose devices have changed since `from_key` that
@ -830,52 +881,32 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
"""
# Get set of users who *may* have changed. Users not in the returned
# list have definitely not changed.
user_ids_to_check: Optional[Collection[str]]
if user_ids is None:
# Get set of all users that have had device list changes since 'from_key'
user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed(
from_key
)
else:
# The same as above, but filter results to only those users in 'user_ids'
user_ids_to_check = self._device_list_stream_cache.get_entities_changed(
user_ids, from_key
)
user_ids_to_check = self._device_list_stream_cache.get_entities_changed(
user_ids, from_key
)
# If an empty set was returned, there's nothing to do.
if user_ids_to_check is not None and not user_ids_to_check:
if not user_ids_to_check:
return set()
if to_key is None:
to_key = self._device_list_id_gen.get_current_token()
def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]:
stream_id_where_clause = "stream_id > ?"
sql_args = [from_key]
if to_key:
stream_id_where_clause += " AND stream_id <= ?"
sql_args.append(to_key)
sql = f"""
sql = """
SELECT DISTINCT user_id FROM device_lists_stream
WHERE {stream_id_where_clause}
WHERE ? < stream_id AND stream_id <= ? AND %s
"""
# If the stream change cache gave us no information, fetch *all*
# users between the stream IDs.
if user_ids_to_check is None:
txn.execute(sql, sql_args)
return {user_id for user_id, in txn}
changes: Set[str] = set()
# Otherwise, fetch changes for the given users.
else:
changes: Set[str] = set()
# Query device changes with a batch of users at a time
for chunk in batch_iter(user_ids_to_check, 100):
clause, args = make_in_list_sql_clause(
txn.database_engine, "user_id", chunk
)
txn.execute(sql + " AND " + clause, sql_args + args)
changes.update(user_id for user_id, in txn)
# Query device changes with a batch of users at a time
for chunk in batch_iter(user_ids_to_check, 100):
clause, args = make_in_list_sql_clause(
txn.database_engine, "user_id", chunk
)
txn.execute(sql % (clause,), [from_key, to_key] + args)
changes.update(user_id for user_id, in txn)
return changes

View file

@ -16,6 +16,7 @@ import logging
import math
from typing import Collection, Dict, FrozenSet, List, Mapping, Optional, Set, Union
import attr
from sortedcontainers import SortedDict
from synapse.util import caches
@ -26,6 +27,29 @@ logger = logging.getLogger(__name__)
EntityType = str
@attr.s(auto_attribs=True, frozen=True, slots=True)
class AllEntitiesChangedResult:
"""Return type of `get_all_entities_changed`.
Callers must check that there was a cache hit, via `result.hit`, before
using the entities in `result.entities`.
This specifically does *not* implement helpers such as `__bool__` to ensure
that callers do the correct checks.
"""
_entities: Optional[List[EntityType]]
@property
def hit(self) -> bool:
return self._entities is not None
@property
def entities(self) -> List[EntityType]:
assert self._entities is not None
return self._entities
class StreamChangeCache:
"""
Keeps track of the stream positions of the latest change in a set of entities.
@ -153,19 +177,19 @@ class StreamChangeCache:
This will be all entities if the given stream position is at or earlier
than the earliest known stream position.
"""
changed_entities = self.get_all_entities_changed(stream_pos)
if changed_entities is not None:
cache_result = self.get_all_entities_changed(stream_pos)
if cache_result.hit:
# We now do an intersection, trying to do so in the most efficient
# way possible (some of these sets are *large*). First check in the
# given iterable is already a set that we can reuse, otherwise we
# create a set of the *smallest* of the two iterables and call
# `intersection(..)` on it (this can be twice as fast as the reverse).
if isinstance(entities, (set, frozenset)):
result = entities.intersection(changed_entities)
elif len(changed_entities) < len(entities):
result = set(changed_entities).intersection(entities)
result = entities.intersection(cache_result.entities)
elif len(cache_result.entities) < len(entities):
result = set(cache_result.entities).intersection(entities)
else:
result = set(entities).intersection(changed_entities)
result = set(entities).intersection(cache_result.entities)
self.metrics.inc_hits()
else:
result = set(entities)
@ -202,12 +226,12 @@ class StreamChangeCache:
self.metrics.inc_hits()
return stream_pos < self._cache.peekitem()[0]
def get_all_entities_changed(self, stream_pos: int) -> Optional[List[EntityType]]:
def get_all_entities_changed(self, stream_pos: int) -> AllEntitiesChangedResult:
"""
Returns all entities that have had changes after the given position.
If the stream change cache does not go far enough back, i.e. the position
is too old, it will return None.
If the stream change cache does not go far enough back, i.e. the
position is too old, it will return None.
Returns the entities in the order that they were changed.
@ -215,23 +239,21 @@ class StreamChangeCache:
stream_pos: The stream position to check for changes after.
Return:
Entities which have changed after the given stream position.
None if the given stream position is at or earlier than the earliest
known stream position.
A class indicating if we have the requested data cached, and if so
includes the entities in the order they were changed.
"""
assert isinstance(stream_pos, int)
# _cache is not valid at or before the earliest known stream position, so
# return None to mark that it is unknown if an entity has changed.
if stream_pos <= self._earliest_known_stream_pos:
return None
return AllEntitiesChangedResult(None)
changed_entities: List[EntityType] = []
for k in self._cache.islice(start=self._cache.bisect_right(stream_pos)):
changed_entities.extend(self._cache[k])
return changed_entities
return AllEntitiesChangedResult(changed_entities)
def entity_has_changed(self, entity: EntityType, stream_pos: int) -> None:
"""

View file

@ -73,8 +73,10 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
# The oldest item has been popped off
self.assertTrue("user@foo.com" not in cache._entity_to_key)
self.assertEqual(cache.get_all_entities_changed(3), ["user@elsewhere.org"])
self.assertIsNone(cache.get_all_entities_changed(2))
self.assertEqual(
cache.get_all_entities_changed(3).entities, ["user@elsewhere.org"]
)
self.assertFalse(cache.get_all_entities_changed(2).hit)
# If we update an existing entity, it keeps the two existing entities
cache.entity_has_changed("bar@baz.net", 5)
@ -82,10 +84,10 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
{"bar@baz.net", "user@elsewhere.org"}, set(cache._entity_to_key)
)
self.assertEqual(
cache.get_all_entities_changed(3),
cache.get_all_entities_changed(3).entities,
["user@elsewhere.org", "bar@baz.net"],
)
self.assertIsNone(cache.get_all_entities_changed(2))
self.assertFalse(cache.get_all_entities_changed(2).hit)
def test_get_all_entities_changed(self) -> None:
"""
@ -105,10 +107,12 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
# Results are ordered so either of these are valid.
ok1 = ["bar@baz.net", "anotheruser@foo.com", "user@elsewhere.org"]
ok2 = ["anotheruser@foo.com", "bar@baz.net", "user@elsewhere.org"]
self.assertTrue(r == ok1 or r == ok2)
self.assertTrue(r.entities == ok1 or r.entities == ok2)
self.assertEqual(cache.get_all_entities_changed(3), ["user@elsewhere.org"])
self.assertEqual(cache.get_all_entities_changed(1), None)
self.assertEqual(
cache.get_all_entities_changed(3).entities, ["user@elsewhere.org"]
)
self.assertFalse(cache.get_all_entities_changed(1).hit)
# ... later, things gest more updates
cache.entity_has_changed("user@foo.com", 5)
@ -128,7 +132,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
"anotheruser@foo.com",
]
r = cache.get_all_entities_changed(3)
self.assertTrue(r == ok1 or r == ok2)
self.assertTrue(r.entities == ok1 or r.entities == ok2)
def test_has_any_entity_changed(self) -> None:
"""