mirror of
https://mau.dev/maunium/synapse.git
synced 2024-11-16 06:51:46 +01:00
Minor @cachedList
enhancements (#9975)
- use a tuple rather than a list for the iterable that is passed into the wrapped function, for performance - test that we can pass an iterable and that keys are correctly deduped.
This commit is contained in:
parent
52ed9655ed
commit
5090f26b63
6 changed files with 31 additions and 20 deletions
1
changelog.d/9975.misc
Normal file
1
changelog.d/9975.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Minor enhancements to the `@cachedList` descriptor.
|
|
@ -665,7 +665,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||||
cached_method_name="get_device_list_last_stream_id_for_remote",
|
cached_method_name="get_device_list_last_stream_id_for_remote",
|
||||||
list_name="user_ids",
|
list_name="user_ids",
|
||||||
)
|
)
|
||||||
async def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
|
async def get_device_list_last_stream_id_for_remotes(self, user_ids: Iterable[str]):
|
||||||
rows = await self.db_pool.simple_select_many_batch(
|
rows = await self.db_pool.simple_select_many_batch(
|
||||||
table="device_lists_remote_extremeties",
|
table="device_lists_remote_extremeties",
|
||||||
column="user_id",
|
column="user_id",
|
||||||
|
|
|
@ -473,7 +473,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
||||||
num_args=1,
|
num_args=1,
|
||||||
)
|
)
|
||||||
async def _get_bare_e2e_cross_signing_keys_bulk(
|
async def _get_bare_e2e_cross_signing_keys_bulk(
|
||||||
self, user_ids: List[str]
|
self, user_ids: Iterable[str]
|
||||||
) -> Dict[str, Dict[str, dict]]:
|
) -> Dict[str, Dict[str, dict]]:
|
||||||
"""Returns the cross-signing keys for a set of users. The output of this
|
"""Returns the cross-signing keys for a set of users. The output of this
|
||||||
function should be passed to _get_e2e_cross_signing_signatures_txn if
|
function should be passed to _get_e2e_cross_signing_signatures_txn if
|
||||||
|
@ -497,7 +497,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
||||||
def _get_bare_e2e_cross_signing_keys_bulk_txn(
|
def _get_bare_e2e_cross_signing_keys_bulk_txn(
|
||||||
self,
|
self,
|
||||||
txn: Connection,
|
txn: Connection,
|
||||||
user_ids: List[str],
|
user_ids: Iterable[str],
|
||||||
) -> Dict[str, Dict[str, dict]]:
|
) -> Dict[str, Dict[str, dict]]:
|
||||||
"""Returns the cross-signing keys for a set of users. The output of this
|
"""Returns the cross-signing keys for a set of users. The output of this
|
||||||
function should be passed to _get_e2e_cross_signing_signatures_txn if
|
function should be passed to _get_e2e_cross_signing_signatures_txn if
|
||||||
|
|
|
@ -12,6 +12,8 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Dict, Iterable
|
||||||
|
|
||||||
from synapse.storage._base import SQLBaseStore
|
from synapse.storage._base import SQLBaseStore
|
||||||
from synapse.util.caches.descriptors import cached, cachedList
|
from synapse.util.caches.descriptors import cached, cachedList
|
||||||
|
|
||||||
|
@ -37,21 +39,16 @@ class UserErasureWorkerStore(SQLBaseStore):
|
||||||
return bool(result)
|
return bool(result)
|
||||||
|
|
||||||
@cachedList(cached_method_name="is_user_erased", list_name="user_ids")
|
@cachedList(cached_method_name="is_user_erased", list_name="user_ids")
|
||||||
async def are_users_erased(self, user_ids):
|
async def are_users_erased(self, user_ids: Iterable[str]) -> Dict[str, bool]:
|
||||||
"""
|
"""
|
||||||
Checks which users in a list have requested erasure
|
Checks which users in a list have requested erasure
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_ids (iterable[str]): full user id to check
|
user_ids: full user ids to check
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict[str, bool]:
|
|
||||||
for each user, whether the user has requested erasure.
|
for each user, whether the user has requested erasure.
|
||||||
"""
|
"""
|
||||||
# this serves the dual purpose of (a) making sure we can do len and
|
|
||||||
# iterate it multiple times, and (b) avoiding duplicates.
|
|
||||||
user_ids = tuple(set(user_ids))
|
|
||||||
|
|
||||||
rows = await self.db_pool.simple_select_many_batch(
|
rows = await self.db_pool.simple_select_many_batch(
|
||||||
table="erased_users",
|
table="erased_users",
|
||||||
column="user_id",
|
column="user_id",
|
||||||
|
|
|
@ -322,8 +322,8 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
|
||||||
class DeferredCacheListDescriptor(_CacheDescriptorBase):
|
class DeferredCacheListDescriptor(_CacheDescriptorBase):
|
||||||
"""Wraps an existing cache to support bulk fetching of keys.
|
"""Wraps an existing cache to support bulk fetching of keys.
|
||||||
|
|
||||||
Given a list of keys it looks in the cache to find any hits, then passes
|
Given an iterable of keys it looks in the cache to find any hits, then passes
|
||||||
the list of missing keys to the wrapped function.
|
the tuple of missing keys to the wrapped function.
|
||||||
|
|
||||||
Once wrapped, the function returns a Deferred which resolves to the list
|
Once wrapped, the function returns a Deferred which resolves to the list
|
||||||
of results.
|
of results.
|
||||||
|
@ -437,7 +437,9 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
|
||||||
return f
|
return f
|
||||||
|
|
||||||
args_to_call = dict(arg_dict)
|
args_to_call = dict(arg_dict)
|
||||||
args_to_call[self.list_name] = list(missing)
|
# copy the missing set before sending it to the callee, to guard against
|
||||||
|
# modification.
|
||||||
|
args_to_call[self.list_name] = tuple(missing)
|
||||||
|
|
||||||
cached_defers.append(
|
cached_defers.append(
|
||||||
defer.maybeDeferred(
|
defer.maybeDeferred(
|
||||||
|
@ -522,14 +524,14 @@ def cachedList(
|
||||||
|
|
||||||
Used to do batch lookups for an already created cache. A single argument
|
Used to do batch lookups for an already created cache. A single argument
|
||||||
is specified as a list that is iterated through to lookup keys in the
|
is specified as a list that is iterated through to lookup keys in the
|
||||||
original cache. A new list consisting of the keys that weren't in the cache
|
original cache. A new tuple consisting of the (deduplicated) keys that weren't in
|
||||||
get passed to the original function, the result of which is stored in the
|
the cache gets passed to the original function, the result of which is stored in the
|
||||||
cache.
|
cache.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cached_method_name: The name of the single-item lookup method.
|
cached_method_name: The name of the single-item lookup method.
|
||||||
This is only used to find the cache to use.
|
This is only used to find the cache to use.
|
||||||
list_name: The name of the argument that is the list to use to
|
list_name: The name of the argument that is the iterable to use to
|
||||||
do batch lookups in the cache.
|
do batch lookups in the cache.
|
||||||
num_args: Number of arguments to use as the key in the cache
|
num_args: Number of arguments to use as the key in the cache
|
||||||
(including list_name). Defaults to all named parameters.
|
(including list_name). Defaults to all named parameters.
|
||||||
|
|
|
@ -666,18 +666,20 @@ class CachedListDescriptorTestCase(unittest.TestCase):
|
||||||
with LoggingContext("c1") as c1:
|
with LoggingContext("c1") as c1:
|
||||||
obj = Cls()
|
obj = Cls()
|
||||||
obj.mock.return_value = {10: "fish", 20: "chips"}
|
obj.mock.return_value = {10: "fish", 20: "chips"}
|
||||||
|
|
||||||
|
# start the lookup off
|
||||||
d1 = obj.list_fn([10, 20], 2)
|
d1 = obj.list_fn([10, 20], 2)
|
||||||
self.assertEqual(current_context(), SENTINEL_CONTEXT)
|
self.assertEqual(current_context(), SENTINEL_CONTEXT)
|
||||||
r = yield d1
|
r = yield d1
|
||||||
self.assertEqual(current_context(), c1)
|
self.assertEqual(current_context(), c1)
|
||||||
obj.mock.assert_called_once_with([10, 20], 2)
|
obj.mock.assert_called_once_with((10, 20), 2)
|
||||||
self.assertEqual(r, {10: "fish", 20: "chips"})
|
self.assertEqual(r, {10: "fish", 20: "chips"})
|
||||||
obj.mock.reset_mock()
|
obj.mock.reset_mock()
|
||||||
|
|
||||||
# a call with different params should call the mock again
|
# a call with different params should call the mock again
|
||||||
obj.mock.return_value = {30: "peas"}
|
obj.mock.return_value = {30: "peas"}
|
||||||
r = yield obj.list_fn([20, 30], 2)
|
r = yield obj.list_fn([20, 30], 2)
|
||||||
obj.mock.assert_called_once_with([30], 2)
|
obj.mock.assert_called_once_with((30,), 2)
|
||||||
self.assertEqual(r, {20: "chips", 30: "peas"})
|
self.assertEqual(r, {20: "chips", 30: "peas"})
|
||||||
obj.mock.reset_mock()
|
obj.mock.reset_mock()
|
||||||
|
|
||||||
|
@ -692,6 +694,15 @@ class CachedListDescriptorTestCase(unittest.TestCase):
|
||||||
obj.mock.assert_not_called()
|
obj.mock.assert_not_called()
|
||||||
self.assertEqual(r, {10: "fish", 20: "chips", 30: "peas"})
|
self.assertEqual(r, {10: "fish", 20: "chips", 30: "peas"})
|
||||||
|
|
||||||
|
# we should also be able to use a (single-use) iterable, and should
|
||||||
|
# deduplicate the keys
|
||||||
|
obj.mock.reset_mock()
|
||||||
|
obj.mock.return_value = {40: "gravy"}
|
||||||
|
iterable = (x for x in [10, 40, 40])
|
||||||
|
r = yield obj.list_fn(iterable, 2)
|
||||||
|
obj.mock.assert_called_once_with((40,), 2)
|
||||||
|
self.assertEqual(r, {10: "fish", 40: "gravy"})
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_invalidate(self):
|
def test_invalidate(self):
|
||||||
"""Make sure that invalidation callbacks are called."""
|
"""Make sure that invalidation callbacks are called."""
|
||||||
|
@ -717,7 +728,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
|
||||||
# cache miss
|
# cache miss
|
||||||
obj.mock.return_value = {10: "fish", 20: "chips"}
|
obj.mock.return_value = {10: "fish", 20: "chips"}
|
||||||
r1 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate0)
|
r1 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate0)
|
||||||
obj.mock.assert_called_once_with([10, 20], 2)
|
obj.mock.assert_called_once_with((10, 20), 2)
|
||||||
self.assertEqual(r1, {10: "fish", 20: "chips"})
|
self.assertEqual(r1, {10: "fish", 20: "chips"})
|
||||||
obj.mock.reset_mock()
|
obj.mock.reset_mock()
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue