Fix caching of remote servers' signature keys

The `@cached` decorator on `KeyStore._get_server_verify_key` was missing
its `num_args` parameter, which meant that it was returning the wrong key for
any server which had more than one recorded key.

By way of a fix, change the default for `num_args` to be *all* arguments. To
implement that, factor out a common base class for `CacheDescriptor` and `CacheListDescriptor`.
This commit is contained in:
Richard van der Hoff 2017-03-22 13:54:20 +00:00
parent 37a187bfab
commit 95f21c7a66
4 changed files with 226 additions and 64 deletions

View file

@ -189,7 +189,55 @@ class Cache(object):
self.cache.clear() self.cache.clear()
class CacheDescriptor(object): class _CacheDescriptorBase(object):
def __init__(self, orig, num_args, inlineCallbacks, cache_context=False):
self.orig = orig
if inlineCallbacks:
self.function_to_call = defer.inlineCallbacks(orig)
else:
self.function_to_call = orig
arg_spec = inspect.getargspec(orig)
all_args = arg_spec.args
if "cache_context" in all_args:
if not cache_context:
raise ValueError(
"Cannot have a 'cache_context' arg without setting"
" cache_context=True"
)
elif cache_context:
raise ValueError(
"Cannot have cache_context=True without having an arg"
" named `cache_context`"
)
if num_args is None:
num_args = len(all_args) - 1
if cache_context:
num_args -= 1
if len(all_args) < num_args + 1:
raise Exception(
"Not enough explicit positional arguments to key off for %r: "
"got %i args, but wanted %i. (@cached cannot key off *args or "
"**kwargs)"
% (orig.__name__, len(all_args), num_args)
)
self.num_args = num_args
self.arg_names = all_args[1:num_args + 1]
if "cache_context" in self.arg_names:
raise Exception(
"cache_context arg cannot be included among the cache keys"
)
self.add_cache_context = cache_context
class CacheDescriptor(_CacheDescriptorBase):
""" A method decorator that applies a memoizing cache around the function. """ A method decorator that applies a memoizing cache around the function.
This caches deferreds, rather than the results themselves. Deferreds that This caches deferreds, rather than the results themselves. Deferreds that
@ -217,52 +265,24 @@ class CacheDescriptor(object):
r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate) r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate)
defer.returnValue(r1 + r2) defer.returnValue(r1 + r2)
Args:
num_args (int): number of positional arguments (excluding ``self`` and
``cache_context``) to use as cache keys. Defaults to all named
args of the function.
""" """
def __init__(self, orig, max_entries=1000, num_args=1, tree=False, def __init__(self, orig, max_entries=1000, num_args=None, tree=False,
inlineCallbacks=False, cache_context=False, iterable=False): inlineCallbacks=False, cache_context=False, iterable=False):
super(CacheDescriptor, self).__init__(
orig, num_args=num_args, inlineCallbacks=inlineCallbacks,
cache_context=cache_context)
max_entries = int(max_entries * CACHE_SIZE_FACTOR) max_entries = int(max_entries * CACHE_SIZE_FACTOR)
self.orig = orig
if inlineCallbacks:
self.function_to_call = defer.inlineCallbacks(orig)
else:
self.function_to_call = orig
self.max_entries = max_entries self.max_entries = max_entries
self.num_args = num_args
self.tree = tree self.tree = tree
self.iterable = iterable self.iterable = iterable
all_args = inspect.getargspec(orig)
self.arg_names = all_args.args[1:num_args + 1]
if "cache_context" in all_args.args:
if not cache_context:
raise ValueError(
"Cannot have a 'cache_context' arg without setting"
" cache_context=True"
)
try:
self.arg_names.remove("cache_context")
except ValueError:
pass
elif cache_context:
raise ValueError(
"Cannot have cache_context=True without having an arg"
" named `cache_context`"
)
self.add_cache_context = cache_context
if len(self.arg_names) < self.num_args:
raise Exception(
"Not enough explicit positional arguments to key off of for %r."
" (@cached cannot key off of *args or **kwargs)"
% (orig.__name__,)
)
def __get__(self, obj, objtype=None): def __get__(self, obj, objtype=None):
cache = Cache( cache = Cache(
name=self.orig.__name__, name=self.orig.__name__,
@ -338,48 +358,36 @@ class CacheDescriptor(object):
return wrapped return wrapped
class CacheListDescriptor(object): class CacheListDescriptor(_CacheDescriptorBase):
"""Wraps an existing cache to support bulk fetching of keys. """Wraps an existing cache to support bulk fetching of keys.
Given a list of keys it looks in the cache to find any hits, then passes Given a list of keys it looks in the cache to find any hits, then passes
the list of missing keys to the wrapped fucntion. the list of missing keys to the wrapped fucntion.
""" """
def __init__(self, orig, cached_method_name, list_name, num_args=1, def __init__(self, orig, cached_method_name, list_name, num_args=None,
inlineCallbacks=False): inlineCallbacks=False):
""" """
Args: Args:
orig (function) orig (function)
method_name (str); The name of the chached method. cached_method_name (str): The name of the chached method.
list_name (str): Name of the argument which is the bulk lookup list list_name (str): Name of the argument which is the bulk lookup list
num_args (int) num_args (int): number of positional arguments (excluding ``self``,
but including list_name) to use as cache keys. Defaults to all
named args of the function.
inlineCallbacks (bool): Whether orig is a generator that should inlineCallbacks (bool): Whether orig is a generator that should
be wrapped by defer.inlineCallbacks be wrapped by defer.inlineCallbacks
""" """
self.orig = orig super(CacheListDescriptor, self).__init__(
orig, num_args=num_args, inlineCallbacks=inlineCallbacks)
if inlineCallbacks:
self.function_to_call = defer.inlineCallbacks(orig)
else:
self.function_to_call = orig
self.num_args = num_args
self.list_name = list_name self.list_name = list_name
self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
self.list_pos = self.arg_names.index(self.list_name) self.list_pos = self.arg_names.index(self.list_name)
self.cached_method_name = cached_method_name self.cached_method_name = cached_method_name
self.sentinel = object() self.sentinel = object()
if len(self.arg_names) < self.num_args:
raise Exception(
"Not enough explicit positional arguments to key off of for %r."
" (@cached cannot key off of *args or **kwars)"
% (orig.__name__,)
)
if self.list_name not in self.arg_names: if self.list_name not in self.arg_names:
raise Exception( raise Exception(
"Couldn't see arguments %r for %r." "Couldn't see arguments %r for %r."
@ -487,7 +495,7 @@ class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
self.cache.invalidate(self.key) self.cache.invalidate(self.key)
def cached(max_entries=1000, num_args=1, tree=False, cache_context=False, def cached(max_entries=1000, num_args=None, tree=False, cache_context=False,
iterable=False): iterable=False):
return lambda orig: CacheDescriptor( return lambda orig: CacheDescriptor(
orig, orig,
@ -499,8 +507,8 @@ def cached(max_entries=1000, num_args=1, tree=False, cache_context=False,
) )
def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False, def cachedInlineCallbacks(max_entries=1000, num_args=None, tree=False,
iterable=False): cache_context=False, iterable=False):
return lambda orig: CacheDescriptor( return lambda orig: CacheDescriptor(
orig, orig,
max_entries=max_entries, max_entries=max_entries,
@ -512,7 +520,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_contex
) )
def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False): def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=False):
"""Creates a descriptor that wraps a function in a `CacheListDescriptor`. """Creates a descriptor that wraps a function in a `CacheListDescriptor`.
Used to do batch lookups for an already created cache. A single argument Used to do batch lookups for an already created cache. A single argument
@ -525,7 +533,8 @@ def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False)
cache (Cache): The underlying cache to use. cache (Cache): The underlying cache to use.
list_name (str): The name of the argument that is the list to use to list_name (str): The name of the argument that is the list to use to
do batch lookups in the cache. do batch lookups in the cache.
num_args (int): Number of arguments to use as the key in the cache. num_args (int): Number of arguments to use as the key in the cache
(including list_name). Defaults to all named parameters.
inlineCallbacks (bool): Should the function be wrapped in an inlineCallbacks (bool): Should the function be wrapped in an
`defer.inlineCallbacks`? `defer.inlineCallbacks`?

View file

@ -0,0 +1,53 @@
# -*- coding: utf-8 -*-
# Copyright 2017 Vector Creations Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import signedjson.key
from twisted.internet import defer
import tests.unittest
import tests.utils
class KeyStoreTestCase(tests.unittest.TestCase):
def __init__(self, *args, **kwargs):
super(KeyStoreTestCase, self).__init__(*args, **kwargs)
self.store = None # type: synapse.storage.keys.KeyStore
@defer.inlineCallbacks
def setUp(self):
hs = yield tests.utils.setup_test_homeserver()
self.store = hs.get_datastore()
@defer.inlineCallbacks
def test_get_server_verify_keys(self):
key1 = signedjson.key.decode_verify_key_base64(
"ed25519", "key1", "fP5l4JzpZPq/zdbBg5xx6lQGAAOM9/3w94cqiJ5jPrw"
)
key2 = signedjson.key.decode_verify_key_base64(
"ed25519", "key2", "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw"
)
yield self.store.store_server_verify_key(
"server1", "from_server", 0, key1
)
yield self.store.store_server_verify_key(
"server1", "from_server", 0, key2
)
res = yield self.store.get_server_verify_keys(
"server1", ["ed25519:key1", "ed25519:key2", "ed25519:key3"])
self.assertEqual(len(res.keys()), 2)
self.assertEqual(res["ed25519:key1"].version, "key1")
self.assertEqual(res["ed25519:key2"].version, "key2")

View file

@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
# Copyright 2017 Vector Creations Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -0,0 +1,86 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import mock
from twisted.internet import defer
from synapse.util.caches import descriptors
from tests import unittest
class DescriptorTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_cache(self):
class Cls(object):
def __init__(self):
self.mock = mock.Mock()
@descriptors.cached()
def fn(self, arg1, arg2):
return self.mock(arg1, arg2)
obj = Cls()
obj.mock.return_value = 'fish'
r = yield obj.fn(1, 2)
self.assertEqual(r, 'fish')
obj.mock.assert_called_once_with(1, 2)
obj.mock.reset_mock()
# a call with different params should call the mock again
obj.mock.return_value = 'chips'
r = yield obj.fn(1, 3)
self.assertEqual(r, 'chips')
obj.mock.assert_called_once_with(1, 3)
obj.mock.reset_mock()
# the two values should now be cached
r = yield obj.fn(1, 2)
self.assertEqual(r, 'fish')
r = yield obj.fn(1, 3)
self.assertEqual(r, 'chips')
obj.mock.assert_not_called()
@defer.inlineCallbacks
def test_cache_num_args(self):
"""Only the first num_args arguments should matter to the cache"""
class Cls(object):
def __init__(self):
self.mock = mock.Mock()
@descriptors.cached(num_args=1)
def fn(self, arg1, arg2):
return self.mock(arg1, arg2)
obj = Cls()
obj.mock.return_value = 'fish'
r = yield obj.fn(1, 2)
self.assertEqual(r, 'fish')
obj.mock.assert_called_once_with(1, 2)
obj.mock.reset_mock()
# a call with different params should call the mock again
obj.mock.return_value = 'chips'
r = yield obj.fn(2, 3)
self.assertEqual(r, 'chips')
obj.mock.assert_called_once_with(2, 3)
obj.mock.reset_mock()
# the two values should now be cached; we should be able to vary
# the second argument and still get the cached result.
r = yield obj.fn(1, 4)
self.assertEqual(r, 'fish')
r = yield obj.fn(2, 5)
self.assertEqual(r, 'chips')
obj.mock.assert_not_called()