Combine LruCache.invalidate and invalidate_many (#9973)

* Make `invalidate` and `invalidate_many` do the same thing

... so that we can do either over the invalidation replication stream, and also
because they always confused me a bit.

* Kill off `invalidate_many`

* changelog
This commit is contained in:
Richard van der Hoff 2021-05-27 10:33:56 +01:00 committed by GitHub
parent f42e4c4eb9
commit 224f2f949b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 52 additions and 52 deletions

1
changelog.d/9973.misc Normal file
View file

@ -0,0 +1 @@
Make `LruCache.invalidate` support tree invalidation, and remove `invalidate_many`.

View file

@ -68,7 +68,7 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
if row.entity.startswith("@"): if row.entity.startswith("@"):
self._device_list_stream_cache.entity_has_changed(row.entity, token) self._device_list_stream_cache.entity_has_changed(row.entity, token)
self.get_cached_devices_for_user.invalidate((row.entity,)) self.get_cached_devices_for_user.invalidate((row.entity,))
self._get_cached_user_device.invalidate_many((row.entity,)) self._get_cached_user_device.invalidate((row.entity,))
self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,)) self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,))
else: else:

View file

@ -171,7 +171,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self.get_latest_event_ids_in_room.invalidate((room_id,)) self.get_latest_event_ids_in_room.invalidate((room_id,))
self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,)) self.get_unread_event_push_actions_by_room_for_user.invalidate((room_id,))
if not backfilled: if not backfilled:
self._events_stream_cache.entity_has_changed(room_id, stream_ordering) self._events_stream_cache.entity_has_changed(room_id, stream_ordering)
@ -184,8 +184,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self.get_invited_rooms_for_local_user.invalidate((state_key,)) self.get_invited_rooms_for_local_user.invalidate((state_key,))
if relates_to: if relates_to:
self.get_relations_for_event.invalidate_many((relates_to,)) self.get_relations_for_event.invalidate((relates_to,))
self.get_aggregation_groups_for_event.invalidate_many((relates_to,)) self.get_aggregation_groups_for_event.invalidate((relates_to,))
self.get_applicable_edit.invalidate((relates_to,)) self.get_applicable_edit.invalidate((relates_to,))
async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]): async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):

View file

@ -1282,7 +1282,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
) )
txn.call_after(self.get_cached_devices_for_user.invalidate, (user_id,)) txn.call_after(self.get_cached_devices_for_user.invalidate, (user_id,))
txn.call_after(self._get_cached_user_device.invalidate_many, (user_id,)) txn.call_after(self._get_cached_user_device.invalidate, (user_id,))
txn.call_after( txn.call_after(
self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
) )

View file

@ -860,7 +860,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
not be deleted. not be deleted.
""" """
txn.call_after( txn.call_after(
self.get_unread_event_push_actions_by_room_for_user.invalidate_many, self.get_unread_event_push_actions_by_room_for_user.invalidate,
(room_id, user_id), (room_id, user_id),
) )

View file

@ -1748,9 +1748,9 @@ class PersistEventsStore:
}, },
) )
txn.call_after(self.store.get_relations_for_event.invalidate_many, (parent_id,)) txn.call_after(self.store.get_relations_for_event.invalidate, (parent_id,))
txn.call_after( txn.call_after(
self.store.get_aggregation_groups_for_event.invalidate_many, (parent_id,) self.store.get_aggregation_groups_for_event.invalidate, (parent_id,)
) )
if rel_type == RelationTypes.REPLACE: if rel_type == RelationTypes.REPLACE:
@ -1903,7 +1903,7 @@ class PersistEventsStore:
for user_id in user_ids: for user_id in user_ids:
txn.call_after( txn.call_after(
self.store.get_unread_event_push_actions_by_room_for_user.invalidate_many, self.store.get_unread_event_push_actions_by_room_for_user.invalidate,
(room_id, user_id), (room_id, user_id),
) )
@ -1917,7 +1917,7 @@ class PersistEventsStore:
def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id): def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id):
# Sad that we have to blow away the cache for the whole room here # Sad that we have to blow away the cache for the whole room here
txn.call_after( txn.call_after(
self.store.get_unread_event_push_actions_by_room_for_user.invalidate_many, self.store.get_unread_event_push_actions_by_room_for_user.invalidate,
(room_id,), (room_id,),
) )
txn.execute( txn.execute(

View file

@ -460,7 +460,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id): def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
self.get_receipts_for_user.invalidate((user_id, receipt_type)) self.get_receipts_for_user.invalidate((user_id, receipt_type))
self._get_linearized_receipts_for_room.invalidate_many((room_id,)) self._get_linearized_receipts_for_room.invalidate((room_id,))
self.get_last_receipt_event_id_for_user.invalidate( self.get_last_receipt_event_id_for_user.invalidate(
(user_id, room_id, receipt_type) (user_id, room_id, receipt_type)
) )
@ -659,9 +659,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
) )
txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type)) txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type))
# FIXME: This shouldn't invalidate the whole cache # FIXME: This shouldn't invalidate the whole cache
txn.call_after( txn.call_after(self._get_linearized_receipts_for_room.invalidate, (room_id,))
self._get_linearized_receipts_for_room.invalidate_many, (room_id,)
)
self.db_pool.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, txn,

View file

@ -16,16 +16,7 @@
import enum import enum
import threading import threading
from typing import ( from typing import Callable, Generic, Iterable, MutableMapping, Optional, TypeVar, Union
Callable,
Generic,
Iterable,
MutableMapping,
Optional,
TypeVar,
Union,
cast,
)
from prometheus_client import Gauge from prometheus_client import Gauge
@ -91,7 +82,7 @@ class DeferredCache(Generic[KT, VT]):
# _pending_deferred_cache maps from the key value to a `CacheEntry` object. # _pending_deferred_cache maps from the key value to a `CacheEntry` object.
self._pending_deferred_cache = ( self._pending_deferred_cache = (
cache_type() cache_type()
) # type: MutableMapping[KT, CacheEntry] ) # type: Union[TreeCache, MutableMapping[KT, CacheEntry]]
def metrics_cb(): def metrics_cb():
cache_pending_metric.labels(name).set(len(self._pending_deferred_cache)) cache_pending_metric.labels(name).set(len(self._pending_deferred_cache))
@ -287,8 +278,17 @@ class DeferredCache(Generic[KT, VT]):
self.cache.set(key, value, callbacks=callbacks) self.cache.set(key, value, callbacks=callbacks)
def invalidate(self, key): def invalidate(self, key):
"""Delete a key, or tree of entries
If the cache is backed by a regular dict, then "key" must be of
the right type for this cache
If the cache is backed by a TreeCache, then "key" must be a tuple, but
may be of lower cardinality than the TreeCache - in which case the whole
subtree is deleted.
"""
self.check_thread() self.check_thread()
self.cache.pop(key, None) self.cache.del_multi(key)
# if we have a pending lookup for this key, remove it from the # if we have a pending lookup for this key, remove it from the
# _pending_deferred_cache, which will (a) stop it being returned # _pending_deferred_cache, which will (a) stop it being returned
@ -299,20 +299,10 @@ class DeferredCache(Generic[KT, VT]):
# run the invalidation callbacks now, rather than waiting for the # run the invalidation callbacks now, rather than waiting for the
# deferred to resolve. # deferred to resolve.
if entry: if entry:
entry.invalidate() # _pending_deferred_cache.pop should either return a CacheEntry, or, in the
# case of a TreeCache, a dict of keys to cache entries. Either way calling
def invalidate_many(self, key: KT): # iterate_tree_cache_entry on it will do the right thing.
self.check_thread() for entry in iterate_tree_cache_entry(entry):
if not isinstance(key, tuple):
raise TypeError("The cache key must be a tuple not %r" % (type(key),))
key = cast(KT, key)
self.cache.del_multi(key)
# if we have a pending lookup for this key, remove it from the
# _pending_deferred_cache, as above
entry_dict = self._pending_deferred_cache.pop(key, None)
if entry_dict is not None:
for entry in iterate_tree_cache_entry(entry_dict):
entry.invalidate() entry.invalidate()
def invalidate_all(self): def invalidate_all(self):

View file

@ -48,7 +48,6 @@ F = TypeVar("F", bound=Callable[..., Any])
class _CachedFunction(Generic[F]): class _CachedFunction(Generic[F]):
invalidate = None # type: Any invalidate = None # type: Any
invalidate_all = None # type: Any invalidate_all = None # type: Any
invalidate_many = None # type: Any
prefill = None # type: Any prefill = None # type: Any
cache = None # type: Any cache = None # type: Any
num_args = None # type: Any num_args = None # type: Any
@ -262,6 +261,11 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
): ):
super().__init__(orig, num_args=num_args, cache_context=cache_context) super().__init__(orig, num_args=num_args, cache_context=cache_context)
if tree and self.num_args < 2:
raise RuntimeError(
"tree=True is nonsensical for cached functions with a single parameter"
)
self.max_entries = max_entries self.max_entries = max_entries
self.tree = tree self.tree = tree
self.iterable = iterable self.iterable = iterable
@ -302,11 +306,11 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
wrapped = cast(_CachedFunction, _wrapped) wrapped = cast(_CachedFunction, _wrapped)
if self.num_args == 1: if self.num_args == 1:
assert not self.tree
wrapped.invalidate = lambda key: cache.invalidate(key[0]) wrapped.invalidate = lambda key: cache.invalidate(key[0])
wrapped.prefill = lambda key, val: cache.prefill(key[0], val) wrapped.prefill = lambda key, val: cache.prefill(key[0], val)
else: else:
wrapped.invalidate = cache.invalidate wrapped.invalidate = cache.invalidate
wrapped.invalidate_many = cache.invalidate_many
wrapped.prefill = cache.prefill wrapped.prefill = cache.prefill
wrapped.invalidate_all = cache.invalidate_all wrapped.invalidate_all = cache.invalidate_all

View file

@ -152,7 +152,6 @@ class LruCache(Generic[KT, VT]):
""" """
Least-recently-used cache, supporting prometheus metrics and invalidation callbacks. Least-recently-used cache, supporting prometheus metrics and invalidation callbacks.
Supports del_multi only if cache_type=TreeCache
If cache_type=TreeCache, all keys must be tuples. If cache_type=TreeCache, all keys must be tuples.
""" """
@ -393,10 +392,16 @@ class LruCache(Generic[KT, VT]):
@synchronized @synchronized
def cache_del_multi(key: KT) -> None: def cache_del_multi(key: KT) -> None:
"""Delete an entry, or tree of entries
If the LruCache is backed by a regular dict, then "key" must be of
the right type for this cache
If the LruCache is backed by a TreeCache, then "key" must be a tuple, but
may be of lower cardinality than the TreeCache - in which case the whole
subtree is deleted.
""" """
This will only work if constructed with cache_type=TreeCache popped = cache.pop(key, None)
"""
popped = cache.pop(key)
if popped is None: if popped is None:
return return
# for each deleted node, we now need to remove it from the linked list # for each deleted node, we now need to remove it from the linked list
@ -430,11 +435,10 @@ class LruCache(Generic[KT, VT]):
self.set = cache_set self.set = cache_set
self.setdefault = cache_set_default self.setdefault = cache_set_default
self.pop = cache_pop self.pop = cache_pop
self.del_multi = cache_del_multi
# `invalidate` is exposed for consistency with DeferredCache, so that it can be # `invalidate` is exposed for consistency with DeferredCache, so that it can be
# invalidated by the cache invalidation replication stream. # invalidated by the cache invalidation replication stream.
self.invalidate = cache_pop self.invalidate = cache_del_multi
if cache_type is TreeCache:
self.del_multi = cache_del_multi
self.len = synchronized(cache_len) self.len = synchronized(cache_len)
self.contains = cache_contains self.contains = cache_contains
self.clear = cache_clear self.clear = cache_clear

View file

@ -89,6 +89,9 @@ class TreeCache:
value. If the key is partial, the TreeCacheNode corresponding to the part value. If the key is partial, the TreeCacheNode corresponding to the part
of the tree that was removed. of the tree that was removed.
""" """
if not isinstance(key, tuple):
raise TypeError("The cache key must be a tuple not %r" % (type(key),))
# a list of the nodes we have touched on the way down the tree # a list of the nodes we have touched on the way down the tree
nodes = [] nodes = []

View file

@ -622,17 +622,17 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
self.assertEquals(callcount2[0], 1) self.assertEquals(callcount2[0], 1)
a.func2.invalidate(("foo",)) a.func2.invalidate(("foo",))
self.assertEquals(a.func2.cache.cache.pop.call_count, 1) self.assertEquals(a.func2.cache.cache.del_multi.call_count, 1)
yield a.func2("foo") yield a.func2("foo")
a.func2.invalidate(("foo",)) a.func2.invalidate(("foo",))
self.assertEquals(a.func2.cache.cache.pop.call_count, 2) self.assertEquals(a.func2.cache.cache.del_multi.call_count, 2)
self.assertEquals(callcount[0], 1) self.assertEquals(callcount[0], 1)
self.assertEquals(callcount2[0], 2) self.assertEquals(callcount2[0], 2)
a.func.invalidate(("foo",)) a.func.invalidate(("foo",))
self.assertEquals(a.func2.cache.cache.pop.call_count, 3) self.assertEquals(a.func2.cache.cache.del_multi.call_count, 3)
yield a.func("foo") yield a.func("foo")
self.assertEquals(callcount[0], 2) self.assertEquals(callcount[0], 2)