Add some more type annotations to Cache

This commit is contained in:
Richard van der Hoff 2020-10-14 19:40:53 +01:00
parent 629a951b49
commit 7eff59ec91
3 changed files with 62 additions and 24 deletions

View file

@ -26,7 +26,7 @@ class SlavedClientIpStore(BaseSlavedStore):
self.client_ip_last_seen = Cache(
name="client_ip_last_seen", keylen=4, max_entries=50000
)
) # type: Cache[tuple, int]
async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
now = int(self._clock.time_msec())

View file

@ -13,12 +13,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
import functools
import inspect
import logging
import threading
from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast
from typing import (
Any,
Callable,
Generic,
Iterable,
MutableMapping,
Optional,
Tuple,
TypeVar,
Union,
cast,
)
from weakref import WeakValueDictionary
from prometheus_client import Gauge
@ -38,6 +49,8 @@ logger = logging.getLogger(__name__)
CacheKey = Union[Tuple, Any]
F = TypeVar("F", bound=Callable[..., Any])
KT = TypeVar("KT")
VT = TypeVar("VT")
class _CachedFunction(Generic[F]):
@ -61,13 +74,19 @@ cache_pending_metric = Gauge(
["name"],
)
_CacheSentinel = object()
class _Sentinel(enum.Enum):
# defining a sentinel in this way allows mypy to correctly handle the
# type of a dictionary lookup.
sentinel = object()
class CacheEntry:
__slots__ = ["deferred", "callbacks", "invalidated"]
def __init__(self, deferred, callbacks):
def __init__(
self, deferred: ObservableDeferred, callbacks: Iterable[Callable[[], None]]
):
self.deferred = deferred
self.callbacks = set(callbacks)
self.invalidated = False
@ -80,7 +99,13 @@ class CacheEntry:
self.callbacks.clear()
class Cache:
class Cache(Generic[KT, VT]):
"""Wraps an LruCache, adding support for Deferred results.
It expects that each entry added with set() will be a Deferred; likewise get()
may return an ObservableDeferred.
"""
__slots__ = (
"cache",
"name",
@ -103,19 +128,23 @@ class Cache:
Args:
name: The name of the cache
max_entries: Maximum amount of entries that the cache will hold
keylen: The length of the tuple used as the cache key
keylen: The length of the tuple used as the cache key. Ignored unless
`tree` is True.
tree: Use a TreeCache instead of a dict as the underlying cache type
iterable: If True, count each item in the cached object as an entry,
rather than each cached object
apply_cache_factor_from_config: Whether cache factors specified in the
config file affect `max_entries`
Returns:
Cache
"""
cache_type = TreeCache if tree else dict
self._pending_deferred_cache = cache_type()
# _pending_deferred_cache maps from the key value to a `CacheEntry` object.
self._pending_deferred_cache = (
cache_type()
) # type: MutableMapping[KT, CacheEntry]
# cache is used for completed results and maps to the result itself, rather than
# a Deferred.
self.cache = LruCache(
max_size=max_entries,
keylen=keylen,
@ -155,7 +184,13 @@ class Cache:
"Cache objects can only be accessed from the main thread"
)
def get(self, key, default=_CacheSentinel, callback=None, update_metrics=True):
def get(
self,
key: KT,
default=_Sentinel.sentinel,
callback: Optional[Callable[[], None]] = None,
update_metrics: bool = True,
):
"""Looks the key up in the caches.
Args:
@ -166,30 +201,32 @@ class Cache:
update_metrics (bool): whether to update the cache hit rate metrics
Returns:
Either an ObservableDeferred or the raw result
Either an ObservableDeferred or the result itself
"""
callbacks = [callback] if callback else []
val = self._pending_deferred_cache.get(key, _CacheSentinel)
if val is not _CacheSentinel:
val = self._pending_deferred_cache.get(key, _Sentinel.sentinel)
if val is not _Sentinel.sentinel:
val.callbacks.update(callbacks)
if update_metrics:
self.metrics.inc_hits()
return val.deferred
val = self.cache.get(key, _CacheSentinel, callbacks=callbacks)
if val is not _CacheSentinel:
val = self.cache.get(key, _Sentinel.sentinel, callbacks=callbacks)
if val is not _Sentinel.sentinel:
self.metrics.inc_hits()
return val
if update_metrics:
self.metrics.inc_misses()
if default is _CacheSentinel:
if default is _Sentinel.sentinel:
raise KeyError()
else:
return default
def set(self, key, value, callback=None):
def set(
self, key: KT, value: defer.Deferred, callback: Optional[Callable[[], None]] = None
) -> ObservableDeferred:
if not isinstance(value, defer.Deferred):
raise TypeError("not a Deferred")
@ -248,7 +285,7 @@ class Cache:
observer.addCallbacks(cb, eb)
return observable
def prefill(self, key, value, callback=None):
def prefill(self, key: KT, value: VT, callback: Callable[[], None] = None):
callbacks = [callback] if callback else []
self.cache.set(key, value, callbacks=callbacks)
@ -267,7 +304,7 @@ class Cache:
if entry:
entry.invalidate()
def invalidate_many(self, key):
def invalidate_many(self, key: KT):
self.check_thread()
if not isinstance(key, tuple):
raise TypeError("The cache key must be a tuple not %r" % (type(key),))
@ -275,7 +312,7 @@ class Cache:
# if we have a pending lookup for this key, remove it from the
# _pending_deferred_cache, as above
entry_dict = self._pending_deferred_cache.pop(key, None)
entry_dict = self._pending_deferred_cache.pop(cast(KT, key), None)
if entry_dict is not None:
for entry in iterate_tree_cache_entry(entry_dict):
entry.invalidate()
@ -396,7 +433,7 @@ class CacheDescriptor(_CacheDescriptorBase):
keylen=self.num_args,
tree=self.tree,
iterable=self.iterable,
)
) # type: Cache[Tuple, Any]
def get_cache_key_gen(args, kwargs):
"""Given some args/kwargs return a generator that resolves into

View file

@ -64,7 +64,8 @@ class LruCache:
Args:
max_size: The maximum amount of entries the cache can hold
keylen: The length of the tuple used as the cache key
keylen: The length of the tuple used as the cache key. Ignored unless
cache_type is `TreeCache`.
cache_type (type):
type of underlying cache to be used. Typically one of dict