Compare commits
16 commits
develop
...
rei/gsfg_1
Author | SHA1 | Date | |
---|---|---|---|
e38f7953ef | |||
db840d2ad5 | |||
5352f2109c | |||
8ea530a7e5 | |||
f78a082989 | |||
af85ac449d | |||
58ef32e272 | |||
cc76d9f100 | |||
a30042b16a | |||
857b2d2039 | |||
831a7a4592 | |||
7dcdab407c | |||
247f558c1c | |||
d998903d46 | |||
1fceefd65d | |||
363565e6ac |
1
changelog.d/10870.misc
Normal file
1
changelog.d/10870.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Deduplicate in-flight requests in `_get_state_for_groups`.
|
1
mypy.ini
1
mypy.ini
|
@ -82,6 +82,7 @@ files =
|
||||||
tests/rest/client/test_relations.py,
|
tests/rest/client/test_relations.py,
|
||||||
tests/rest/media/v1/test_filepath.py,
|
tests/rest/media/v1/test_filepath.py,
|
||||||
tests/rest/media/v1/test_oembed.py,
|
tests/rest/media/v1/test_oembed.py,
|
||||||
|
tests/storage/databases/test_state_store.py,
|
||||||
tests/storage/test_state.py,
|
tests/storage/test_state.py,
|
||||||
tests/storage/test_user_directory.py,
|
tests/storage/test_user_directory.py,
|
||||||
tests/util/test_itertools.py,
|
tests/util/test_itertools.py,
|
||||||
|
|
|
@ -13,11 +13,23 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Collection,
|
||||||
|
Dict,
|
||||||
|
Iterable,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Set,
|
||||||
|
Tuple,
|
||||||
|
)
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
|
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||||
from synapse.storage._base import SQLBaseStore
|
from synapse.storage._base import SQLBaseStore
|
||||||
from synapse.storage.database import (
|
from synapse.storage.database import (
|
||||||
DatabasePool,
|
DatabasePool,
|
||||||
|
@ -29,6 +41,8 @@ from synapse.storage.state import StateFilter
|
||||||
from synapse.storage.types import Cursor
|
from synapse.storage.types import Cursor
|
||||||
from synapse.storage.util.sequence import build_sequence_generator
|
from synapse.storage.util.sequence import build_sequence_generator
|
||||||
from synapse.types import MutableStateMap, StateKey, StateMap
|
from synapse.types import MutableStateMap, StateKey, StateMap
|
||||||
|
from synapse.util import unwrapFirstError
|
||||||
|
from synapse.util.async_helpers import ObservableDeferred, yieldable_gather_results
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import cached
|
||||||
from synapse.util.caches.dictionary_cache import DictionaryCache
|
from synapse.util.caches.dictionary_cache import DictionaryCache
|
||||||
|
|
||||||
|
@ -37,7 +51,6 @@ if TYPE_CHECKING:
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
MAX_STATE_DELTA_HOPS = 100
|
MAX_STATE_DELTA_HOPS = 100
|
||||||
|
|
||||||
|
|
||||||
|
@ -106,6 +119,12 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
||||||
500000,
|
500000,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Current ongoing get_state_for_groups in-flight requests
|
||||||
|
# {group ID -> {StateFilter -> ObservableDeferred}}
|
||||||
|
self._state_group_inflight_requests: Dict[
|
||||||
|
int, Dict[StateFilter, ObservableDeferred[StateMap[str]]]
|
||||||
|
] = {}
|
||||||
|
|
||||||
def get_max_state_group_txn(txn: Cursor) -> int:
|
def get_max_state_group_txn(txn: Cursor) -> int:
|
||||||
txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
|
txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
|
||||||
return txn.fetchone()[0] # type: ignore
|
return txn.fetchone()[0] # type: ignore
|
||||||
|
@ -157,7 +176,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _get_state_groups_from_groups(
|
async def _get_state_groups_from_groups(
|
||||||
self, groups: List[int], state_filter: StateFilter
|
self, groups: Sequence[int], state_filter: StateFilter
|
||||||
) -> Dict[int, StateMap[str]]:
|
) -> Dict[int, StateMap[str]]:
|
||||||
"""Returns the state groups for a given set of groups from the
|
"""Returns the state groups for a given set of groups from the
|
||||||
database, filtering on types of state events.
|
database, filtering on types of state events.
|
||||||
|
@ -228,6 +247,144 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
||||||
|
|
||||||
return state_filter.filter_state(state_dict_ids), not missing_types
|
return state_filter.filter_state(state_dict_ids), not missing_types
|
||||||
|
|
||||||
|
def _get_state_for_group_gather_inflight_requests(
|
||||||
|
self, group: int, state_filter: StateFilter
|
||||||
|
) -> Tuple[Sequence[ObservableDeferred[StateMap[str]]], StateFilter]:
|
||||||
|
"""
|
||||||
|
Attempts to gather in-flight requests and re-use them to retrieve state
|
||||||
|
for the given state group, filtered with the given state filter.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple (
|
||||||
|
sequence of ObservableDeferreds to observe,
|
||||||
|
StateFilter representing what is left over (what else needs
|
||||||
|
to be requested to fulfil the request)
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
inflight_requests = self._state_group_inflight_requests.get(group)
|
||||||
|
if inflight_requests is None:
|
||||||
|
# no requests for this group, need to retrieve it all ourselves
|
||||||
|
return (), state_filter
|
||||||
|
|
||||||
|
state_filter_left_over = state_filter
|
||||||
|
reusable_requests = []
|
||||||
|
for (
|
||||||
|
request_state_filter,
|
||||||
|
request_deferred,
|
||||||
|
) in inflight_requests.items():
|
||||||
|
new_state_filter_left_over = state_filter_left_over.approx_difference(
|
||||||
|
request_state_filter
|
||||||
|
)
|
||||||
|
if new_state_filter_left_over != state_filter_left_over:
|
||||||
|
# reusing this request narrows our StateFilter down a bit.
|
||||||
|
reusable_requests.append(request_deferred)
|
||||||
|
state_filter_left_over = new_state_filter_left_over
|
||||||
|
if state_filter_left_over == StateFilter.none():
|
||||||
|
# we have managed to collect enough of the in-flight requests
|
||||||
|
# to cover our StateFilter and give us the state we need.
|
||||||
|
break
|
||||||
|
|
||||||
|
return reusable_requests, state_filter_left_over
|
||||||
|
|
||||||
|
async def _get_state_for_group_fire_request(
|
||||||
|
self, group: int, state_filter: StateFilter
|
||||||
|
) -> StateMap[str]:
|
||||||
|
"""
|
||||||
|
Fires off a request to get the state at a state group,
|
||||||
|
potentially filtering by type and/or state key.
|
||||||
|
|
||||||
|
This request will be tracked in the in-flight request cache and automatically
|
||||||
|
removed when it is finished.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
group: ID of the state group for which we want to get state
|
||||||
|
state_filter: the state filter used to fetch state from the database
|
||||||
|
"""
|
||||||
|
cache_sequence_nm = self._state_group_cache.sequence
|
||||||
|
cache_sequence_m = self._state_group_members_cache.sequence
|
||||||
|
|
||||||
|
# Help the cache hit ratio by expanding the filter a bit
|
||||||
|
db_state_filter = state_filter.return_expanded()
|
||||||
|
|
||||||
|
async def _the_request() -> StateMap[str]:
|
||||||
|
group_to_state_dict = await self._get_state_groups_from_groups(
|
||||||
|
(group,), state_filter=db_state_filter
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now let's update the caches
|
||||||
|
self._insert_into_cache(
|
||||||
|
group_to_state_dict,
|
||||||
|
db_state_filter,
|
||||||
|
cache_seq_num_members=cache_sequence_m,
|
||||||
|
cache_seq_num_non_members=cache_sequence_nm,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove ourselves from the in-flight cache
|
||||||
|
group_request_dict = self._state_group_inflight_requests[group]
|
||||||
|
del group_request_dict[db_state_filter]
|
||||||
|
if not group_request_dict:
|
||||||
|
# If there are no more requests in-flight for this group,
|
||||||
|
# clean up the cache by removing the empty dictionary
|
||||||
|
del self._state_group_inflight_requests[group]
|
||||||
|
|
||||||
|
return group_to_state_dict[group]
|
||||||
|
|
||||||
|
# We don't immediately await the result, so must use run_in_background
|
||||||
|
# But we DO await the result before the current log context (request)
|
||||||
|
# finishes, so don't need to run it as a background process.
|
||||||
|
request_deferred = run_in_background(_the_request())
|
||||||
|
observable_deferred = ObservableDeferred(request_deferred)
|
||||||
|
|
||||||
|
# Insert the ObservableDeferred into the cache
|
||||||
|
group_request_dict = self._state_group_inflight_requests.setdefault(group, {})
|
||||||
|
group_request_dict[db_state_filter] = observable_deferred
|
||||||
|
|
||||||
|
return await make_deferred_yieldable(observable_deferred.observe())
|
||||||
|
|
||||||
|
async def _get_state_for_group_using_inflight_cache(
|
||||||
|
self, group: int, state_filter: StateFilter
|
||||||
|
) -> MutableStateMap[str]:
|
||||||
|
"""
|
||||||
|
Gets the state at a state group, potentially filtering by type and/or
|
||||||
|
state key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
group: ID of the state group for which we want to get state
|
||||||
|
state_filter: the state filter used to fetch state from the database
|
||||||
|
Returns:
|
||||||
|
state map
|
||||||
|
"""
|
||||||
|
|
||||||
|
# first, figure out whether we can re-use any in-flight requests
|
||||||
|
# (and if so, what would be left over)
|
||||||
|
(
|
||||||
|
reusable_requests,
|
||||||
|
state_filter_left_over,
|
||||||
|
) = self._get_state_for_group_gather_inflight_requests(group, state_filter)
|
||||||
|
|
||||||
|
if state_filter_left_over != StateFilter.none():
|
||||||
|
# Fetch remaining state
|
||||||
|
remaining = await self._get_state_for_group_fire_request(
|
||||||
|
group, state_filter_left_over
|
||||||
|
)
|
||||||
|
assembled_state: MutableStateMap[str] = dict(remaining)
|
||||||
|
else:
|
||||||
|
assembled_state = {}
|
||||||
|
|
||||||
|
gathered = await make_deferred_yieldable(
|
||||||
|
defer.gatherResults(
|
||||||
|
(r.observe() for r in reusable_requests), consumeErrors=True
|
||||||
|
)
|
||||||
|
).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
|
# assemble our result.
|
||||||
|
for result_piece in gathered:
|
||||||
|
assembled_state.update(result_piece)
|
||||||
|
|
||||||
|
# Filter out any state that may be more than what we asked for.
|
||||||
|
return state_filter.filter_state(assembled_state)
|
||||||
|
|
||||||
async def _get_state_for_groups(
|
async def _get_state_for_groups(
|
||||||
self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
|
self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
|
||||||
) -> Dict[int, MutableStateMap[str]]:
|
) -> Dict[int, MutableStateMap[str]]:
|
||||||
|
@ -269,30 +426,14 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
||||||
if not incomplete_groups:
|
if not incomplete_groups:
|
||||||
return state
|
return state
|
||||||
|
|
||||||
cache_sequence_nm = self._state_group_cache.sequence
|
results_from_requests = await yieldable_gather_results(
|
||||||
cache_sequence_m = self._state_group_members_cache.sequence
|
self._get_state_for_group_using_inflight_cache,
|
||||||
|
incomplete_groups,
|
||||||
# Help the cache hit ratio by expanding the filter a bit
|
state_filter,
|
||||||
db_state_filter = state_filter.return_expanded()
|
|
||||||
|
|
||||||
group_to_state_dict = await self._get_state_groups_from_groups(
|
|
||||||
list(incomplete_groups), state_filter=db_state_filter
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Now lets update the caches
|
for group, group_result in zip(incomplete_groups, results_from_requests):
|
||||||
self._insert_into_cache(
|
state[group] = group_result
|
||||||
group_to_state_dict,
|
|
||||||
db_state_filter,
|
|
||||||
cache_seq_num_members=cache_sequence_m,
|
|
||||||
cache_seq_num_non_members=cache_sequence_nm,
|
|
||||||
)
|
|
||||||
|
|
||||||
# And finally update the result dict, by filtering out any extra
|
|
||||||
# stuff we pulled out of the database.
|
|
||||||
for group, group_state_dict in group_to_state_dict.items():
|
|
||||||
# We just replace any existing entries, as we will have loaded
|
|
||||||
# everything we need from the database anyway.
|
|
||||||
state[group] = state_filter.filter_state(group_state_dict)
|
|
||||||
|
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
|
109
tests/storage/databases/test_state_store.py
Normal file
109
tests/storage/databases/test_state_store.py
Normal file
|
@ -0,0 +1,109 @@
|
||||||
|
from typing import Dict, List, Sequence, Tuple
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from twisted.internet.defer import Deferred, ensureDeferred
|
||||||
|
|
||||||
|
from synapse.storage.state import StateFilter
|
||||||
|
from synapse.types import MutableStateMap, StateMap
|
||||||
|
|
||||||
|
from tests.unittest import HomeserverTestCase
|
||||||
|
|
||||||
|
|
||||||
|
class StateGroupInflightCachingTestCase(HomeserverTestCase):
|
||||||
|
def setUp(self) -> None:
|
||||||
|
super(StateGroupInflightCachingTestCase, self).setUp()
|
||||||
|
# Patch out the `_get_state_groups_from_groups`.
|
||||||
|
# This is useful because it lets us pretend we have a slow database.
|
||||||
|
gsgfg_patch = patch(
|
||||||
|
"synapse.storage.databases.state.store.StateGroupDataStore._get_state_groups_from_groups",
|
||||||
|
self._fake_get_state_groups_from_groups,
|
||||||
|
)
|
||||||
|
gsgfg_patch.start()
|
||||||
|
self.addCleanup(gsgfg_patch.stop)
|
||||||
|
self.gsgfg_calls: List[
|
||||||
|
Tuple[Tuple[int, ...], StateFilter, Deferred[Dict[int, StateMap[str]]]]
|
||||||
|
] = []
|
||||||
|
|
||||||
|
def prepare(self, reactor, clock, homeserver) -> None:
|
||||||
|
super(StateGroupInflightCachingTestCase, self).prepare(
|
||||||
|
reactor, clock, homeserver
|
||||||
|
)
|
||||||
|
self.state_storage = homeserver.get_storage().state
|
||||||
|
self.state_datastore = homeserver.get_datastores().state
|
||||||
|
|
||||||
|
def _fake_get_state_groups_from_groups(
|
||||||
|
self, groups: Sequence[int], state_filter: StateFilter
|
||||||
|
) -> "Deferred[Dict[int, StateMap[str]]]":
|
||||||
|
print("hi", groups, state_filter)
|
||||||
|
d: Deferred[Dict[int, StateMap[str]]] = Deferred()
|
||||||
|
self.gsgfg_calls.append((tuple(groups), state_filter, d))
|
||||||
|
return d
|
||||||
|
|
||||||
|
def _complete_request_fake(
|
||||||
|
self,
|
||||||
|
groups: Tuple[int, ...],
|
||||||
|
state_filter: StateFilter,
|
||||||
|
d: "Deferred[Dict[int, StateMap[str]]]",
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Assemble a fake database response and complete the database request.
|
||||||
|
"""
|
||||||
|
|
||||||
|
result: Dict[int, StateMap[str]] = {}
|
||||||
|
|
||||||
|
for group in groups:
|
||||||
|
group_result: MutableStateMap[str] = {}
|
||||||
|
result[group] = group_result
|
||||||
|
|
||||||
|
for state_type, state_keys in state_filter.types.items():
|
||||||
|
if state_keys is None:
|
||||||
|
group_result[
|
||||||
|
(state_type, "wild wombat")
|
||||||
|
] = f"{group} {state_type} wild wombat"
|
||||||
|
group_result[
|
||||||
|
(state_type, "wild spqr")
|
||||||
|
] = f"{group} {state_type} wild spqr"
|
||||||
|
else:
|
||||||
|
for state_key in state_keys:
|
||||||
|
group_result[
|
||||||
|
(state_type, state_key)
|
||||||
|
] = f"{group} {state_type} {state_key}"
|
||||||
|
|
||||||
|
if state_filter.include_others:
|
||||||
|
group_result[("something.else", "wild")] = "card"
|
||||||
|
|
||||||
|
d.callback(result)
|
||||||
|
|
||||||
|
def test_duplicate_requests_deduplicated(self) -> None:
|
||||||
|
req1 = ensureDeferred(
|
||||||
|
self.state_datastore._get_state_for_group_using_inflight_cache(
|
||||||
|
42, StateFilter.all()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.pump(by=0.1)
|
||||||
|
|
||||||
|
# This should have gone to the database
|
||||||
|
self.assertEqual(len(self.gsgfg_calls), 1)
|
||||||
|
self.assertFalse(req1.called)
|
||||||
|
|
||||||
|
req2 = ensureDeferred(
|
||||||
|
self.state_datastore._get_state_for_group_using_inflight_cache(
|
||||||
|
42, StateFilter.all()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.pump(by=0.1)
|
||||||
|
|
||||||
|
# No more calls should have gone to the database
|
||||||
|
self.assertEqual(len(self.gsgfg_calls), 1)
|
||||||
|
self.assertFalse(req1.called)
|
||||||
|
self.assertFalse(req2.called)
|
||||||
|
|
||||||
|
groups, sf, d = self.gsgfg_calls[0]
|
||||||
|
self.assertEqual(groups, (42,))
|
||||||
|
self.assertEqual(sf, StateFilter.all())
|
||||||
|
|
||||||
|
# Now we can complete the request
|
||||||
|
self._complete_request_fake(groups, sf, d)
|
||||||
|
|
||||||
|
self.assertEqual(self.get_success(req1), {("something.else", "wild"): "card"})
|
||||||
|
self.assertEqual(self.get_success(req2), {("something.else", "wild"): "card"})
|
Loading…
Reference in a new issue