Add type hints to expiring cache. (#9730)

This commit is contained in:
Patrick Cloke 2021-04-06 08:58:18 -04:00 committed by GitHub
parent 024f121b74
commit 44bb881096
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 65 additions and 54 deletions

1
changelog.d/9730.misc Normal file
View file

@ -0,0 +1 @@
Add type hints to expiring cache.

View file

@ -102,7 +102,7 @@ class FederationClient(FederationBase):
max_len=1000, max_len=1000,
expiry_ms=120 * 1000, expiry_ms=120 * 1000,
reset_expiry_on_get=False, reset_expiry_on_get=False,
) ) # type: ExpiringCache[str, EventBase]
def _clear_tried_cache(self): def _clear_tried_cache(self):
"""Clear pdu_destination_tried cache""" """Clear pdu_destination_tried cache"""

View file

@ -631,7 +631,7 @@ class DeviceListUpdater:
max_len=10000, max_len=10000,
expiry_ms=30 * 60 * 1000, expiry_ms=30 * 60 * 1000,
iterable=True, iterable=True,
) ) # type: ExpiringCache[str, Set[str]]
# Attempt to resync out of sync device lists every 30s. # Attempt to resync out of sync device lists every 30s.
self._resync_retry_in_progress = False self._resync_retry_in_progress = False
@ -760,7 +760,7 @@ class DeviceListUpdater:
"""Given a list of updates for a user figure out if we need to do a full """Given a list of updates for a user figure out if we need to do a full
resync, or whether we have enough data that we can just apply the delta. resync, or whether we have enough data that we can just apply the delta.
""" """
seen_updates = self._seen_updates.get(user_id, set()) seen_updates = self._seen_updates.get(user_id, set()) # type: Set[str]
extremity = await self.store.get_device_list_last_stream_id_for_remote(user_id) extremity = await self.store.get_device_list_last_stream_id_for_remote(user_id)

View file

@ -38,7 +38,6 @@ from synapse.types import (
) )
from synapse.util import json_decoder, unwrapFirstError from synapse.util import json_decoder, unwrapFirstError
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
if TYPE_CHECKING: if TYPE_CHECKING:
@ -1292,17 +1291,6 @@ class SigningKeyEduUpdater:
# user_id -> list of updates waiting to be handled. # user_id -> list of updates waiting to be handled.
self._pending_updates = {} # type: Dict[str, List[Tuple[JsonDict, JsonDict]]] self._pending_updates = {} # type: Dict[str, List[Tuple[JsonDict, JsonDict]]]
# Recently seen stream ids. We don't bother keeping these in the DB,
# but they're useful to have them about to reduce the number of spurious
# resyncs.
self._seen_updates = ExpiringCache(
cache_name="signing_key_update_edu",
clock=self.clock,
max_len=10000,
expiry_ms=30 * 60 * 1000,
iterable=True,
)
async def incoming_signing_key_update( async def incoming_signing_key_update(
self, origin: str, edu_content: JsonDict self, origin: str, edu_content: JsonDict
) -> None: ) -> None:

View file

@ -252,13 +252,13 @@ class SyncHandler:
self.storage = hs.get_storage() self.storage = hs.get_storage()
self.state_store = self.storage.state self.state_store = self.storage.state
# ExpiringCache((User, Device)) -> LruCache(state_key => event_id) # ExpiringCache((User, Device)) -> LruCache(user_id => event_id)
self.lazy_loaded_members_cache = ExpiringCache( self.lazy_loaded_members_cache = ExpiringCache(
"lazy_loaded_members_cache", "lazy_loaded_members_cache",
self.clock, self.clock,
max_len=0, max_len=0,
expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE, expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE,
) ) # type: ExpiringCache[Tuple[str, Optional[str]], LruCache[str, str]]
async def wait_for_sync_for_user( async def wait_for_sync_for_user(
self, self,
@ -733,8 +733,10 @@ class SyncHandler:
def get_lazy_loaded_members_cache( def get_lazy_loaded_members_cache(
self, cache_key: Tuple[str, Optional[str]] self, cache_key: Tuple[str, Optional[str]]
) -> LruCache: ) -> LruCache[str, str]:
cache = self.lazy_loaded_members_cache.get(cache_key) cache = self.lazy_loaded_members_cache.get(
cache_key
) # type: Optional[LruCache[str, str]]
if cache is None: if cache is None:
logger.debug("creating LruCache for %r", cache_key) logger.debug("creating LruCache for %r", cache_key)
cache = LruCache(LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE) cache = LruCache(LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE)

View file

@ -175,7 +175,7 @@ class PreviewUrlResource(DirectServeJsonResource):
clock=self.clock, clock=self.clock,
# don't spider URLs more often than once an hour # don't spider URLs more often than once an hour
expiry_ms=ONE_HOUR, expiry_ms=ONE_HOUR,
) ) # type: ExpiringCache[str, ObservableDeferred]
if self._worker_run_media_background_jobs: if self._worker_run_media_background_jobs:
self._cleaner_loop = self.clock.looping_call( self._cleaner_loop = self.clock.looping_call(

View file

@ -22,6 +22,7 @@ from typing import (
Callable, Callable,
DefaultDict, DefaultDict,
Dict, Dict,
FrozenSet,
Iterable, Iterable,
List, List,
Optional, Optional,
@ -515,7 +516,7 @@ class StateResolutionHandler:
expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000, expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
iterable=True, iterable=True,
reset_expiry_on_get=True, reset_expiry_on_get=True,
) ) # type: ExpiringCache[FrozenSet[int], _StateCacheEntry]
# #
# stuff for tracking time spent on state-res by room # stuff for tracking time spent on state-res by room
@ -536,7 +537,7 @@ class StateResolutionHandler:
state_groups_ids: Dict[int, StateMap[str]], state_groups_ids: Dict[int, StateMap[str]],
event_map: Optional[Dict[str, EventBase]], event_map: Optional[Dict[str, EventBase]],
state_res_store: "StateResolutionStore", state_res_store: "StateResolutionStore",
): ) -> _StateCacheEntry:
"""Resolves conflicts between a set of state groups """Resolves conflicts between a set of state groups
Always generates a new state group (unless we hit the cache), so should Always generates a new state group (unless we hit the cache), so should

View file

@ -15,40 +15,50 @@
import logging import logging
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Generic, Optional, TypeVar, Union, overload
import attr
from typing_extensions import Literal
from synapse.config import cache as cache_config from synapse.config import cache as cache_config
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util import Clock
from synapse.util.caches import register_cache from synapse.util.caches import register_cache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SENTINEL = object() SENTINEL = object() # type: Any
class ExpiringCache: T = TypeVar("T")
KT = TypeVar("KT")
VT = TypeVar("VT")
class ExpiringCache(Generic[KT, VT]):
def __init__( def __init__(
self, self,
cache_name, cache_name: str,
clock, clock: Clock,
max_len=0, max_len: int = 0,
expiry_ms=0, expiry_ms: int = 0,
reset_expiry_on_get=False, reset_expiry_on_get: bool = False,
iterable=False, iterable: bool = False,
): ):
""" """
Args: Args:
cache_name (str): Name of this cache, used for logging. cache_name: Name of this cache, used for logging.
clock (Clock) clock
max_len (int): Max size of dict. If the dict grows larger than this max_len: Max size of dict. If the dict grows larger than this
then the oldest items get automatically evicted. Default is 0, then the oldest items get automatically evicted. Default is 0,
which indicates there is no max limit. which indicates there is no max limit.
expiry_ms (int): How long before an item is evicted from the cache expiry_ms: How long before an item is evicted from the cache
in milliseconds. Default is 0, indicating items never get in milliseconds. Default is 0, indicating items never get
evicted based on time. evicted based on time.
reset_expiry_on_get (bool): If true, will reset the expiry time for reset_expiry_on_get: If true, will reset the expiry time for
an item on access. Defaults to False. an item on access. Defaults to False.
iterable (bool): If true, the size is calculated by summing the iterable: If true, the size is calculated by summing the
sizes of all entries, rather than the number of entries. sizes of all entries, rather than the number of entries.
""" """
self._cache_name = cache_name self._cache_name = cache_name
@ -62,7 +72,7 @@ class ExpiringCache:
self._expiry_ms = expiry_ms self._expiry_ms = expiry_ms
self._reset_expiry_on_get = reset_expiry_on_get self._reset_expiry_on_get = reset_expiry_on_get
self._cache = OrderedDict() self._cache = OrderedDict() # type: OrderedDict[KT, _CacheEntry]
self.iterable = iterable self.iterable = iterable
@ -79,12 +89,12 @@ class ExpiringCache:
self._clock.looping_call(f, self._expiry_ms / 2) self._clock.looping_call(f, self._expiry_ms / 2)
def __setitem__(self, key, value): def __setitem__(self, key: KT, value: VT) -> None:
now = self._clock.time_msec() now = self._clock.time_msec()
self._cache[key] = _CacheEntry(now, value) self._cache[key] = _CacheEntry(now, value)
self.evict() self.evict()
def evict(self): def evict(self) -> None:
# Evict if there are now too many items # Evict if there are now too many items
while self._max_size and len(self) > self._max_size: while self._max_size and len(self) > self._max_size:
_key, value = self._cache.popitem(last=False) _key, value = self._cache.popitem(last=False)
@ -93,7 +103,7 @@ class ExpiringCache:
else: else:
self.metrics.inc_evictions() self.metrics.inc_evictions()
def __getitem__(self, key): def __getitem__(self, key: KT) -> VT:
try: try:
entry = self._cache[key] entry = self._cache[key]
self.metrics.inc_hits() self.metrics.inc_hits()
@ -106,7 +116,7 @@ class ExpiringCache:
return entry.value return entry.value
def pop(self, key, default=SENTINEL): def pop(self, key: KT, default: T = SENTINEL) -> Union[VT, T]:
"""Removes and returns the value with the given key from the cache. """Removes and returns the value with the given key from the cache.
If the key isn't in the cache then `default` will be returned if If the key isn't in the cache then `default` will be returned if
@ -115,29 +125,40 @@ class ExpiringCache:
Identical functionality to `dict.pop(..)`. Identical functionality to `dict.pop(..)`.
""" """
value = self._cache.pop(key, default) value = self._cache.pop(key, SENTINEL)
# The key was not found.
if value is SENTINEL: if value is SENTINEL:
if default is SENTINEL:
raise KeyError(key) raise KeyError(key)
return default
return value return value.value
def __contains__(self, key): def __contains__(self, key: KT) -> bool:
return key in self._cache return key in self._cache
def get(self, key, default=None): @overload
def get(self, key: KT, default: Literal[None] = None) -> Optional[VT]:
...
@overload
def get(self, key: KT, default: T) -> Union[VT, T]:
...
def get(self, key: KT, default: Optional[T] = None) -> Union[VT, Optional[T]]:
try: try:
return self[key] return self[key]
except KeyError: except KeyError:
return default return default
def setdefault(self, key, value): def setdefault(self, key: KT, value: VT) -> VT:
try: try:
return self[key] return self[key]
except KeyError: except KeyError:
self[key] = value self[key] = value
return value return value
def _prune_cache(self): def _prune_cache(self) -> None:
if not self._expiry_ms: if not self._expiry_ms:
# zero expiry time means don't expire. This should never get called # zero expiry time means don't expire. This should never get called
# since we have this check in start too. # since we have this check in start too.
@ -166,7 +187,7 @@ class ExpiringCache:
len(self), len(self),
) )
def __len__(self): def __len__(self) -> int:
if self.iterable: if self.iterable:
return sum(len(entry.value) for entry in self._cache.values()) return sum(len(entry.value) for entry in self._cache.values())
else: else:
@ -190,9 +211,7 @@ class ExpiringCache:
return False return False
@attr.s(slots=True)
class _CacheEntry: class _CacheEntry:
__slots__ = ["time", "value"] time = attr.ib(type=int)
value = attr.ib()
def __init__(self, time, value):
self.time = time
self.value = value