forked from MirrorHub/synapse
Better return type for get_all_entities_changed
(#14604)
Help callers from using the return value incorrectly by ensuring that callers explicitly check if there was a cache hit or not.
This commit is contained in:
parent
6a8310f3df
commit
cee9445884
8 changed files with 138 additions and 76 deletions
1
changelog.d/14604.bugfix
Normal file
1
changelog.d/14604.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Fix a long-standing bug where a device list update might not be sent to clients in certain circumstances.
|
|
@ -615,8 +615,8 @@ class ApplicationServicesHandler:
|
|||
)
|
||||
|
||||
# Fetch the users who have modified their device list since then.
|
||||
users_with_changed_device_lists = (
|
||||
await self.store.get_users_whose_devices_changed(from_key, to_key=new_key)
|
||||
users_with_changed_device_lists = await self.store.get_all_devices_changed(
|
||||
from_key, to_key=new_key
|
||||
)
|
||||
|
||||
# Filter out any users the application service is not interested in
|
||||
|
|
|
@ -1692,10 +1692,12 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
|
|||
|
||||
if from_key is not None:
|
||||
# First get all users that have had a presence update
|
||||
updated_users = stream_change_cache.get_all_entities_changed(from_key)
|
||||
result = stream_change_cache.get_all_entities_changed(from_key)
|
||||
|
||||
# Cross-reference users we're interested in with those that have had updates.
|
||||
if updated_users is not None:
|
||||
if result.hit:
|
||||
updated_users = result.entities
|
||||
|
||||
# If we have the full list of changes for presence we can
|
||||
# simply check which ones share a room with the user.
|
||||
get_updates_counter.labels("stream").inc()
|
||||
|
@ -1767,9 +1769,9 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
|
|||
updated_users = None
|
||||
if from_key:
|
||||
# Only return updates since the last sync
|
||||
updated_users = self.store.presence_stream_cache.get_all_entities_changed(
|
||||
from_key
|
||||
)
|
||||
result = self.store.presence_stream_cache.get_all_entities_changed(from_key)
|
||||
if result.hit:
|
||||
updated_users = result.entities
|
||||
|
||||
if updated_users is not None:
|
||||
# Get the actual presence update for each change
|
||||
|
|
|
@ -1528,10 +1528,12 @@ class SyncHandler:
|
|||
#
|
||||
# If we don't have that info cached then we get all the users that
|
||||
# share a room with our user and check if those users have changed.
|
||||
changed_users = self.store.get_cached_device_list_changes(
|
||||
cache_result = self.store.get_cached_device_list_changes(
|
||||
since_token.device_list_key
|
||||
)
|
||||
if changed_users is not None:
|
||||
if cache_result.hit:
|
||||
changed_users = cache_result.entities
|
||||
|
||||
result = await self.store.get_rooms_for_users(changed_users)
|
||||
|
||||
for changed_user_id, entries in result.items():
|
||||
|
|
|
@ -420,11 +420,11 @@ class TypingWriterHandler(FollowerTypingHandler):
|
|||
if last_id == current_id:
|
||||
return [], current_id, False
|
||||
|
||||
changed_rooms: Optional[
|
||||
Iterable[str]
|
||||
] = self._typing_stream_change_cache.get_all_entities_changed(last_id)
|
||||
result = self._typing_stream_change_cache.get_all_entities_changed(last_id)
|
||||
|
||||
if changed_rooms is None:
|
||||
if result.hit:
|
||||
changed_rooms: Iterable[str] = result.entities
|
||||
else:
|
||||
changed_rooms = self._room_serials
|
||||
|
||||
rows = []
|
||||
|
|
|
@ -58,7 +58,10 @@ from synapse.types import JsonDict, get_verify_key_from_cross_signing_key
|
|||
from synapse.util import json_decoder, json_encoder
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
from synapse.util.caches.lrucache import LruCache
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
from synapse.util.caches.stream_change_cache import (
|
||||
AllEntitiesChangedResult,
|
||||
StreamChangeCache,
|
||||
)
|
||||
from synapse.util.cancellation import cancellable
|
||||
from synapse.util.iterutils import batch_iter
|
||||
from synapse.util.stringutils import shortstr
|
||||
|
@ -799,18 +802,66 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
|||
def get_cached_device_list_changes(
|
||||
self,
|
||||
from_key: int,
|
||||
) -> Optional[List[str]]:
|
||||
) -> AllEntitiesChangedResult:
|
||||
"""Get set of users whose devices have changed since `from_key`, or None
|
||||
if that information is not in our cache.
|
||||
"""
|
||||
|
||||
return self._device_list_stream_cache.get_all_entities_changed(from_key)
|
||||
|
||||
@cancellable
|
||||
async def get_all_devices_changed(
|
||||
self,
|
||||
from_key: int,
|
||||
to_key: int,
|
||||
) -> Set[str]:
|
||||
"""Get all users whose devices have changed in the given range.
|
||||
|
||||
Args:
|
||||
from_key: The minimum device lists stream token to query device list
|
||||
changes for, exclusive.
|
||||
to_key: The maximum device lists stream token to query device list
|
||||
changes for, inclusive.
|
||||
|
||||
Returns:
|
||||
The set of user_ids whose devices have changed since `from_key`
|
||||
(exclusive) until `to_key` (inclusive).
|
||||
"""
|
||||
|
||||
result = self._device_list_stream_cache.get_all_entities_changed(from_key)
|
||||
|
||||
if result.hit:
|
||||
# We know which users might have changed devices.
|
||||
if not result.entities:
|
||||
# If no users then we can return early.
|
||||
return set()
|
||||
|
||||
# Otherwise we need to filter down the list
|
||||
return await self.get_users_whose_devices_changed(
|
||||
from_key, result.entities, to_key
|
||||
)
|
||||
|
||||
# If the cache didn't tell us anything, we just need to query the full
|
||||
# range.
|
||||
sql = """
|
||||
SELECT DISTINCT user_id FROM device_lists_stream
|
||||
WHERE ? < stream_id AND stream_id <= ?
|
||||
"""
|
||||
|
||||
rows = await self.db_pool.execute(
|
||||
"get_all_devices_changed",
|
||||
None,
|
||||
sql,
|
||||
from_key,
|
||||
to_key,
|
||||
)
|
||||
return {u for u, in rows}
|
||||
|
||||
@cancellable
|
||||
async def get_users_whose_devices_changed(
|
||||
self,
|
||||
from_key: int,
|
||||
user_ids: Optional[Collection[str]] = None,
|
||||
user_ids: Collection[str],
|
||||
to_key: Optional[int] = None,
|
||||
) -> Set[str]:
|
||||
"""Get set of users whose devices have changed since `from_key` that
|
||||
|
@ -830,52 +881,32 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
|||
"""
|
||||
# Get set of users who *may* have changed. Users not in the returned
|
||||
# list have definitely not changed.
|
||||
user_ids_to_check: Optional[Collection[str]]
|
||||
if user_ids is None:
|
||||
# Get set of all users that have had device list changes since 'from_key'
|
||||
user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed(
|
||||
from_key
|
||||
)
|
||||
else:
|
||||
# The same as above, but filter results to only those users in 'user_ids'
|
||||
user_ids_to_check = self._device_list_stream_cache.get_entities_changed(
|
||||
user_ids, from_key
|
||||
)
|
||||
user_ids_to_check = self._device_list_stream_cache.get_entities_changed(
|
||||
user_ids, from_key
|
||||
)
|
||||
|
||||
# If an empty set was returned, there's nothing to do.
|
||||
if user_ids_to_check is not None and not user_ids_to_check:
|
||||
if not user_ids_to_check:
|
||||
return set()
|
||||
|
||||
if to_key is None:
|
||||
to_key = self._device_list_id_gen.get_current_token()
|
||||
|
||||
def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]:
|
||||
stream_id_where_clause = "stream_id > ?"
|
||||
sql_args = [from_key]
|
||||
|
||||
if to_key:
|
||||
stream_id_where_clause += " AND stream_id <= ?"
|
||||
sql_args.append(to_key)
|
||||
|
||||
sql = f"""
|
||||
sql = """
|
||||
SELECT DISTINCT user_id FROM device_lists_stream
|
||||
WHERE {stream_id_where_clause}
|
||||
WHERE ? < stream_id AND stream_id <= ? AND %s
|
||||
"""
|
||||
|
||||
# If the stream change cache gave us no information, fetch *all*
|
||||
# users between the stream IDs.
|
||||
if user_ids_to_check is None:
|
||||
txn.execute(sql, sql_args)
|
||||
return {user_id for user_id, in txn}
|
||||
changes: Set[str] = set()
|
||||
|
||||
# Otherwise, fetch changes for the given users.
|
||||
else:
|
||||
changes: Set[str] = set()
|
||||
|
||||
# Query device changes with a batch of users at a time
|
||||
for chunk in batch_iter(user_ids_to_check, 100):
|
||||
clause, args = make_in_list_sql_clause(
|
||||
txn.database_engine, "user_id", chunk
|
||||
)
|
||||
txn.execute(sql + " AND " + clause, sql_args + args)
|
||||
changes.update(user_id for user_id, in txn)
|
||||
# Query device changes with a batch of users at a time
|
||||
for chunk in batch_iter(user_ids_to_check, 100):
|
||||
clause, args = make_in_list_sql_clause(
|
||||
txn.database_engine, "user_id", chunk
|
||||
)
|
||||
txn.execute(sql % (clause,), [from_key, to_key] + args)
|
||||
changes.update(user_id for user_id, in txn)
|
||||
|
||||
return changes
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@ import logging
|
|||
import math
|
||||
from typing import Collection, Dict, FrozenSet, List, Mapping, Optional, Set, Union
|
||||
|
||||
import attr
|
||||
from sortedcontainers import SortedDict
|
||||
|
||||
from synapse.util import caches
|
||||
|
@ -26,6 +27,29 @@ logger = logging.getLogger(__name__)
|
|||
EntityType = str
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True, frozen=True, slots=True)
|
||||
class AllEntitiesChangedResult:
|
||||
"""Return type of `get_all_entities_changed`.
|
||||
|
||||
Callers must check that there was a cache hit, via `result.hit`, before
|
||||
using the entities in `result.entities`.
|
||||
|
||||
This specifically does *not* implement helpers such as `__bool__` to ensure
|
||||
that callers do the correct checks.
|
||||
"""
|
||||
|
||||
_entities: Optional[List[EntityType]]
|
||||
|
||||
@property
|
||||
def hit(self) -> bool:
|
||||
return self._entities is not None
|
||||
|
||||
@property
|
||||
def entities(self) -> List[EntityType]:
|
||||
assert self._entities is not None
|
||||
return self._entities
|
||||
|
||||
|
||||
class StreamChangeCache:
|
||||
"""
|
||||
Keeps track of the stream positions of the latest change in a set of entities.
|
||||
|
@ -153,19 +177,19 @@ class StreamChangeCache:
|
|||
This will be all entities if the given stream position is at or earlier
|
||||
than the earliest known stream position.
|
||||
"""
|
||||
changed_entities = self.get_all_entities_changed(stream_pos)
|
||||
if changed_entities is not None:
|
||||
cache_result = self.get_all_entities_changed(stream_pos)
|
||||
if cache_result.hit:
|
||||
# We now do an intersection, trying to do so in the most efficient
|
||||
# way possible (some of these sets are *large*). First check in the
|
||||
# given iterable is already a set that we can reuse, otherwise we
|
||||
# create a set of the *smallest* of the two iterables and call
|
||||
# `intersection(..)` on it (this can be twice as fast as the reverse).
|
||||
if isinstance(entities, (set, frozenset)):
|
||||
result = entities.intersection(changed_entities)
|
||||
elif len(changed_entities) < len(entities):
|
||||
result = set(changed_entities).intersection(entities)
|
||||
result = entities.intersection(cache_result.entities)
|
||||
elif len(cache_result.entities) < len(entities):
|
||||
result = set(cache_result.entities).intersection(entities)
|
||||
else:
|
||||
result = set(entities).intersection(changed_entities)
|
||||
result = set(entities).intersection(cache_result.entities)
|
||||
self.metrics.inc_hits()
|
||||
else:
|
||||
result = set(entities)
|
||||
|
@ -202,12 +226,12 @@ class StreamChangeCache:
|
|||
self.metrics.inc_hits()
|
||||
return stream_pos < self._cache.peekitem()[0]
|
||||
|
||||
def get_all_entities_changed(self, stream_pos: int) -> Optional[List[EntityType]]:
|
||||
def get_all_entities_changed(self, stream_pos: int) -> AllEntitiesChangedResult:
|
||||
"""
|
||||
Returns all entities that have had changes after the given position.
|
||||
|
||||
If the stream change cache does not go far enough back, i.e. the position
|
||||
is too old, it will return None.
|
||||
If the stream change cache does not go far enough back, i.e. the
|
||||
position is too old, it will return None.
|
||||
|
||||
Returns the entities in the order that they were changed.
|
||||
|
||||
|
@ -215,23 +239,21 @@ class StreamChangeCache:
|
|||
stream_pos: The stream position to check for changes after.
|
||||
|
||||
Return:
|
||||
Entities which have changed after the given stream position.
|
||||
|
||||
None if the given stream position is at or earlier than the earliest
|
||||
known stream position.
|
||||
A class indicating if we have the requested data cached, and if so
|
||||
includes the entities in the order they were changed.
|
||||
"""
|
||||
assert isinstance(stream_pos, int)
|
||||
|
||||
# _cache is not valid at or before the earliest known stream position, so
|
||||
# return None to mark that it is unknown if an entity has changed.
|
||||
if stream_pos <= self._earliest_known_stream_pos:
|
||||
return None
|
||||
return AllEntitiesChangedResult(None)
|
||||
|
||||
changed_entities: List[EntityType] = []
|
||||
|
||||
for k in self._cache.islice(start=self._cache.bisect_right(stream_pos)):
|
||||
changed_entities.extend(self._cache[k])
|
||||
return changed_entities
|
||||
return AllEntitiesChangedResult(changed_entities)
|
||||
|
||||
def entity_has_changed(self, entity: EntityType, stream_pos: int) -> None:
|
||||
"""
|
||||
|
|
|
@ -73,8 +73,10 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
|
|||
# The oldest item has been popped off
|
||||
self.assertTrue("user@foo.com" not in cache._entity_to_key)
|
||||
|
||||
self.assertEqual(cache.get_all_entities_changed(3), ["user@elsewhere.org"])
|
||||
self.assertIsNone(cache.get_all_entities_changed(2))
|
||||
self.assertEqual(
|
||||
cache.get_all_entities_changed(3).entities, ["user@elsewhere.org"]
|
||||
)
|
||||
self.assertFalse(cache.get_all_entities_changed(2).hit)
|
||||
|
||||
# If we update an existing entity, it keeps the two existing entities
|
||||
cache.entity_has_changed("bar@baz.net", 5)
|
||||
|
@ -82,10 +84,10 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
|
|||
{"bar@baz.net", "user@elsewhere.org"}, set(cache._entity_to_key)
|
||||
)
|
||||
self.assertEqual(
|
||||
cache.get_all_entities_changed(3),
|
||||
cache.get_all_entities_changed(3).entities,
|
||||
["user@elsewhere.org", "bar@baz.net"],
|
||||
)
|
||||
self.assertIsNone(cache.get_all_entities_changed(2))
|
||||
self.assertFalse(cache.get_all_entities_changed(2).hit)
|
||||
|
||||
def test_get_all_entities_changed(self) -> None:
|
||||
"""
|
||||
|
@ -105,10 +107,12 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
|
|||
# Results are ordered so either of these are valid.
|
||||
ok1 = ["bar@baz.net", "anotheruser@foo.com", "user@elsewhere.org"]
|
||||
ok2 = ["anotheruser@foo.com", "bar@baz.net", "user@elsewhere.org"]
|
||||
self.assertTrue(r == ok1 or r == ok2)
|
||||
self.assertTrue(r.entities == ok1 or r.entities == ok2)
|
||||
|
||||
self.assertEqual(cache.get_all_entities_changed(3), ["user@elsewhere.org"])
|
||||
self.assertEqual(cache.get_all_entities_changed(1), None)
|
||||
self.assertEqual(
|
||||
cache.get_all_entities_changed(3).entities, ["user@elsewhere.org"]
|
||||
)
|
||||
self.assertFalse(cache.get_all_entities_changed(1).hit)
|
||||
|
||||
# ... later, things gest more updates
|
||||
cache.entity_has_changed("user@foo.com", 5)
|
||||
|
@ -128,7 +132,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
|
|||
"anotheruser@foo.com",
|
||||
]
|
||||
r = cache.get_all_entities_changed(3)
|
||||
self.assertTrue(r == ok1 or r == ok2)
|
||||
self.assertTrue(r.entities == ok1 or r.entities == ok2)
|
||||
|
||||
def test_has_any_entity_changed(self) -> None:
|
||||
"""
|
||||
|
|
Loading…
Reference in a new issue