Add type hints to DictionaryCache and TTLCache. (#9442)

This commit is contained in:
Patrick Cloke 2021-03-29 12:15:33 -04:00 committed by GitHub
parent 7dcf3fd221
commit 01dd90b0f0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 96 additions and 67 deletions

1
changelog.d/9442.misc Normal file
View file

@ -0,0 +1 @@
Add type hints to the caching module.

View file

@ -71,8 +71,10 @@ WELL_KNOWN_RETRY_ATTEMPTS = 3
logger = logging.getLogger(__name__)
_well_known_cache = TTLCache("well-known")
_had_valid_well_known_cache = TTLCache("had-valid-well-known")
_well_known_cache = TTLCache("well-known") # type: TTLCache[bytes, Optional[bytes]]
_had_valid_well_known_cache = TTLCache(
"had-valid-well-known"
) # type: TTLCache[bytes, bool]
@attr.s(slots=True, frozen=True)
@ -88,8 +90,8 @@ class WellKnownResolver:
reactor: IReactorTime,
agent: IAgent,
user_agent: bytes,
well_known_cache: Optional[TTLCache] = None,
had_well_known_cache: Optional[TTLCache] = None,
well_known_cache: Optional[TTLCache[bytes, Optional[bytes]]] = None,
had_well_known_cache: Optional[TTLCache[bytes, bool]] = None,
):
self._reactor = reactor
self._clock = Clock(reactor)

View file

@ -183,12 +183,13 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
requests state from the cache, if False we need to query the DB for the
missing state.
"""
is_all, known_absent, state_dict_ids = cache.get(group)
cache_entry = cache.get(group)
state_dict_ids = cache_entry.value
if is_all or state_filter.is_full():
if cache_entry.full or state_filter.is_full():
# Either we have everything or want everything, either way
# `is_all` tells us whether we've gotten everything.
return state_filter.filter_state(state_dict_ids), is_all
return state_filter.filter_state(state_dict_ids), cache_entry.full
# tracks whether any of our requested types are missing from the cache
missing_types = False
@ -202,7 +203,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
# There aren't any wild cards, so `concrete_types()` returns the
# complete list of event types we're wanting.
for key in state_filter.concrete_types():
if key not in state_dict_ids and key not in known_absent:
if key not in state_dict_ids and key not in cache_entry.known_absent:
missing_types = True
break

View file

@ -15,26 +15,38 @@
import enum
import logging
import threading
from collections import namedtuple
from typing import Any
from typing import Any, Dict, Generic, Iterable, Optional, Set, TypeVar
import attr
from synapse.util.caches.lrucache import LruCache
logger = logging.getLogger(__name__)
class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "known_absent", "value"))):
# The type of the cache keys.
KT = TypeVar("KT")
# The type of the dictionary keys.
DKT = TypeVar("DKT")
@attr.s(slots=True)
class DictionaryEntry:
"""Returned when getting an entry from the cache
Attributes:
full (bool): Whether the cache has the full or dict or just some keys.
full: Whether the cache has the full or dict or just some keys.
If not full then not all requested keys will necessarily be present
in `value`
known_absent (set): Keys that were looked up in the dict and were not
known_absent: Keys that were looked up in the dict and were not
there.
value (dict): The full or partial dict value
value: The full or partial dict value
"""
full = attr.ib(type=bool)
known_absent = attr.ib()
value = attr.ib()
def __len__(self):
return len(self.value)
@ -45,21 +57,21 @@ class _Sentinel(enum.Enum):
sentinel = object()
class DictionaryCache:
class DictionaryCache(Generic[KT, DKT]):
"""Caches key -> dictionary lookups, supporting caching partial dicts, i.e.
fetching a subset of dictionary keys for a particular key.
"""
def __init__(self, name, max_entries=1000):
def __init__(self, name: str, max_entries: int = 1000):
self.cache = LruCache(
max_size=max_entries, cache_name=name, size_callback=len
) # type: LruCache[Any, DictionaryEntry]
) # type: LruCache[KT, DictionaryEntry]
self.name = name
self.sequence = 0
self.thread = None
self.thread = None # type: Optional[threading.Thread]
def check_thread(self):
def check_thread(self) -> None:
expected_thread = self.thread
if expected_thread is None:
self.thread = threading.current_thread()
@ -69,12 +81,14 @@ class DictionaryCache:
"Cache objects can only be accessed from the main thread"
)
def get(self, key, dict_keys=None):
def get(
self, key: KT, dict_keys: Optional[Iterable[DKT]] = None
) -> DictionaryEntry:
"""Fetch an entry out of the cache
Args:
key
dict_key(list): If given a set of keys then return only those keys
dict_key: If given a set of keys then return only those keys
that exist in the cache.
Returns:
@ -95,7 +109,7 @@ class DictionaryCache:
return DictionaryEntry(False, set(), {})
def invalidate(self, key):
def invalidate(self, key: KT) -> None:
self.check_thread()
# Increment the sequence number so that any SELECT statements that
@ -103,19 +117,25 @@ class DictionaryCache:
self.sequence += 1
self.cache.pop(key, None)
def invalidate_all(self):
def invalidate_all(self) -> None:
self.check_thread()
self.sequence += 1
self.cache.clear()
def update(self, sequence, key, value, fetched_keys=None):
def update(
self,
sequence: int,
key: KT,
value: Dict[DKT, Any],
fetched_keys: Optional[Set[DKT]] = None,
) -> None:
"""Updates the entry in the cache
Args:
sequence
key (K)
value (dict[X,Y]): The value to update the cache with.
fetched_keys (None|set[X]): All of the dictionary keys which were
key
value: The value to update the cache with.
fetched_keys: All of the dictionary keys which were
fetched from the database.
If None, this is the complete value for key K. Otherwise, it
@ -131,7 +151,9 @@ class DictionaryCache:
else:
self._update_or_insert(key, value, fetched_keys)
def _update_or_insert(self, key, value, known_absent):
def _update_or_insert(
self, key: KT, value: Dict[DKT, Any], known_absent: Set[DKT]
) -> None:
# We pop and reinsert as we need to tell the cache the size may have
# changed
@ -140,5 +162,5 @@ class DictionaryCache:
entry.known_absent.update(known_absent)
self.cache[key] = entry
def _insert(self, key, value, known_absent):
def _insert(self, key: KT, value: Dict[DKT, Any], known_absent: Set[DKT]) -> None:
self.cache[key] = DictionaryEntry(True, known_absent, value)

View file

@ -15,6 +15,7 @@
import logging
import time
from typing import Any, Callable, Dict, Generic, Tuple, TypeVar, Union
import attr
from sortedcontainers import SortedList
@ -23,15 +24,19 @@ from synapse.util.caches import register_cache
logger = logging.getLogger(__name__)
SENTINEL = object()
SENTINEL = object() # type: Any
T = TypeVar("T")
KT = TypeVar("KT")
VT = TypeVar("VT")
class TTLCache:
class TTLCache(Generic[KT, VT]):
"""A key/value cache implementation where each entry has its own TTL"""
def __init__(self, cache_name, timer=time.time):
def __init__(self, cache_name: str, timer: Callable[[], float] = time.time):
# map from key to _CacheEntry
self._data = {}
self._data = {} # type: Dict[KT, _CacheEntry]
# the _CacheEntries, sorted by expiry time
self._expiry_list = SortedList() # type: SortedList[_CacheEntry]
@ -40,26 +45,27 @@ class TTLCache:
self._metrics = register_cache("ttl", cache_name, self, resizable=False)
def set(self, key, value, ttl):
def set(self, key: KT, value: VT, ttl: float) -> None:
"""Add/update an entry in the cache
Args:
key: key for this entry
value: value for this entry
ttl (float): TTL for this entry, in seconds
ttl: TTL for this entry, in seconds
"""
expiry = self._timer() + ttl
self.expire()
e = self._data.pop(key, SENTINEL)
if e != SENTINEL:
if e is not SENTINEL:
assert isinstance(e, _CacheEntry)
self._expiry_list.remove(e)
entry = _CacheEntry(expiry_time=expiry, ttl=ttl, key=key, value=value)
self._data[key] = entry
self._expiry_list.add(entry)
def get(self, key, default=SENTINEL):
def get(self, key: KT, default: T = SENTINEL) -> Union[VT, T]:
"""Get a value from the cache
Args:
@ -72,23 +78,23 @@ class TTLCache:
"""
self.expire()
e = self._data.get(key, SENTINEL)
if e == SENTINEL:
if e is SENTINEL:
self._metrics.inc_misses()
if default == SENTINEL:
if default is SENTINEL:
raise KeyError(key)
return default
assert isinstance(e, _CacheEntry)
self._metrics.inc_hits()
return e.value
def get_with_expiry(self, key):
def get_with_expiry(self, key: KT) -> Tuple[VT, float, float]:
"""Get a value, and its expiry time, from the cache
Args:
key: key to look up
Returns:
Tuple[Any, float, float]: the value from the cache, the expiry time
and the TTL
A tuple of the value from the cache, the expiry time and the TTL
Raises:
KeyError if the entry is not found
@ -102,7 +108,7 @@ class TTLCache:
self._metrics.inc_hits()
return e.value, e.expiry_time, e.ttl
def pop(self, key, default=SENTINEL):
def pop(self, key: KT, default: T = SENTINEL) -> Union[VT, T]: # type: ignore
"""Remove a value from the cache
If key is in the cache, remove it and return its value, else return default.
@ -118,29 +124,30 @@ class TTLCache:
"""
self.expire()
e = self._data.pop(key, SENTINEL)
if e == SENTINEL:
if e is SENTINEL:
self._metrics.inc_misses()
if default == SENTINEL:
if default is SENTINEL:
raise KeyError(key)
return default
assert isinstance(e, _CacheEntry)
self._expiry_list.remove(e)
self._metrics.inc_hits()
return e.value
def __getitem__(self, key):
def __getitem__(self, key: KT) -> VT:
return self.get(key)
def __delitem__(self, key):
def __delitem__(self, key: KT) -> None:
self.pop(key)
def __contains__(self, key):
def __contains__(self, key: KT) -> bool:
return key in self._data
def __len__(self):
def __len__(self) -> int:
self.expire()
return len(self._data)
def expire(self):
def expire(self) -> None:
"""Run the expiry on the cache. Any entries whose expiry times are due will
be removed
"""
@ -158,7 +165,7 @@ class _CacheEntry:
"""TTLCache entry"""
# expiry_time is the first attribute, so that entries are sorted by expiry.
expiry_time = attr.ib()
ttl = attr.ib()
expiry_time = attr.ib(type=float)
ttl = attr.ib(type=float)
key = attr.ib()
value = attr.ib()

View file

@ -377,14 +377,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
#######################################################
# deliberately remove e2 (room name) from the _state_group_cache
(
is_all,
known_absent,
state_dict_ids,
) = self.state_datastore._state_group_cache.get(group)
cache_entry = self.state_datastore._state_group_cache.get(group)
state_dict_ids = cache_entry.value
self.assertEqual(is_all, True)
self.assertEqual(known_absent, set())
self.assertEqual(cache_entry.full, True)
self.assertEqual(cache_entry.known_absent, set())
self.assertDictEqual(
state_dict_ids,
{
@ -403,14 +400,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
fetched_keys=((e1.type, e1.state_key),),
)
(
is_all,
known_absent,
state_dict_ids,
) = self.state_datastore._state_group_cache.get(group)
cache_entry = self.state_datastore._state_group_cache.get(group)
state_dict_ids = cache_entry.value
self.assertEqual(is_all, False)
self.assertEqual(known_absent, {(e1.type, e1.state_key)})
self.assertEqual(cache_entry.full, False)
self.assertEqual(cache_entry.known_absent, {(e1.type, e1.state_key)})
self.assertDictEqual(state_dict_ids, {(e1.type, e1.state_key): e1.event_id})
############################################

View file

@ -27,7 +27,9 @@ class DictCacheTestCase(unittest.TestCase):
key = "test_simple_cache_hit_full"
v = self.cache.get(key)
self.assertEqual((False, set(), {}), v)
self.assertIs(v.full, False)
self.assertEqual(v.known_absent, set())
self.assertEqual({}, v.value)
seq = self.cache.sequence
test_value = {"test": "test_simple_cache_hit_full"}