mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-21 03:03:51 +01:00
Back out in-flight state caching changes. (#12126)
This commit is contained in:
parent
8e56a1b73c
commit
c7b2f1ccdc
7 changed files with 26 additions and 676 deletions
|
@ -1 +0,0 @@
|
||||||
Deduplicate in-flight requests in `_get_state_for_groups`.
|
|
|
@ -1 +0,0 @@
|
||||||
Deduplicate in-flight requests in `_get_state_for_groups`.
|
|
|
@ -1 +0,0 @@
|
||||||
Deduplicate in-flight requests in `_get_state_for_groups`.
|
|
|
@ -1 +0,0 @@
|
||||||
Deduplicate in-flight requests in `_get_state_for_groups`.
|
|
1
changelog.d/12126.removal
Normal file
1
changelog.d/12126.removal
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Back out in-flight state caching changes.
|
|
@ -13,24 +13,11 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import (
|
from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
|
||||||
TYPE_CHECKING,
|
|
||||||
Collection,
|
|
||||||
Dict,
|
|
||||||
Iterable,
|
|
||||||
Optional,
|
|
||||||
Sequence,
|
|
||||||
Set,
|
|
||||||
Tuple,
|
|
||||||
)
|
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
from sortedcontainers import SortedDict
|
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
|
||||||
from synapse.storage._base import SQLBaseStore
|
from synapse.storage._base import SQLBaseStore
|
||||||
from synapse.storage.database import (
|
from synapse.storage.database import (
|
||||||
DatabasePool,
|
DatabasePool,
|
||||||
|
@ -42,12 +29,6 @@ from synapse.storage.state import StateFilter
|
||||||
from synapse.storage.types import Cursor
|
from synapse.storage.types import Cursor
|
||||||
from synapse.storage.util.sequence import build_sequence_generator
|
from synapse.storage.util.sequence import build_sequence_generator
|
||||||
from synapse.types import MutableStateMap, StateKey, StateMap
|
from synapse.types import MutableStateMap, StateKey, StateMap
|
||||||
from synapse.util import unwrapFirstError
|
|
||||||
from synapse.util.async_helpers import (
|
|
||||||
AbstractObservableDeferred,
|
|
||||||
ObservableDeferred,
|
|
||||||
yieldable_gather_results,
|
|
||||||
)
|
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import cached
|
||||||
from synapse.util.caches.dictionary_cache import DictionaryCache
|
from synapse.util.caches.dictionary_cache import DictionaryCache
|
||||||
|
|
||||||
|
@ -56,8 +37,8 @@ if TYPE_CHECKING:
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
MAX_STATE_DELTA_HOPS = 100
|
MAX_STATE_DELTA_HOPS = 100
|
||||||
MAX_INFLIGHT_REQUESTS_PER_GROUP = 5
|
|
||||||
|
|
||||||
|
|
||||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||||
|
@ -73,24 +54,6 @@ class _GetStateGroupDelta:
|
||||||
return len(self.delta_ids) if self.delta_ids else 0
|
return len(self.delta_ids) if self.delta_ids else 0
|
||||||
|
|
||||||
|
|
||||||
def state_filter_rough_priority_comparator(
|
|
||||||
state_filter: StateFilter,
|
|
||||||
) -> Tuple[int, int]:
|
|
||||||
"""
|
|
||||||
Returns a comparable value that roughly indicates the relative size of this
|
|
||||||
state filter compared to others.
|
|
||||||
'Larger' state filters should sort first when using ascending order, so
|
|
||||||
this is essentially the opposite of 'size'.
|
|
||||||
It should be treated as a rough guide only and should not be interpreted to
|
|
||||||
have any particular meaning. The representation may also change
|
|
||||||
|
|
||||||
The current implementation returns a tuple of the form:
|
|
||||||
* -1 for include_others, 0 otherwise
|
|
||||||
* -(number of entries in state_filter.types)
|
|
||||||
"""
|
|
||||||
return -int(state_filter.include_others), -len(state_filter.types)
|
|
||||||
|
|
||||||
|
|
||||||
class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
||||||
"""A data store for fetching/storing state groups."""
|
"""A data store for fetching/storing state groups."""
|
||||||
|
|
||||||
|
@ -143,12 +106,6 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
||||||
500000,
|
500000,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Current ongoing get_state_for_groups in-flight requests
|
|
||||||
# {group ID -> {StateFilter -> ObservableDeferred}}
|
|
||||||
self._state_group_inflight_requests: Dict[
|
|
||||||
int, SortedDict[StateFilter, AbstractObservableDeferred[StateMap[str]]]
|
|
||||||
] = {}
|
|
||||||
|
|
||||||
def get_max_state_group_txn(txn: Cursor) -> int:
|
def get_max_state_group_txn(txn: Cursor) -> int:
|
||||||
txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
|
txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
|
||||||
return txn.fetchone()[0] # type: ignore
|
return txn.fetchone()[0] # type: ignore
|
||||||
|
@ -200,7 +157,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _get_state_groups_from_groups(
|
async def _get_state_groups_from_groups(
|
||||||
self, groups: Sequence[int], state_filter: StateFilter
|
self, groups: List[int], state_filter: StateFilter
|
||||||
) -> Dict[int, StateMap[str]]:
|
) -> Dict[int, StateMap[str]]:
|
||||||
"""Returns the state groups for a given set of groups from the
|
"""Returns the state groups for a given set of groups from the
|
||||||
database, filtering on types of state events.
|
database, filtering on types of state events.
|
||||||
|
@ -271,170 +228,6 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
||||||
|
|
||||||
return state_filter.filter_state(state_dict_ids), not missing_types
|
return state_filter.filter_state(state_dict_ids), not missing_types
|
||||||
|
|
||||||
def _get_state_for_group_gather_inflight_requests(
|
|
||||||
self, group: int, state_filter_left_over: StateFilter
|
|
||||||
) -> Tuple[Sequence[AbstractObservableDeferred[StateMap[str]]], StateFilter]:
|
|
||||||
"""
|
|
||||||
Attempts to gather in-flight requests and re-use them to retrieve state
|
|
||||||
for the given state group, filtered with the given state filter.
|
|
||||||
|
|
||||||
If there are more than MAX_INFLIGHT_REQUESTS_PER_GROUP in-flight requests,
|
|
||||||
and there *still* isn't enough information to complete the request by solely
|
|
||||||
reusing others, a full state filter will be requested to ensure that subsequent
|
|
||||||
requests can reuse this request.
|
|
||||||
|
|
||||||
Used as part of _get_state_for_group_using_inflight_cache.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of two values:
|
|
||||||
A sequence of ObservableDeferreds to observe
|
|
||||||
A StateFilter representing what else needs to be requested to fulfill the request
|
|
||||||
"""
|
|
||||||
|
|
||||||
inflight_requests = self._state_group_inflight_requests.get(group)
|
|
||||||
if inflight_requests is None:
|
|
||||||
# no requests for this group, need to retrieve it all ourselves
|
|
||||||
return (), state_filter_left_over
|
|
||||||
|
|
||||||
# The list of ongoing requests which will help narrow the current request.
|
|
||||||
reusable_requests = []
|
|
||||||
|
|
||||||
# Iterate over existing requests in roughly biggest-first order.
|
|
||||||
for request_state_filter in inflight_requests:
|
|
||||||
request_deferred = inflight_requests[request_state_filter]
|
|
||||||
new_state_filter_left_over = state_filter_left_over.approx_difference(
|
|
||||||
request_state_filter
|
|
||||||
)
|
|
||||||
if new_state_filter_left_over == state_filter_left_over:
|
|
||||||
# Reusing this request would not gain us anything, so don't bother.
|
|
||||||
continue
|
|
||||||
|
|
||||||
reusable_requests.append(request_deferred)
|
|
||||||
state_filter_left_over = new_state_filter_left_over
|
|
||||||
if state_filter_left_over == StateFilter.none():
|
|
||||||
# we have managed to collect enough of the in-flight requests
|
|
||||||
# to cover our StateFilter and give us the state we need.
|
|
||||||
break
|
|
||||||
|
|
||||||
if (
|
|
||||||
state_filter_left_over != StateFilter.none()
|
|
||||||
and len(inflight_requests) >= MAX_INFLIGHT_REQUESTS_PER_GROUP
|
|
||||||
):
|
|
||||||
# There are too many requests for this group.
|
|
||||||
# To prevent even more from building up, we request the whole
|
|
||||||
# state filter to guarantee that we can be reused by any subsequent
|
|
||||||
# requests for this state group.
|
|
||||||
return (), StateFilter.all()
|
|
||||||
|
|
||||||
return reusable_requests, state_filter_left_over
|
|
||||||
|
|
||||||
async def _get_state_for_group_fire_request(
|
|
||||||
self, group: int, state_filter: StateFilter
|
|
||||||
) -> StateMap[str]:
|
|
||||||
"""
|
|
||||||
Fires off a request to get the state at a state group,
|
|
||||||
potentially filtering by type and/or state key.
|
|
||||||
|
|
||||||
This request will be tracked in the in-flight request cache and automatically
|
|
||||||
removed when it is finished.
|
|
||||||
|
|
||||||
Used as part of _get_state_for_group_using_inflight_cache.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
group: ID of the state group for which we want to get state
|
|
||||||
state_filter: the state filter used to fetch state from the database
|
|
||||||
"""
|
|
||||||
cache_sequence_nm = self._state_group_cache.sequence
|
|
||||||
cache_sequence_m = self._state_group_members_cache.sequence
|
|
||||||
|
|
||||||
# Help the cache hit ratio by expanding the filter a bit
|
|
||||||
db_state_filter = state_filter.return_expanded()
|
|
||||||
|
|
||||||
async def _the_request() -> StateMap[str]:
|
|
||||||
group_to_state_dict = await self._get_state_groups_from_groups(
|
|
||||||
(group,), state_filter=db_state_filter
|
|
||||||
)
|
|
||||||
|
|
||||||
# Now let's update the caches
|
|
||||||
self._insert_into_cache(
|
|
||||||
group_to_state_dict,
|
|
||||||
db_state_filter,
|
|
||||||
cache_seq_num_members=cache_sequence_m,
|
|
||||||
cache_seq_num_non_members=cache_sequence_nm,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Remove ourselves from the in-flight cache
|
|
||||||
group_request_dict = self._state_group_inflight_requests[group]
|
|
||||||
del group_request_dict[db_state_filter]
|
|
||||||
if not group_request_dict:
|
|
||||||
# If there are no more requests in-flight for this group,
|
|
||||||
# clean up the cache by removing the empty dictionary
|
|
||||||
del self._state_group_inflight_requests[group]
|
|
||||||
|
|
||||||
return group_to_state_dict[group]
|
|
||||||
|
|
||||||
# We don't immediately await the result, so must use run_in_background
|
|
||||||
# But we DO await the result before the current log context (request)
|
|
||||||
# finishes, so don't need to run it as a background process.
|
|
||||||
request_deferred = run_in_background(_the_request)
|
|
||||||
observable_deferred = ObservableDeferred(request_deferred, consumeErrors=True)
|
|
||||||
|
|
||||||
# Insert the ObservableDeferred into the cache
|
|
||||||
group_request_dict = self._state_group_inflight_requests.setdefault(
|
|
||||||
group, SortedDict(state_filter_rough_priority_comparator)
|
|
||||||
)
|
|
||||||
group_request_dict[db_state_filter] = observable_deferred
|
|
||||||
|
|
||||||
return await make_deferred_yieldable(observable_deferred.observe())
|
|
||||||
|
|
||||||
async def _get_state_for_group_using_inflight_cache(
|
|
||||||
self, group: int, state_filter: StateFilter
|
|
||||||
) -> MutableStateMap[str]:
|
|
||||||
"""
|
|
||||||
Gets the state at a state group, potentially filtering by type and/or
|
|
||||||
state key.
|
|
||||||
|
|
||||||
1. Calls _get_state_for_group_gather_inflight_requests to gather any
|
|
||||||
ongoing requests which might overlap with the current request.
|
|
||||||
2. Fires a new request, using _get_state_for_group_fire_request,
|
|
||||||
for any state which cannot be gathered from ongoing requests.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
group: ID of the state group for which we want to get state
|
|
||||||
state_filter: the state filter used to fetch state from the database
|
|
||||||
Returns:
|
|
||||||
state map
|
|
||||||
"""
|
|
||||||
|
|
||||||
# first, figure out whether we can re-use any in-flight requests
|
|
||||||
# (and if so, what would be left over)
|
|
||||||
(
|
|
||||||
reusable_requests,
|
|
||||||
state_filter_left_over,
|
|
||||||
) = self._get_state_for_group_gather_inflight_requests(group, state_filter)
|
|
||||||
|
|
||||||
if state_filter_left_over != StateFilter.none():
|
|
||||||
# Fetch remaining state
|
|
||||||
remaining = await self._get_state_for_group_fire_request(
|
|
||||||
group, state_filter_left_over
|
|
||||||
)
|
|
||||||
assembled_state: MutableStateMap[str] = dict(remaining)
|
|
||||||
else:
|
|
||||||
assembled_state = {}
|
|
||||||
|
|
||||||
gathered = await make_deferred_yieldable(
|
|
||||||
defer.gatherResults(
|
|
||||||
(r.observe() for r in reusable_requests), consumeErrors=True
|
|
||||||
)
|
|
||||||
).addErrback(unwrapFirstError)
|
|
||||||
|
|
||||||
# assemble our result.
|
|
||||||
for result_piece in gathered:
|
|
||||||
assembled_state.update(result_piece)
|
|
||||||
|
|
||||||
# Filter out any state that may be more than what we asked for.
|
|
||||||
return state_filter.filter_state(assembled_state)
|
|
||||||
|
|
||||||
async def _get_state_for_groups(
|
async def _get_state_for_groups(
|
||||||
self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
|
self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
|
||||||
) -> Dict[int, MutableStateMap[str]]:
|
) -> Dict[int, MutableStateMap[str]]:
|
||||||
|
@ -476,17 +269,31 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
||||||
if not incomplete_groups:
|
if not incomplete_groups:
|
||||||
return state
|
return state
|
||||||
|
|
||||||
async def get_from_cache(group: int, state_filter: StateFilter) -> None:
|
cache_sequence_nm = self._state_group_cache.sequence
|
||||||
state[group] = await self._get_state_for_group_using_inflight_cache(
|
cache_sequence_m = self._state_group_members_cache.sequence
|
||||||
group, state_filter
|
|
||||||
)
|
|
||||||
|
|
||||||
await yieldable_gather_results(
|
# Help the cache hit ratio by expanding the filter a bit
|
||||||
get_from_cache,
|
db_state_filter = state_filter.return_expanded()
|
||||||
incomplete_groups,
|
|
||||||
state_filter,
|
group_to_state_dict = await self._get_state_groups_from_groups(
|
||||||
|
list(incomplete_groups), state_filter=db_state_filter
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Now lets update the caches
|
||||||
|
self._insert_into_cache(
|
||||||
|
group_to_state_dict,
|
||||||
|
db_state_filter,
|
||||||
|
cache_seq_num_members=cache_sequence_m,
|
||||||
|
cache_seq_num_non_members=cache_sequence_nm,
|
||||||
|
)
|
||||||
|
|
||||||
|
# And finally update the result dict, by filtering out any extra
|
||||||
|
# stuff we pulled out of the database.
|
||||||
|
for group, group_state_dict in group_to_state_dict.items():
|
||||||
|
# We just replace any existing entries, as we will have loaded
|
||||||
|
# everything we need from the database anyway.
|
||||||
|
state[group] = state_filter.filter_state(group_state_dict)
|
||||||
|
|
||||||
return state
|
return state
|
||||||
|
|
||||||
def _get_state_for_groups_using_cache(
|
def _get_state_for_groups_using_cache(
|
||||||
|
|
|
@ -1,454 +0,0 @@
|
||||||
# Copyright 2022 The Matrix.org Foundation C.I.C.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import typing
|
|
||||||
from typing import Dict, List, Sequence, Tuple
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
from parameterized import parameterized
|
|
||||||
|
|
||||||
from twisted.internet.defer import Deferred, ensureDeferred
|
|
||||||
from twisted.test.proto_helpers import MemoryReactor
|
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes
|
|
||||||
from synapse.storage.databases.state.store import (
|
|
||||||
MAX_INFLIGHT_REQUESTS_PER_GROUP,
|
|
||||||
state_filter_rough_priority_comparator,
|
|
||||||
)
|
|
||||||
from synapse.storage.state import StateFilter
|
|
||||||
from synapse.types import StateMap
|
|
||||||
from synapse.util import Clock
|
|
||||||
|
|
||||||
from tests.unittest import HomeserverTestCase
|
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
|
||||||
from synapse.server import HomeServer
|
|
||||||
|
|
||||||
# StateFilter for ALL non-m.room.member state events
|
|
||||||
ALL_NON_MEMBERS_STATE_FILTER = StateFilter.freeze(
|
|
||||||
types={EventTypes.Member: set()},
|
|
||||||
include_others=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
FAKE_STATE = {
|
|
||||||
(EventTypes.Member, "@alice:test"): "join",
|
|
||||||
(EventTypes.Member, "@bob:test"): "leave",
|
|
||||||
(EventTypes.Member, "@charlie:test"): "invite",
|
|
||||||
("test.type", "a"): "AAA",
|
|
||||||
("test.type", "b"): "BBB",
|
|
||||||
("other.event.type", "state.key"): "123",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class StateGroupInflightCachingTestCase(HomeserverTestCase):
|
|
||||||
def prepare(
|
|
||||||
self, reactor: MemoryReactor, clock: Clock, homeserver: "HomeServer"
|
|
||||||
) -> None:
|
|
||||||
self.state_storage = homeserver.get_storage().state
|
|
||||||
self.state_datastore = homeserver.get_datastores().state
|
|
||||||
# Patch out the `_get_state_groups_from_groups`.
|
|
||||||
# This is useful because it lets us pretend we have a slow database.
|
|
||||||
get_state_groups_patch = patch.object(
|
|
||||||
self.state_datastore,
|
|
||||||
"_get_state_groups_from_groups",
|
|
||||||
self._fake_get_state_groups_from_groups,
|
|
||||||
)
|
|
||||||
get_state_groups_patch.start()
|
|
||||||
|
|
||||||
self.addCleanup(get_state_groups_patch.stop)
|
|
||||||
self.get_state_group_calls: List[
|
|
||||||
Tuple[Tuple[int, ...], StateFilter, Deferred[Dict[int, StateMap[str]]]]
|
|
||||||
] = []
|
|
||||||
|
|
||||||
def _fake_get_state_groups_from_groups(
|
|
||||||
self, groups: Sequence[int], state_filter: StateFilter
|
|
||||||
) -> "Deferred[Dict[int, StateMap[str]]]":
|
|
||||||
d: Deferred[Dict[int, StateMap[str]]] = Deferred()
|
|
||||||
self.get_state_group_calls.append((tuple(groups), state_filter, d))
|
|
||||||
return d
|
|
||||||
|
|
||||||
def _complete_request_fake(
|
|
||||||
self,
|
|
||||||
groups: Tuple[int, ...],
|
|
||||||
state_filter: StateFilter,
|
|
||||||
d: "Deferred[Dict[int, StateMap[str]]]",
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Assemble a fake database response and complete the database request.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Return a filtered copy of the fake state
|
|
||||||
d.callback({group: state_filter.filter_state(FAKE_STATE) for group in groups})
|
|
||||||
|
|
||||||
def test_duplicate_requests_deduplicated(self) -> None:
|
|
||||||
"""
|
|
||||||
Tests that duplicate requests for state are deduplicated.
|
|
||||||
|
|
||||||
This test:
|
|
||||||
- requests some state (state group 42, 'all' state filter)
|
|
||||||
- requests it again, before the first request finishes
|
|
||||||
- checks to see that only one database query was made
|
|
||||||
- completes the database query
|
|
||||||
- checks that both requests see the same retrieved state
|
|
||||||
"""
|
|
||||||
req1 = ensureDeferred(
|
|
||||||
self.state_datastore._get_state_for_group_using_inflight_cache(
|
|
||||||
42, StateFilter.all()
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.pump(by=0.1)
|
|
||||||
|
|
||||||
# This should have gone to the database
|
|
||||||
self.assertEqual(len(self.get_state_group_calls), 1)
|
|
||||||
self.assertFalse(req1.called)
|
|
||||||
|
|
||||||
req2 = ensureDeferred(
|
|
||||||
self.state_datastore._get_state_for_group_using_inflight_cache(
|
|
||||||
42, StateFilter.all()
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.pump(by=0.1)
|
|
||||||
|
|
||||||
# No more calls should have gone to the database
|
|
||||||
self.assertEqual(len(self.get_state_group_calls), 1)
|
|
||||||
self.assertFalse(req1.called)
|
|
||||||
self.assertFalse(req2.called)
|
|
||||||
|
|
||||||
groups, sf, d = self.get_state_group_calls[0]
|
|
||||||
self.assertEqual(groups, (42,))
|
|
||||||
self.assertEqual(sf, StateFilter.all())
|
|
||||||
|
|
||||||
# Now we can complete the request
|
|
||||||
self._complete_request_fake(groups, sf, d)
|
|
||||||
|
|
||||||
self.assertEqual(self.get_success(req1), FAKE_STATE)
|
|
||||||
self.assertEqual(self.get_success(req2), FAKE_STATE)
|
|
||||||
|
|
||||||
def test_smaller_request_deduplicated(self) -> None:
|
|
||||||
"""
|
|
||||||
Tests that duplicate requests for state are deduplicated.
|
|
||||||
|
|
||||||
This test:
|
|
||||||
- requests some state (state group 42, 'all' state filter)
|
|
||||||
- requests a subset of that state, before the first request finishes
|
|
||||||
- checks to see that only one database query was made
|
|
||||||
- completes the database query
|
|
||||||
- checks that both requests see the correct retrieved state
|
|
||||||
"""
|
|
||||||
req1 = ensureDeferred(
|
|
||||||
self.state_datastore._get_state_for_group_using_inflight_cache(
|
|
||||||
42, StateFilter.from_types((("test.type", None),))
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.pump(by=0.1)
|
|
||||||
|
|
||||||
# This should have gone to the database
|
|
||||||
self.assertEqual(len(self.get_state_group_calls), 1)
|
|
||||||
self.assertFalse(req1.called)
|
|
||||||
|
|
||||||
req2 = ensureDeferred(
|
|
||||||
self.state_datastore._get_state_for_group_using_inflight_cache(
|
|
||||||
42, StateFilter.from_types((("test.type", "b"),))
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.pump(by=0.1)
|
|
||||||
|
|
||||||
# No more calls should have gone to the database, because the second
|
|
||||||
# request was already in the in-flight cache!
|
|
||||||
self.assertEqual(len(self.get_state_group_calls), 1)
|
|
||||||
self.assertFalse(req1.called)
|
|
||||||
self.assertFalse(req2.called)
|
|
||||||
|
|
||||||
groups, sf, d = self.get_state_group_calls[0]
|
|
||||||
self.assertEqual(groups, (42,))
|
|
||||||
# The state filter is expanded internally for increased cache hit rate,
|
|
||||||
# so we the database sees a wider state filter than requested.
|
|
||||||
self.assertEqual(sf, ALL_NON_MEMBERS_STATE_FILTER)
|
|
||||||
|
|
||||||
# Now we can complete the request
|
|
||||||
self._complete_request_fake(groups, sf, d)
|
|
||||||
|
|
||||||
self.assertEqual(
|
|
||||||
self.get_success(req1),
|
|
||||||
{("test.type", "a"): "AAA", ("test.type", "b"): "BBB"},
|
|
||||||
)
|
|
||||||
self.assertEqual(self.get_success(req2), {("test.type", "b"): "BBB"})
|
|
||||||
|
|
||||||
def test_partially_overlapping_request_deduplicated(self) -> None:
|
|
||||||
"""
|
|
||||||
Tests that partially-overlapping requests are partially deduplicated.
|
|
||||||
|
|
||||||
This test:
|
|
||||||
- requests a single type of wildcard state
|
|
||||||
(This is internally expanded to be all non-member state)
|
|
||||||
- requests the entire state in parallel
|
|
||||||
- checks to see that two database queries were made, but that the second
|
|
||||||
one is only for member state.
|
|
||||||
- completes the database queries
|
|
||||||
- checks that both requests have the correct result.
|
|
||||||
"""
|
|
||||||
|
|
||||||
req1 = ensureDeferred(
|
|
||||||
self.state_datastore._get_state_for_group_using_inflight_cache(
|
|
||||||
42, StateFilter.from_types((("test.type", None),))
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.pump(by=0.1)
|
|
||||||
|
|
||||||
# This should have gone to the database
|
|
||||||
self.assertEqual(len(self.get_state_group_calls), 1)
|
|
||||||
self.assertFalse(req1.called)
|
|
||||||
|
|
||||||
req2 = ensureDeferred(
|
|
||||||
self.state_datastore._get_state_for_group_using_inflight_cache(
|
|
||||||
42, StateFilter.all()
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.pump(by=0.1)
|
|
||||||
|
|
||||||
# Because it only partially overlaps, this also went to the database
|
|
||||||
self.assertEqual(len(self.get_state_group_calls), 2)
|
|
||||||
self.assertFalse(req1.called)
|
|
||||||
self.assertFalse(req2.called)
|
|
||||||
|
|
||||||
# First request:
|
|
||||||
groups, sf, d = self.get_state_group_calls[0]
|
|
||||||
self.assertEqual(groups, (42,))
|
|
||||||
# The state filter is expanded internally for increased cache hit rate,
|
|
||||||
# so we the database sees a wider state filter than requested.
|
|
||||||
self.assertEqual(sf, ALL_NON_MEMBERS_STATE_FILTER)
|
|
||||||
self._complete_request_fake(groups, sf, d)
|
|
||||||
|
|
||||||
# Second request:
|
|
||||||
groups, sf, d = self.get_state_group_calls[1]
|
|
||||||
self.assertEqual(groups, (42,))
|
|
||||||
# The state filter is narrowed to only request membership state, because
|
|
||||||
# the remainder of the state is already being queried in the first request!
|
|
||||||
self.assertEqual(
|
|
||||||
sf, StateFilter.freeze({EventTypes.Member: None}, include_others=False)
|
|
||||||
)
|
|
||||||
self._complete_request_fake(groups, sf, d)
|
|
||||||
|
|
||||||
# Check the results are correct
|
|
||||||
self.assertEqual(
|
|
||||||
self.get_success(req1),
|
|
||||||
{("test.type", "a"): "AAA", ("test.type", "b"): "BBB"},
|
|
||||||
)
|
|
||||||
self.assertEqual(self.get_success(req2), FAKE_STATE)
|
|
||||||
|
|
||||||
def test_in_flight_requests_stop_being_in_flight(self) -> None:
|
|
||||||
"""
|
|
||||||
Tests that in-flight request deduplication doesn't somehow 'hold on'
|
|
||||||
to completed requests: once they're done, they're taken out of the
|
|
||||||
in-flight cache.
|
|
||||||
"""
|
|
||||||
req1 = ensureDeferred(
|
|
||||||
self.state_datastore._get_state_for_group_using_inflight_cache(
|
|
||||||
42, StateFilter.all()
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.pump(by=0.1)
|
|
||||||
|
|
||||||
# This should have gone to the database
|
|
||||||
self.assertEqual(len(self.get_state_group_calls), 1)
|
|
||||||
self.assertFalse(req1.called)
|
|
||||||
|
|
||||||
# Complete the request right away.
|
|
||||||
self._complete_request_fake(*self.get_state_group_calls[0])
|
|
||||||
self.assertTrue(req1.called)
|
|
||||||
|
|
||||||
# Send off another request
|
|
||||||
req2 = ensureDeferred(
|
|
||||||
self.state_datastore._get_state_for_group_using_inflight_cache(
|
|
||||||
42, StateFilter.all()
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.pump(by=0.1)
|
|
||||||
|
|
||||||
# It should have gone to the database again, because the previous request
|
|
||||||
# isn't in-flight and therefore isn't available for deduplication.
|
|
||||||
self.assertEqual(len(self.get_state_group_calls), 2)
|
|
||||||
self.assertFalse(req2.called)
|
|
||||||
|
|
||||||
# Complete the request right away.
|
|
||||||
self._complete_request_fake(*self.get_state_group_calls[1])
|
|
||||||
self.assertTrue(req2.called)
|
|
||||||
groups, sf, d = self.get_state_group_calls[0]
|
|
||||||
|
|
||||||
self.assertEqual(self.get_success(req1), FAKE_STATE)
|
|
||||||
self.assertEqual(self.get_success(req2), FAKE_STATE)
|
|
||||||
|
|
||||||
def test_inflight_requests_capped(self) -> None:
|
|
||||||
"""
|
|
||||||
Tests that the number of in-flight requests is capped to 5.
|
|
||||||
|
|
||||||
- requests several pieces of state separately
|
|
||||||
(5 to hit the limit, 1 to 'shunt out', another that comes after the
|
|
||||||
group has been 'shunted out')
|
|
||||||
- checks to see that the torrent of requests is shunted out by
|
|
||||||
rewriting one of the filters as the 'all' state filter
|
|
||||||
- requests after that one do not cause any additional queries
|
|
||||||
"""
|
|
||||||
# 5 at the time of writing.
|
|
||||||
CAP_COUNT = MAX_INFLIGHT_REQUESTS_PER_GROUP
|
|
||||||
|
|
||||||
reqs = []
|
|
||||||
|
|
||||||
# Request 7 different keys (1 to 7) of the `some.state` type.
|
|
||||||
for req_id in range(CAP_COUNT + 2):
|
|
||||||
reqs.append(
|
|
||||||
ensureDeferred(
|
|
||||||
self.state_datastore._get_state_for_group_using_inflight_cache(
|
|
||||||
42,
|
|
||||||
StateFilter.freeze(
|
|
||||||
{"some.state": {str(req_id + 1)}}, include_others=False
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.pump(by=0.1)
|
|
||||||
|
|
||||||
# There should only be 6 calls to the database, not 7.
|
|
||||||
self.assertEqual(len(self.get_state_group_calls), CAP_COUNT + 1)
|
|
||||||
|
|
||||||
# Assert that the first 5 are exact requests for the individual pieces
|
|
||||||
# wanted
|
|
||||||
for req_id in range(CAP_COUNT):
|
|
||||||
groups, sf, d = self.get_state_group_calls[req_id]
|
|
||||||
self.assertEqual(
|
|
||||||
sf,
|
|
||||||
StateFilter.freeze(
|
|
||||||
{"some.state": {str(req_id + 1)}}, include_others=False
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# The 6th request should be the 'all' state filter
|
|
||||||
groups, sf, d = self.get_state_group_calls[CAP_COUNT]
|
|
||||||
self.assertEqual(sf, StateFilter.all())
|
|
||||||
|
|
||||||
# Complete the queries and check which requests complete as a result
|
|
||||||
for req_id in range(CAP_COUNT):
|
|
||||||
# This request should not have been completed yet
|
|
||||||
self.assertFalse(reqs[req_id].called)
|
|
||||||
|
|
||||||
groups, sf, d = self.get_state_group_calls[req_id]
|
|
||||||
self._complete_request_fake(groups, sf, d)
|
|
||||||
|
|
||||||
# This should have only completed this one request
|
|
||||||
self.assertTrue(reqs[req_id].called)
|
|
||||||
|
|
||||||
# Now complete the final query; the last 2 requests should complete
|
|
||||||
# as a result
|
|
||||||
self.assertFalse(reqs[CAP_COUNT].called)
|
|
||||||
self.assertFalse(reqs[CAP_COUNT + 1].called)
|
|
||||||
groups, sf, d = self.get_state_group_calls[CAP_COUNT]
|
|
||||||
self._complete_request_fake(groups, sf, d)
|
|
||||||
self.assertTrue(reqs[CAP_COUNT].called)
|
|
||||||
self.assertTrue(reqs[CAP_COUNT + 1].called)
|
|
||||||
|
|
||||||
@parameterized.expand([(False,), (True,)])
|
|
||||||
def test_ordering_of_request_reuse(self, reverse: bool) -> None:
|
|
||||||
"""
|
|
||||||
Tests that 'larger' in-flight requests are ordered first.
|
|
||||||
|
|
||||||
This is mostly a design decision in order to prevent a request from
|
|
||||||
hanging on to multiple queries when it would have been sufficient to
|
|
||||||
hang on to only one bigger query.
|
|
||||||
|
|
||||||
The 'size' of a state filter is a rough heuristic.
|
|
||||||
|
|
||||||
- requests two pieces of state, one 'larger' than the other, but each
|
|
||||||
spawning a query
|
|
||||||
- requests a third piece of state
|
|
||||||
- completes the larger of the first two queries
|
|
||||||
- checks that the third request gets completed (and doesn't needlessly
|
|
||||||
wait for the other query)
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
reverse: whether to reverse the order of the initial requests, to ensure
|
|
||||||
that the effect doesn't depend on the order of request submission.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# We add in an extra state type to make sure that both requests spawn
|
|
||||||
# queries which are not optimised out.
|
|
||||||
state_filters = [
|
|
||||||
StateFilter.freeze(
|
|
||||||
{"state.type": {"A"}, "other.state.type": {"a"}}, include_others=False
|
|
||||||
),
|
|
||||||
StateFilter.freeze(
|
|
||||||
{
|
|
||||||
"state.type": None,
|
|
||||||
"other.state.type": {"b"},
|
|
||||||
# The current rough size comparator uses the number of state types
|
|
||||||
# as an indicator of size.
|
|
||||||
# To influence it to make this state filter bigger than the previous one,
|
|
||||||
# we add another dummy state type.
|
|
||||||
"extra.state.type": {"c"},
|
|
||||||
},
|
|
||||||
include_others=False,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
if reverse:
|
|
||||||
# For fairness, we perform one test run with the list reversed.
|
|
||||||
state_filters.reverse()
|
|
||||||
smallest_state_filter_idx = 1
|
|
||||||
biggest_state_filter_idx = 0
|
|
||||||
else:
|
|
||||||
smallest_state_filter_idx = 0
|
|
||||||
biggest_state_filter_idx = 1
|
|
||||||
|
|
||||||
# This assertion is for our own sanity more than anything else.
|
|
||||||
self.assertLess(
|
|
||||||
state_filter_rough_priority_comparator(
|
|
||||||
state_filters[biggest_state_filter_idx]
|
|
||||||
),
|
|
||||||
state_filter_rough_priority_comparator(
|
|
||||||
state_filters[smallest_state_filter_idx]
|
|
||||||
),
|
|
||||||
"Test invalid: bigger state filter is not actually bigger.",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Spawn the initial two requests
|
|
||||||
for state_filter in state_filters:
|
|
||||||
ensureDeferred(
|
|
||||||
self.state_datastore._get_state_for_group_using_inflight_cache(
|
|
||||||
42,
|
|
||||||
state_filter,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Spawn a third request
|
|
||||||
req = ensureDeferred(
|
|
||||||
self.state_datastore._get_state_for_group_using_inflight_cache(
|
|
||||||
42,
|
|
||||||
StateFilter.freeze(
|
|
||||||
{
|
|
||||||
"state.type": {"A"},
|
|
||||||
},
|
|
||||||
include_others=False,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.pump(by=0.1)
|
|
||||||
|
|
||||||
self.assertFalse(req.called)
|
|
||||||
|
|
||||||
# Complete the largest request's query to make sure that the final request
|
|
||||||
# only waits for that one (and doesn't needlessly wait for both queries)
|
|
||||||
self._complete_request_fake(
|
|
||||||
*self.get_state_group_calls[biggest_state_filter_idx]
|
|
||||||
)
|
|
||||||
|
|
||||||
# That should have been sufficient to complete the third request
|
|
||||||
self.assertTrue(req.called)
|
|
Loading…
Reference in a new issue