forked from MirrorHub/synapse
Fix caching of remote servers' signature keys
The `@cached` decorator on `KeyStore._get_server_verify_key` was missing its `num_args` parameter, which meant that it was returning the wrong key for any server which had more than one recorded key. By way of a fix, change the default for `num_args` to be *all* arguments. To implement that, factor out a common base class for `CacheDescriptor` and `CacheListDescriptor`.
This commit is contained in:
parent
37a187bfab
commit
95f21c7a66
4 changed files with 226 additions and 64 deletions
|
@ -189,7 +189,55 @@ class Cache(object):
|
||||||
self.cache.clear()
|
self.cache.clear()
|
||||||
|
|
||||||
|
|
||||||
class CacheDescriptor(object):
|
class _CacheDescriptorBase(object):
|
||||||
|
def __init__(self, orig, num_args, inlineCallbacks, cache_context=False):
|
||||||
|
self.orig = orig
|
||||||
|
|
||||||
|
if inlineCallbacks:
|
||||||
|
self.function_to_call = defer.inlineCallbacks(orig)
|
||||||
|
else:
|
||||||
|
self.function_to_call = orig
|
||||||
|
|
||||||
|
arg_spec = inspect.getargspec(orig)
|
||||||
|
all_args = arg_spec.args
|
||||||
|
|
||||||
|
if "cache_context" in all_args:
|
||||||
|
if not cache_context:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot have a 'cache_context' arg without setting"
|
||||||
|
" cache_context=True"
|
||||||
|
)
|
||||||
|
elif cache_context:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot have cache_context=True without having an arg"
|
||||||
|
" named `cache_context`"
|
||||||
|
)
|
||||||
|
|
||||||
|
if num_args is None:
|
||||||
|
num_args = len(all_args) - 1
|
||||||
|
if cache_context:
|
||||||
|
num_args -= 1
|
||||||
|
|
||||||
|
if len(all_args) < num_args + 1:
|
||||||
|
raise Exception(
|
||||||
|
"Not enough explicit positional arguments to key off for %r: "
|
||||||
|
"got %i args, but wanted %i. (@cached cannot key off *args or "
|
||||||
|
"**kwargs)"
|
||||||
|
% (orig.__name__, len(all_args), num_args)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.num_args = num_args
|
||||||
|
self.arg_names = all_args[1:num_args + 1]
|
||||||
|
|
||||||
|
if "cache_context" in self.arg_names:
|
||||||
|
raise Exception(
|
||||||
|
"cache_context arg cannot be included among the cache keys"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.add_cache_context = cache_context
|
||||||
|
|
||||||
|
|
||||||
|
class CacheDescriptor(_CacheDescriptorBase):
|
||||||
""" A method decorator that applies a memoizing cache around the function.
|
""" A method decorator that applies a memoizing cache around the function.
|
||||||
|
|
||||||
This caches deferreds, rather than the results themselves. Deferreds that
|
This caches deferreds, rather than the results themselves. Deferreds that
|
||||||
|
@ -217,52 +265,24 @@ class CacheDescriptor(object):
|
||||||
r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate)
|
r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate)
|
||||||
defer.returnValue(r1 + r2)
|
defer.returnValue(r1 + r2)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_args (int): number of positional arguments (excluding ``self`` and
|
||||||
|
``cache_context``) to use as cache keys. Defaults to all named
|
||||||
|
args of the function.
|
||||||
"""
|
"""
|
||||||
def __init__(self, orig, max_entries=1000, num_args=1, tree=False,
|
def __init__(self, orig, max_entries=1000, num_args=None, tree=False,
|
||||||
inlineCallbacks=False, cache_context=False, iterable=False):
|
inlineCallbacks=False, cache_context=False, iterable=False):
|
||||||
|
|
||||||
|
super(CacheDescriptor, self).__init__(
|
||||||
|
orig, num_args=num_args, inlineCallbacks=inlineCallbacks,
|
||||||
|
cache_context=cache_context)
|
||||||
|
|
||||||
max_entries = int(max_entries * CACHE_SIZE_FACTOR)
|
max_entries = int(max_entries * CACHE_SIZE_FACTOR)
|
||||||
|
|
||||||
self.orig = orig
|
|
||||||
|
|
||||||
if inlineCallbacks:
|
|
||||||
self.function_to_call = defer.inlineCallbacks(orig)
|
|
||||||
else:
|
|
||||||
self.function_to_call = orig
|
|
||||||
|
|
||||||
self.max_entries = max_entries
|
self.max_entries = max_entries
|
||||||
self.num_args = num_args
|
|
||||||
self.tree = tree
|
self.tree = tree
|
||||||
|
|
||||||
self.iterable = iterable
|
self.iterable = iterable
|
||||||
|
|
||||||
all_args = inspect.getargspec(orig)
|
|
||||||
self.arg_names = all_args.args[1:num_args + 1]
|
|
||||||
|
|
||||||
if "cache_context" in all_args.args:
|
|
||||||
if not cache_context:
|
|
||||||
raise ValueError(
|
|
||||||
"Cannot have a 'cache_context' arg without setting"
|
|
||||||
" cache_context=True"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
self.arg_names.remove("cache_context")
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
elif cache_context:
|
|
||||||
raise ValueError(
|
|
||||||
"Cannot have cache_context=True without having an arg"
|
|
||||||
" named `cache_context`"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.add_cache_context = cache_context
|
|
||||||
|
|
||||||
if len(self.arg_names) < self.num_args:
|
|
||||||
raise Exception(
|
|
||||||
"Not enough explicit positional arguments to key off of for %r."
|
|
||||||
" (@cached cannot key off of *args or **kwargs)"
|
|
||||||
% (orig.__name__,)
|
|
||||||
)
|
|
||||||
|
|
||||||
def __get__(self, obj, objtype=None):
|
def __get__(self, obj, objtype=None):
|
||||||
cache = Cache(
|
cache = Cache(
|
||||||
name=self.orig.__name__,
|
name=self.orig.__name__,
|
||||||
|
@ -338,48 +358,36 @@ class CacheDescriptor(object):
|
||||||
return wrapped
|
return wrapped
|
||||||
|
|
||||||
|
|
||||||
class CacheListDescriptor(object):
|
class CacheListDescriptor(_CacheDescriptorBase):
|
||||||
"""Wraps an existing cache to support bulk fetching of keys.
|
"""Wraps an existing cache to support bulk fetching of keys.
|
||||||
|
|
||||||
Given a list of keys it looks in the cache to find any hits, then passes
|
Given a list of keys it looks in the cache to find any hits, then passes
|
||||||
the list of missing keys to the wrapped fucntion.
|
the list of missing keys to the wrapped fucntion.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, orig, cached_method_name, list_name, num_args=1,
|
def __init__(self, orig, cached_method_name, list_name, num_args=None,
|
||||||
inlineCallbacks=False):
|
inlineCallbacks=False):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
orig (function)
|
orig (function)
|
||||||
method_name (str); The name of the chached method.
|
cached_method_name (str): The name of the chached method.
|
||||||
list_name (str): Name of the argument which is the bulk lookup list
|
list_name (str): Name of the argument which is the bulk lookup list
|
||||||
num_args (int)
|
num_args (int): number of positional arguments (excluding ``self``,
|
||||||
|
but including list_name) to use as cache keys. Defaults to all
|
||||||
|
named args of the function.
|
||||||
inlineCallbacks (bool): Whether orig is a generator that should
|
inlineCallbacks (bool): Whether orig is a generator that should
|
||||||
be wrapped by defer.inlineCallbacks
|
be wrapped by defer.inlineCallbacks
|
||||||
"""
|
"""
|
||||||
self.orig = orig
|
super(CacheListDescriptor, self).__init__(
|
||||||
|
orig, num_args=num_args, inlineCallbacks=inlineCallbacks)
|
||||||
|
|
||||||
if inlineCallbacks:
|
|
||||||
self.function_to_call = defer.inlineCallbacks(orig)
|
|
||||||
else:
|
|
||||||
self.function_to_call = orig
|
|
||||||
|
|
||||||
self.num_args = num_args
|
|
||||||
self.list_name = list_name
|
self.list_name = list_name
|
||||||
|
|
||||||
self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
|
|
||||||
self.list_pos = self.arg_names.index(self.list_name)
|
self.list_pos = self.arg_names.index(self.list_name)
|
||||||
|
|
||||||
self.cached_method_name = cached_method_name
|
self.cached_method_name = cached_method_name
|
||||||
|
|
||||||
self.sentinel = object()
|
self.sentinel = object()
|
||||||
|
|
||||||
if len(self.arg_names) < self.num_args:
|
|
||||||
raise Exception(
|
|
||||||
"Not enough explicit positional arguments to key off of for %r."
|
|
||||||
" (@cached cannot key off of *args or **kwars)"
|
|
||||||
% (orig.__name__,)
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.list_name not in self.arg_names:
|
if self.list_name not in self.arg_names:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"Couldn't see arguments %r for %r."
|
"Couldn't see arguments %r for %r."
|
||||||
|
@ -487,7 +495,7 @@ class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
|
||||||
self.cache.invalidate(self.key)
|
self.cache.invalidate(self.key)
|
||||||
|
|
||||||
|
|
||||||
def cached(max_entries=1000, num_args=1, tree=False, cache_context=False,
|
def cached(max_entries=1000, num_args=None, tree=False, cache_context=False,
|
||||||
iterable=False):
|
iterable=False):
|
||||||
return lambda orig: CacheDescriptor(
|
return lambda orig: CacheDescriptor(
|
||||||
orig,
|
orig,
|
||||||
|
@ -499,8 +507,8 @@ def cached(max_entries=1000, num_args=1, tree=False, cache_context=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False,
|
def cachedInlineCallbacks(max_entries=1000, num_args=None, tree=False,
|
||||||
iterable=False):
|
cache_context=False, iterable=False):
|
||||||
return lambda orig: CacheDescriptor(
|
return lambda orig: CacheDescriptor(
|
||||||
orig,
|
orig,
|
||||||
max_entries=max_entries,
|
max_entries=max_entries,
|
||||||
|
@ -512,7 +520,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_contex
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False):
|
def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=False):
|
||||||
"""Creates a descriptor that wraps a function in a `CacheListDescriptor`.
|
"""Creates a descriptor that wraps a function in a `CacheListDescriptor`.
|
||||||
|
|
||||||
Used to do batch lookups for an already created cache. A single argument
|
Used to do batch lookups for an already created cache. A single argument
|
||||||
|
@ -525,7 +533,8 @@ def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False)
|
||||||
cache (Cache): The underlying cache to use.
|
cache (Cache): The underlying cache to use.
|
||||||
list_name (str): The name of the argument that is the list to use to
|
list_name (str): The name of the argument that is the list to use to
|
||||||
do batch lookups in the cache.
|
do batch lookups in the cache.
|
||||||
num_args (int): Number of arguments to use as the key in the cache.
|
num_args (int): Number of arguments to use as the key in the cache
|
||||||
|
(including list_name). Defaults to all named parameters.
|
||||||
inlineCallbacks (bool): Should the function be wrapped in an
|
inlineCallbacks (bool): Should the function be wrapped in an
|
||||||
`defer.inlineCallbacks`?
|
`defer.inlineCallbacks`?
|
||||||
|
|
||||||
|
|
53
tests/storage/test_keys.py
Normal file
53
tests/storage/test_keys.py
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2017 Vector Creations Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import signedjson.key
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
import tests.unittest
|
||||||
|
import tests.utils
|
||||||
|
|
||||||
|
|
||||||
|
class KeyStoreTestCase(tests.unittest.TestCase):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(KeyStoreTestCase, self).__init__(*args, **kwargs)
|
||||||
|
self.store = None # type: synapse.storage.keys.KeyStore
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def setUp(self):
|
||||||
|
hs = yield tests.utils.setup_test_homeserver()
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_get_server_verify_keys(self):
|
||||||
|
key1 = signedjson.key.decode_verify_key_base64(
|
||||||
|
"ed25519", "key1", "fP5l4JzpZPq/zdbBg5xx6lQGAAOM9/3w94cqiJ5jPrw"
|
||||||
|
)
|
||||||
|
key2 = signedjson.key.decode_verify_key_base64(
|
||||||
|
"ed25519", "key2", "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw"
|
||||||
|
)
|
||||||
|
yield self.store.store_server_verify_key(
|
||||||
|
"server1", "from_server", 0, key1
|
||||||
|
)
|
||||||
|
yield self.store.store_server_verify_key(
|
||||||
|
"server1", "from_server", 0, key2
|
||||||
|
)
|
||||||
|
|
||||||
|
res = yield self.store.get_server_verify_keys(
|
||||||
|
"server1", ["ed25519:key1", "ed25519:key2", "ed25519:key3"])
|
||||||
|
|
||||||
|
self.assertEqual(len(res.keys()), 2)
|
||||||
|
self.assertEqual(res["ed25519:key1"].version, "key1")
|
||||||
|
self.assertEqual(res["ed25519:key2"].version, "key2")
|
14
tests/util/caches/__init__.py
Normal file
14
tests/util/caches/__init__.py
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2017 Vector Creations Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
86
tests/util/caches/test_descriptors.py
Normal file
86
tests/util/caches/test_descriptors.py
Normal file
|
@ -0,0 +1,86 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2016 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import mock
|
||||||
|
from twisted.internet import defer
|
||||||
|
from synapse.util.caches import descriptors
|
||||||
|
from tests import unittest
|
||||||
|
|
||||||
|
|
||||||
|
class DescriptorTestCase(unittest.TestCase):
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_cache(self):
|
||||||
|
class Cls(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.mock = mock.Mock()
|
||||||
|
|
||||||
|
@descriptors.cached()
|
||||||
|
def fn(self, arg1, arg2):
|
||||||
|
return self.mock(arg1, arg2)
|
||||||
|
|
||||||
|
obj = Cls()
|
||||||
|
|
||||||
|
obj.mock.return_value = 'fish'
|
||||||
|
r = yield obj.fn(1, 2)
|
||||||
|
self.assertEqual(r, 'fish')
|
||||||
|
obj.mock.assert_called_once_with(1, 2)
|
||||||
|
obj.mock.reset_mock()
|
||||||
|
|
||||||
|
# a call with different params should call the mock again
|
||||||
|
obj.mock.return_value = 'chips'
|
||||||
|
r = yield obj.fn(1, 3)
|
||||||
|
self.assertEqual(r, 'chips')
|
||||||
|
obj.mock.assert_called_once_with(1, 3)
|
||||||
|
obj.mock.reset_mock()
|
||||||
|
|
||||||
|
# the two values should now be cached
|
||||||
|
r = yield obj.fn(1, 2)
|
||||||
|
self.assertEqual(r, 'fish')
|
||||||
|
r = yield obj.fn(1, 3)
|
||||||
|
self.assertEqual(r, 'chips')
|
||||||
|
obj.mock.assert_not_called()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_cache_num_args(self):
|
||||||
|
"""Only the first num_args arguments should matter to the cache"""
|
||||||
|
|
||||||
|
class Cls(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.mock = mock.Mock()
|
||||||
|
|
||||||
|
@descriptors.cached(num_args=1)
|
||||||
|
def fn(self, arg1, arg2):
|
||||||
|
return self.mock(arg1, arg2)
|
||||||
|
|
||||||
|
obj = Cls()
|
||||||
|
obj.mock.return_value = 'fish'
|
||||||
|
r = yield obj.fn(1, 2)
|
||||||
|
self.assertEqual(r, 'fish')
|
||||||
|
obj.mock.assert_called_once_with(1, 2)
|
||||||
|
obj.mock.reset_mock()
|
||||||
|
|
||||||
|
# a call with different params should call the mock again
|
||||||
|
obj.mock.return_value = 'chips'
|
||||||
|
r = yield obj.fn(2, 3)
|
||||||
|
self.assertEqual(r, 'chips')
|
||||||
|
obj.mock.assert_called_once_with(2, 3)
|
||||||
|
obj.mock.reset_mock()
|
||||||
|
|
||||||
|
# the two values should now be cached; we should be able to vary
|
||||||
|
# the second argument and still get the cached result.
|
||||||
|
r = yield obj.fn(1, 4)
|
||||||
|
self.assertEqual(r, 'fish')
|
||||||
|
r = yield obj.fn(2, 5)
|
||||||
|
self.assertEqual(r, 'chips')
|
||||||
|
obj.mock.assert_not_called()
|
Loading…
Reference in a new issue