Minor @cachedList enhancements (#9975)

- use a tuple rather than a list for the iterable that is passed into the
  wrapped function, for performance

- test that we can pass an iterable and that keys are correctly deduped.
This commit is contained in:
Richard van der Hoff 2021-05-14 11:12:36 +01:00 committed by GitHub
parent 52ed9655ed
commit 5090f26b63
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 31 additions and 20 deletions

1
changelog.d/9975.misc Normal file
View file

@ -0,0 +1 @@
Minor enhancements to the `@cachedList` descriptor.

View file

@ -665,7 +665,7 @@ class DeviceWorkerStore(SQLBaseStore):
cached_method_name="get_device_list_last_stream_id_for_remote", cached_method_name="get_device_list_last_stream_id_for_remote",
list_name="user_ids", list_name="user_ids",
) )
async def get_device_list_last_stream_id_for_remotes(self, user_ids: str): async def get_device_list_last_stream_id_for_remotes(self, user_ids: Iterable[str]):
rows = await self.db_pool.simple_select_many_batch( rows = await self.db_pool.simple_select_many_batch(
table="device_lists_remote_extremeties", table="device_lists_remote_extremeties",
column="user_id", column="user_id",

View file

@ -473,7 +473,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
num_args=1, num_args=1,
) )
async def _get_bare_e2e_cross_signing_keys_bulk( async def _get_bare_e2e_cross_signing_keys_bulk(
self, user_ids: List[str] self, user_ids: Iterable[str]
) -> Dict[str, Dict[str, dict]]: ) -> Dict[str, Dict[str, dict]]:
"""Returns the cross-signing keys for a set of users. The output of this """Returns the cross-signing keys for a set of users. The output of this
function should be passed to _get_e2e_cross_signing_signatures_txn if function should be passed to _get_e2e_cross_signing_signatures_txn if
@ -497,7 +497,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
def _get_bare_e2e_cross_signing_keys_bulk_txn( def _get_bare_e2e_cross_signing_keys_bulk_txn(
self, self,
txn: Connection, txn: Connection,
user_ids: List[str], user_ids: Iterable[str],
) -> Dict[str, Dict[str, dict]]: ) -> Dict[str, Dict[str, dict]]:
"""Returns the cross-signing keys for a set of users. The output of this """Returns the cross-signing keys for a set of users. The output of this
function should be passed to _get_e2e_cross_signing_signatures_txn if function should be passed to _get_e2e_cross_signing_signatures_txn if

View file

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Dict, Iterable
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
@ -37,21 +39,16 @@ class UserErasureWorkerStore(SQLBaseStore):
return bool(result) return bool(result)
@cachedList(cached_method_name="is_user_erased", list_name="user_ids") @cachedList(cached_method_name="is_user_erased", list_name="user_ids")
async def are_users_erased(self, user_ids): async def are_users_erased(self, user_ids: Iterable[str]) -> Dict[str, bool]:
""" """
Checks which users in a list have requested erasure Checks which users in a list have requested erasure
Args: Args:
user_ids (iterable[str]): full user id to check user_ids: full user ids to check
Returns: Returns:
dict[str, bool]: for each user, whether the user has requested erasure.
for each user, whether the user has requested erasure.
""" """
# this serves the dual purpose of (a) making sure we can do len and
# iterate it multiple times, and (b) avoiding duplicates.
user_ids = tuple(set(user_ids))
rows = await self.db_pool.simple_select_many_batch( rows = await self.db_pool.simple_select_many_batch(
table="erased_users", table="erased_users",
column="user_id", column="user_id",

View file

@ -322,8 +322,8 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
class DeferredCacheListDescriptor(_CacheDescriptorBase): class DeferredCacheListDescriptor(_CacheDescriptorBase):
"""Wraps an existing cache to support bulk fetching of keys. """Wraps an existing cache to support bulk fetching of keys.
Given a list of keys it looks in the cache to find any hits, then passes Given an iterable of keys it looks in the cache to find any hits, then passes
the list of missing keys to the wrapped function. the tuple of missing keys to the wrapped function.
Once wrapped, the function returns a Deferred which resolves to the list Once wrapped, the function returns a Deferred which resolves to the list
of results. of results.
@ -437,7 +437,9 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
return f return f
args_to_call = dict(arg_dict) args_to_call = dict(arg_dict)
args_to_call[self.list_name] = list(missing) # copy the missing set before sending it to the callee, to guard against
# modification.
args_to_call[self.list_name] = tuple(missing)
cached_defers.append( cached_defers.append(
defer.maybeDeferred( defer.maybeDeferred(
@ -522,14 +524,14 @@ def cachedList(
Used to do batch lookups for an already created cache. A single argument Used to do batch lookups for an already created cache. A single argument
is specified as a list that is iterated through to lookup keys in the is specified as a list that is iterated through to lookup keys in the
original cache. A new list consisting of the keys that weren't in the cache original cache. A new tuple consisting of the (deduplicated) keys that weren't in
get passed to the original function, the result of which is stored in the the cache gets passed to the original function, the result of which is stored in the
cache. cache.
Args: Args:
cached_method_name: The name of the single-item lookup method. cached_method_name: The name of the single-item lookup method.
This is only used to find the cache to use. This is only used to find the cache to use.
list_name: The name of the argument that is the list to use to list_name: The name of the argument that is the iterable to use to
do batch lookups in the cache. do batch lookups in the cache.
num_args: Number of arguments to use as the key in the cache num_args: Number of arguments to use as the key in the cache
(including list_name). Defaults to all named parameters. (including list_name). Defaults to all named parameters.

View file

@ -666,18 +666,20 @@ class CachedListDescriptorTestCase(unittest.TestCase):
with LoggingContext("c1") as c1: with LoggingContext("c1") as c1:
obj = Cls() obj = Cls()
obj.mock.return_value = {10: "fish", 20: "chips"} obj.mock.return_value = {10: "fish", 20: "chips"}
# start the lookup off
d1 = obj.list_fn([10, 20], 2) d1 = obj.list_fn([10, 20], 2)
self.assertEqual(current_context(), SENTINEL_CONTEXT) self.assertEqual(current_context(), SENTINEL_CONTEXT)
r = yield d1 r = yield d1
self.assertEqual(current_context(), c1) self.assertEqual(current_context(), c1)
obj.mock.assert_called_once_with([10, 20], 2) obj.mock.assert_called_once_with((10, 20), 2)
self.assertEqual(r, {10: "fish", 20: "chips"}) self.assertEqual(r, {10: "fish", 20: "chips"})
obj.mock.reset_mock() obj.mock.reset_mock()
# a call with different params should call the mock again # a call with different params should call the mock again
obj.mock.return_value = {30: "peas"} obj.mock.return_value = {30: "peas"}
r = yield obj.list_fn([20, 30], 2) r = yield obj.list_fn([20, 30], 2)
obj.mock.assert_called_once_with([30], 2) obj.mock.assert_called_once_with((30,), 2)
self.assertEqual(r, {20: "chips", 30: "peas"}) self.assertEqual(r, {20: "chips", 30: "peas"})
obj.mock.reset_mock() obj.mock.reset_mock()
@ -692,6 +694,15 @@ class CachedListDescriptorTestCase(unittest.TestCase):
obj.mock.assert_not_called() obj.mock.assert_not_called()
self.assertEqual(r, {10: "fish", 20: "chips", 30: "peas"}) self.assertEqual(r, {10: "fish", 20: "chips", 30: "peas"})
# we should also be able to use a (single-use) iterable, and should
# deduplicate the keys
obj.mock.reset_mock()
obj.mock.return_value = {40: "gravy"}
iterable = (x for x in [10, 40, 40])
r = yield obj.list_fn(iterable, 2)
obj.mock.assert_called_once_with((40,), 2)
self.assertEqual(r, {10: "fish", 40: "gravy"})
@defer.inlineCallbacks @defer.inlineCallbacks
def test_invalidate(self): def test_invalidate(self):
"""Make sure that invalidation callbacks are called.""" """Make sure that invalidation callbacks are called."""
@ -717,7 +728,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
# cache miss # cache miss
obj.mock.return_value = {10: "fish", 20: "chips"} obj.mock.return_value = {10: "fish", 20: "chips"}
r1 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate0) r1 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate0)
obj.mock.assert_called_once_with([10, 20], 2) obj.mock.assert_called_once_with((10, 20), 2)
self.assertEqual(r1, {10: "fish", 20: "chips"}) self.assertEqual(r1, {10: "fish", 20: "chips"})
obj.mock.reset_mock() obj.mock.reset_mock()