Add type hints to DictionaryCache and TTLCache. (#9442)

This commit is contained in:
Patrick Cloke 2021-03-29 12:15:33 -04:00 committed by GitHub
parent 7dcf3fd221
commit 01dd90b0f0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 96 additions and 67 deletions

1
changelog.d/9442.misc Normal file
View file

@ -0,0 +1 @@
Add type hints to the caching module.

View file

@ -71,8 +71,10 @@ WELL_KNOWN_RETRY_ATTEMPTS = 3
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_well_known_cache = TTLCache("well-known") _well_known_cache = TTLCache("well-known") # type: TTLCache[bytes, Optional[bytes]]
_had_valid_well_known_cache = TTLCache("had-valid-well-known") _had_valid_well_known_cache = TTLCache(
"had-valid-well-known"
) # type: TTLCache[bytes, bool]
@attr.s(slots=True, frozen=True) @attr.s(slots=True, frozen=True)
@ -88,8 +90,8 @@ class WellKnownResolver:
reactor: IReactorTime, reactor: IReactorTime,
agent: IAgent, agent: IAgent,
user_agent: bytes, user_agent: bytes,
well_known_cache: Optional[TTLCache] = None, well_known_cache: Optional[TTLCache[bytes, Optional[bytes]]] = None,
had_well_known_cache: Optional[TTLCache] = None, had_well_known_cache: Optional[TTLCache[bytes, bool]] = None,
): ):
self._reactor = reactor self._reactor = reactor
self._clock = Clock(reactor) self._clock = Clock(reactor)

View file

@ -183,12 +183,13 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
requests state from the cache, if False we need to query the DB for the requests state from the cache, if False we need to query the DB for the
missing state. missing state.
""" """
is_all, known_absent, state_dict_ids = cache.get(group) cache_entry = cache.get(group)
state_dict_ids = cache_entry.value
if is_all or state_filter.is_full(): if cache_entry.full or state_filter.is_full():
# Either we have everything or want everything, either way # Either we have everything or want everything, either way
# `is_all` tells us whether we've gotten everything. # `is_all` tells us whether we've gotten everything.
return state_filter.filter_state(state_dict_ids), is_all return state_filter.filter_state(state_dict_ids), cache_entry.full
# tracks whether any of our requested types are missing from the cache # tracks whether any of our requested types are missing from the cache
missing_types = False missing_types = False
@ -202,7 +203,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
# There aren't any wild cards, so `concrete_types()` returns the # There aren't any wild cards, so `concrete_types()` returns the
# complete list of event types we're wanting. # complete list of event types we're wanting.
for key in state_filter.concrete_types(): for key in state_filter.concrete_types():
if key not in state_dict_ids and key not in known_absent: if key not in state_dict_ids and key not in cache_entry.known_absent:
missing_types = True missing_types = True
break break

View file

@ -15,26 +15,38 @@
import enum import enum
import logging import logging
import threading import threading
from collections import namedtuple from typing import Any, Dict, Generic, Iterable, Optional, Set, TypeVar
from typing import Any
import attr
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "known_absent", "value"))): # The type of the cache keys.
KT = TypeVar("KT")
# The type of the dictionary keys.
DKT = TypeVar("DKT")
@attr.s(slots=True)
class DictionaryEntry:
"""Returned when getting an entry from the cache """Returned when getting an entry from the cache
Attributes: Attributes:
full (bool): Whether the cache has the full or dict or just some keys. full: Whether the cache has the full or dict or just some keys.
If not full then not all requested keys will necessarily be present If not full then not all requested keys will necessarily be present
in `value` in `value`
known_absent (set): Keys that were looked up in the dict and were not known_absent: Keys that were looked up in the dict and were not
there. there.
value (dict): The full or partial dict value value: The full or partial dict value
""" """
full = attr.ib(type=bool)
known_absent = attr.ib()
value = attr.ib()
def __len__(self): def __len__(self):
return len(self.value) return len(self.value)
@ -45,21 +57,21 @@ class _Sentinel(enum.Enum):
sentinel = object() sentinel = object()
class DictionaryCache: class DictionaryCache(Generic[KT, DKT]):
"""Caches key -> dictionary lookups, supporting caching partial dicts, i.e. """Caches key -> dictionary lookups, supporting caching partial dicts, i.e.
fetching a subset of dictionary keys for a particular key. fetching a subset of dictionary keys for a particular key.
""" """
def __init__(self, name, max_entries=1000): def __init__(self, name: str, max_entries: int = 1000):
self.cache = LruCache( self.cache = LruCache(
max_size=max_entries, cache_name=name, size_callback=len max_size=max_entries, cache_name=name, size_callback=len
) # type: LruCache[Any, DictionaryEntry] ) # type: LruCache[KT, DictionaryEntry]
self.name = name self.name = name
self.sequence = 0 self.sequence = 0
self.thread = None self.thread = None # type: Optional[threading.Thread]
def check_thread(self): def check_thread(self) -> None:
expected_thread = self.thread expected_thread = self.thread
if expected_thread is None: if expected_thread is None:
self.thread = threading.current_thread() self.thread = threading.current_thread()
@ -69,12 +81,14 @@ class DictionaryCache:
"Cache objects can only be accessed from the main thread" "Cache objects can only be accessed from the main thread"
) )
def get(self, key, dict_keys=None): def get(
self, key: KT, dict_keys: Optional[Iterable[DKT]] = None
) -> DictionaryEntry:
"""Fetch an entry out of the cache """Fetch an entry out of the cache
Args: Args:
key key
dict_key(list): If given a set of keys then return only those keys dict_key: If given a set of keys then return only those keys
that exist in the cache. that exist in the cache.
Returns: Returns:
@ -95,7 +109,7 @@ class DictionaryCache:
return DictionaryEntry(False, set(), {}) return DictionaryEntry(False, set(), {})
def invalidate(self, key): def invalidate(self, key: KT) -> None:
self.check_thread() self.check_thread()
# Increment the sequence number so that any SELECT statements that # Increment the sequence number so that any SELECT statements that
@ -103,19 +117,25 @@ class DictionaryCache:
self.sequence += 1 self.sequence += 1
self.cache.pop(key, None) self.cache.pop(key, None)
def invalidate_all(self): def invalidate_all(self) -> None:
self.check_thread() self.check_thread()
self.sequence += 1 self.sequence += 1
self.cache.clear() self.cache.clear()
def update(self, sequence, key, value, fetched_keys=None): def update(
self,
sequence: int,
key: KT,
value: Dict[DKT, Any],
fetched_keys: Optional[Set[DKT]] = None,
) -> None:
"""Updates the entry in the cache """Updates the entry in the cache
Args: Args:
sequence sequence
key (K) key
value (dict[X,Y]): The value to update the cache with. value: The value to update the cache with.
fetched_keys (None|set[X]): All of the dictionary keys which were fetched_keys: All of the dictionary keys which were
fetched from the database. fetched from the database.
If None, this is the complete value for key K. Otherwise, it If None, this is the complete value for key K. Otherwise, it
@ -131,7 +151,9 @@ class DictionaryCache:
else: else:
self._update_or_insert(key, value, fetched_keys) self._update_or_insert(key, value, fetched_keys)
def _update_or_insert(self, key, value, known_absent): def _update_or_insert(
self, key: KT, value: Dict[DKT, Any], known_absent: Set[DKT]
) -> None:
# We pop and reinsert as we need to tell the cache the size may have # We pop and reinsert as we need to tell the cache the size may have
# changed # changed
@ -140,5 +162,5 @@ class DictionaryCache:
entry.known_absent.update(known_absent) entry.known_absent.update(known_absent)
self.cache[key] = entry self.cache[key] = entry
def _insert(self, key, value, known_absent): def _insert(self, key: KT, value: Dict[DKT, Any], known_absent: Set[DKT]) -> None:
self.cache[key] = DictionaryEntry(True, known_absent, value) self.cache[key] = DictionaryEntry(True, known_absent, value)

View file

@ -15,6 +15,7 @@
import logging import logging
import time import time
from typing import Any, Callable, Dict, Generic, Tuple, TypeVar, Union
import attr import attr
from sortedcontainers import SortedList from sortedcontainers import SortedList
@ -23,15 +24,19 @@ from synapse.util.caches import register_cache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SENTINEL = object() SENTINEL = object() # type: Any
T = TypeVar("T")
KT = TypeVar("KT")
VT = TypeVar("VT")
class TTLCache: class TTLCache(Generic[KT, VT]):
"""A key/value cache implementation where each entry has its own TTL""" """A key/value cache implementation where each entry has its own TTL"""
def __init__(self, cache_name, timer=time.time): def __init__(self, cache_name: str, timer: Callable[[], float] = time.time):
# map from key to _CacheEntry # map from key to _CacheEntry
self._data = {} self._data = {} # type: Dict[KT, _CacheEntry]
# the _CacheEntries, sorted by expiry time # the _CacheEntries, sorted by expiry time
self._expiry_list = SortedList() # type: SortedList[_CacheEntry] self._expiry_list = SortedList() # type: SortedList[_CacheEntry]
@ -40,26 +45,27 @@ class TTLCache:
self._metrics = register_cache("ttl", cache_name, self, resizable=False) self._metrics = register_cache("ttl", cache_name, self, resizable=False)
def set(self, key, value, ttl): def set(self, key: KT, value: VT, ttl: float) -> None:
"""Add/update an entry in the cache """Add/update an entry in the cache
Args: Args:
key: key for this entry key: key for this entry
value: value for this entry value: value for this entry
ttl (float): TTL for this entry, in seconds ttl: TTL for this entry, in seconds
""" """
expiry = self._timer() + ttl expiry = self._timer() + ttl
self.expire() self.expire()
e = self._data.pop(key, SENTINEL) e = self._data.pop(key, SENTINEL)
if e != SENTINEL: if e is not SENTINEL:
assert isinstance(e, _CacheEntry)
self._expiry_list.remove(e) self._expiry_list.remove(e)
entry = _CacheEntry(expiry_time=expiry, ttl=ttl, key=key, value=value) entry = _CacheEntry(expiry_time=expiry, ttl=ttl, key=key, value=value)
self._data[key] = entry self._data[key] = entry
self._expiry_list.add(entry) self._expiry_list.add(entry)
def get(self, key, default=SENTINEL): def get(self, key: KT, default: T = SENTINEL) -> Union[VT, T]:
"""Get a value from the cache """Get a value from the cache
Args: Args:
@ -72,23 +78,23 @@ class TTLCache:
""" """
self.expire() self.expire()
e = self._data.get(key, SENTINEL) e = self._data.get(key, SENTINEL)
if e == SENTINEL: if e is SENTINEL:
self._metrics.inc_misses() self._metrics.inc_misses()
if default == SENTINEL: if default is SENTINEL:
raise KeyError(key) raise KeyError(key)
return default return default
assert isinstance(e, _CacheEntry)
self._metrics.inc_hits() self._metrics.inc_hits()
return e.value return e.value
def get_with_expiry(self, key): def get_with_expiry(self, key: KT) -> Tuple[VT, float, float]:
"""Get a value, and its expiry time, from the cache """Get a value, and its expiry time, from the cache
Args: Args:
key: key to look up key: key to look up
Returns: Returns:
Tuple[Any, float, float]: the value from the cache, the expiry time A tuple of the value from the cache, the expiry time and the TTL
and the TTL
Raises: Raises:
KeyError if the entry is not found KeyError if the entry is not found
@ -102,7 +108,7 @@ class TTLCache:
self._metrics.inc_hits() self._metrics.inc_hits()
return e.value, e.expiry_time, e.ttl return e.value, e.expiry_time, e.ttl
def pop(self, key, default=SENTINEL): def pop(self, key: KT, default: T = SENTINEL) -> Union[VT, T]: # type: ignore
"""Remove a value from the cache """Remove a value from the cache
If key is in the cache, remove it and return its value, else return default. If key is in the cache, remove it and return its value, else return default.
@ -118,29 +124,30 @@ class TTLCache:
""" """
self.expire() self.expire()
e = self._data.pop(key, SENTINEL) e = self._data.pop(key, SENTINEL)
if e == SENTINEL: if e is SENTINEL:
self._metrics.inc_misses() self._metrics.inc_misses()
if default == SENTINEL: if default is SENTINEL:
raise KeyError(key) raise KeyError(key)
return default return default
assert isinstance(e, _CacheEntry)
self._expiry_list.remove(e) self._expiry_list.remove(e)
self._metrics.inc_hits() self._metrics.inc_hits()
return e.value return e.value
def __getitem__(self, key): def __getitem__(self, key: KT) -> VT:
return self.get(key) return self.get(key)
def __delitem__(self, key): def __delitem__(self, key: KT) -> None:
self.pop(key) self.pop(key)
def __contains__(self, key): def __contains__(self, key: KT) -> bool:
return key in self._data return key in self._data
def __len__(self): def __len__(self) -> int:
self.expire() self.expire()
return len(self._data) return len(self._data)
def expire(self): def expire(self) -> None:
"""Run the expiry on the cache. Any entries whose expiry times are due will """Run the expiry on the cache. Any entries whose expiry times are due will
be removed be removed
""" """
@ -158,7 +165,7 @@ class _CacheEntry:
"""TTLCache entry""" """TTLCache entry"""
# expiry_time is the first attribute, so that entries are sorted by expiry. # expiry_time is the first attribute, so that entries are sorted by expiry.
expiry_time = attr.ib() expiry_time = attr.ib(type=float)
ttl = attr.ib() ttl = attr.ib(type=float)
key = attr.ib() key = attr.ib()
value = attr.ib() value = attr.ib()

View file

@ -377,14 +377,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
####################################################### #######################################################
# deliberately remove e2 (room name) from the _state_group_cache # deliberately remove e2 (room name) from the _state_group_cache
( cache_entry = self.state_datastore._state_group_cache.get(group)
is_all, state_dict_ids = cache_entry.value
known_absent,
state_dict_ids,
) = self.state_datastore._state_group_cache.get(group)
self.assertEqual(is_all, True) self.assertEqual(cache_entry.full, True)
self.assertEqual(known_absent, set()) self.assertEqual(cache_entry.known_absent, set())
self.assertDictEqual( self.assertDictEqual(
state_dict_ids, state_dict_ids,
{ {
@ -403,14 +400,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
fetched_keys=((e1.type, e1.state_key),), fetched_keys=((e1.type, e1.state_key),),
) )
( cache_entry = self.state_datastore._state_group_cache.get(group)
is_all, state_dict_ids = cache_entry.value
known_absent,
state_dict_ids,
) = self.state_datastore._state_group_cache.get(group)
self.assertEqual(is_all, False) self.assertEqual(cache_entry.full, False)
self.assertEqual(known_absent, {(e1.type, e1.state_key)}) self.assertEqual(cache_entry.known_absent, {(e1.type, e1.state_key)})
self.assertDictEqual(state_dict_ids, {(e1.type, e1.state_key): e1.event_id}) self.assertDictEqual(state_dict_ids, {(e1.type, e1.state_key): e1.event_id})
############################################ ############################################

View file

@ -27,7 +27,9 @@ class DictCacheTestCase(unittest.TestCase):
key = "test_simple_cache_hit_full" key = "test_simple_cache_hit_full"
v = self.cache.get(key) v = self.cache.get(key)
self.assertEqual((False, set(), {}), v) self.assertIs(v.full, False)
self.assertEqual(v.known_absent, set())
self.assertEqual({}, v.value)
seq = self.cache.sequence seq = self.cache.sequence
test_value = {"test": "test_simple_cache_hit_full"} test_value = {"test": "test_simple_cache_hit_full"}