Compare commits

...

16 commits

Author SHA1 Message Date
Olivier Wilkinson (reivilibre) e38f7953ef Merge branch 'develop' into rei/gsfg_1 2021-11-04 13:38:46 +00:00
Olivier Wilkinson (reivilibre) db840d2ad5 Simplify gatherResults and fix log contexts 2021-11-04 13:35:13 +00:00
Olivier Wilkinson (reivilibre) 5352f2109c Directly await fetched state for simplicity 2021-11-04 13:31:01 +00:00
Olivier Wilkinson (reivilibre) 8ea530a7e5 Check != against StateFilter.none() for clarity 2021-11-04 13:25:13 +00:00
Olivier Wilkinson (reivilibre) f78a082989 Fix up log contexts in _get_state_for_group_fire_request 2021-11-04 13:23:45 +00:00
Olivier Wilkinson (reivilibre) af85ac449d Merge remote-tracking branch 'origin/develop' into rei/gsfg_1
to introduce `approx_difference`.
2021-10-12 10:47:13 +01:00
Olivier Wilkinson (reivilibre) 58ef32e272 Revert "Add a stub implementation of StateFilter.approx_difference"
This reverts commit 363565e6ac
because a proper implementation is now in develop.
2021-10-12 10:46:50 +01:00
Olivier Wilkinson (reivilibre) cc76d9f100 Use yieldable_gather_results helper because it's more elegant 2021-09-27 12:46:54 +01:00
Olivier Wilkinson (reivilibre) a30042b16a Convert to review comments 2021-09-21 17:14:19 +01:00
Olivier Wilkinson (reivilibre) 857b2d2039 Newsfile
Signed-off-by: Olivier Wilkinson (reivilibre) <oliverw@matrix.org>
2021-09-21 17:14:19 +01:00
Olivier Wilkinson (reivilibre) 831a7a4592 Add some basic tests about requests getting deduplicated
More to follow when #10825 is ready.
2021-09-21 17:14:19 +01:00
Olivier Wilkinson (reivilibre) 7dcdab407c Use the in-flight caches for _get_state_for_groups 2021-09-21 17:14:18 +01:00
Olivier Wilkinson (reivilibre) 247f558c1c Add method to get 1 state group, both using and updating in-flight cache 2021-09-21 17:14:18 +01:00
Olivier Wilkinson (reivilibre) d998903d46 Add method to gather in-flight requests and calculate left-over filter 2021-09-21 17:14:18 +01:00
Olivier Wilkinson (reivilibre) 1fceefd65d Add method to request 1 state group, tracking the inflight request 2021-09-21 17:14:18 +01:00
Olivier Wilkinson (reivilibre) 363565e6ac Add a stub implementation of StateFilter.approx_difference 2021-09-21 17:10:02 +01:00
4 changed files with 277 additions and 25 deletions

1
changelog.d/10870.misc Normal file
View file

@ -0,0 +1 @@
Deduplicate in-flight requests in `_get_state_for_groups`.

View file

@ -82,6 +82,7 @@ files =
tests/rest/client/test_relations.py,
tests/rest/media/v1/test_filepath.py,
tests/rest/media/v1/test_oembed.py,
tests/storage/databases/test_state_store.py,
tests/storage/test_state.py,
tests/storage/test_user_directory.py,
tests/util/test_itertools.py,

View file

@ -13,11 +13,23 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
from typing import (
TYPE_CHECKING,
Collection,
Dict,
Iterable,
Optional,
Sequence,
Set,
Tuple,
)
import attr
from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
@ -29,6 +41,8 @@ from synapse.storage.state import StateFilter
from synapse.storage.types import Cursor
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import MutableStateMap, StateKey, StateMap
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import ObservableDeferred, yieldable_gather_results
from synapse.util.caches.descriptors import cached
from synapse.util.caches.dictionary_cache import DictionaryCache
@ -37,7 +51,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
MAX_STATE_DELTA_HOPS = 100
@ -106,6 +119,12 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
500000,
)
# Current ongoing get_state_for_groups in-flight requests
# {group ID -> {StateFilter -> ObservableDeferred}}
self._state_group_inflight_requests: Dict[
int, Dict[StateFilter, ObservableDeferred[StateMap[str]]]
] = {}
def get_max_state_group_txn(txn: Cursor) -> int:
txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
return txn.fetchone()[0] # type: ignore
@ -157,7 +176,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
)
async def _get_state_groups_from_groups(
self, groups: List[int], state_filter: StateFilter
self, groups: Sequence[int], state_filter: StateFilter
) -> Dict[int, StateMap[str]]:
"""Returns the state groups for a given set of groups from the
database, filtering on types of state events.
@ -228,6 +247,144 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return state_filter.filter_state(state_dict_ids), not missing_types
def _get_state_for_group_gather_inflight_requests(
self, group: int, state_filter: StateFilter
) -> Tuple[Sequence[ObservableDeferred[StateMap[str]]], StateFilter]:
"""
Attempts to gather in-flight requests and re-use them to retrieve state
for the given state group, filtered with the given state filter.
Returns:
tuple (
sequence of ObservableDeferreds to observe,
StateFilter representing what is left over (what else needs
to be requested to fulfil the request)
)
"""
inflight_requests = self._state_group_inflight_requests.get(group)
if inflight_requests is None:
# no requests for this group, need to retrieve it all ourselves
return (), state_filter
state_filter_left_over = state_filter
reusable_requests = []
for (
request_state_filter,
request_deferred,
) in inflight_requests.items():
new_state_filter_left_over = state_filter_left_over.approx_difference(
request_state_filter
)
if new_state_filter_left_over != state_filter_left_over:
# reusing this request narrows our StateFilter down a bit.
reusable_requests.append(request_deferred)
state_filter_left_over = new_state_filter_left_over
if state_filter_left_over == StateFilter.none():
# we have managed to collect enough of the in-flight requests
# to cover our StateFilter and give us the state we need.
break
return reusable_requests, state_filter_left_over
async def _get_state_for_group_fire_request(
self, group: int, state_filter: StateFilter
) -> StateMap[str]:
"""
Fires off a request to get the state at a state group,
potentially filtering by type and/or state key.
This request will be tracked in the in-flight request cache and automatically
removed when it is finished.
Args:
group: ID of the state group for which we want to get state
state_filter: the state filter used to fetch state from the database
"""
cache_sequence_nm = self._state_group_cache.sequence
cache_sequence_m = self._state_group_members_cache.sequence
# Help the cache hit ratio by expanding the filter a bit
db_state_filter = state_filter.return_expanded()
async def _the_request() -> StateMap[str]:
group_to_state_dict = await self._get_state_groups_from_groups(
(group,), state_filter=db_state_filter
)
# Now let's update the caches
self._insert_into_cache(
group_to_state_dict,
db_state_filter,
cache_seq_num_members=cache_sequence_m,
cache_seq_num_non_members=cache_sequence_nm,
)
# Remove ourselves from the in-flight cache
group_request_dict = self._state_group_inflight_requests[group]
del group_request_dict[db_state_filter]
if not group_request_dict:
# If there are no more requests in-flight for this group,
# clean up the cache by removing the empty dictionary
del self._state_group_inflight_requests[group]
return group_to_state_dict[group]
# We don't immediately await the result, so must use run_in_background
# But we DO await the result before the current log context (request)
# finishes, so don't need to run it as a background process.
request_deferred = run_in_background(_the_request())
observable_deferred = ObservableDeferred(request_deferred)
# Insert the ObservableDeferred into the cache
group_request_dict = self._state_group_inflight_requests.setdefault(group, {})
group_request_dict[db_state_filter] = observable_deferred
return await make_deferred_yieldable(observable_deferred.observe())
async def _get_state_for_group_using_inflight_cache(
self, group: int, state_filter: StateFilter
) -> MutableStateMap[str]:
"""
Gets the state at a state group, potentially filtering by type and/or
state key.
Args:
group: ID of the state group for which we want to get state
state_filter: the state filter used to fetch state from the database
Returns:
state map
"""
# first, figure out whether we can re-use any in-flight requests
# (and if so, what would be left over)
(
reusable_requests,
state_filter_left_over,
) = self._get_state_for_group_gather_inflight_requests(group, state_filter)
if state_filter_left_over != StateFilter.none():
# Fetch remaining state
remaining = await self._get_state_for_group_fire_request(
group, state_filter_left_over
)
assembled_state: MutableStateMap[str] = dict(remaining)
else:
assembled_state = {}
gathered = await make_deferred_yieldable(
defer.gatherResults(
(r.observe() for r in reusable_requests), consumeErrors=True
)
).addErrback(unwrapFirstError)
# assemble our result.
for result_piece in gathered:
assembled_state.update(result_piece)
# Filter out any state that may be more than what we asked for.
return state_filter.filter_state(assembled_state)
async def _get_state_for_groups(
self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
) -> Dict[int, MutableStateMap[str]]:
@ -269,30 +426,14 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
if not incomplete_groups:
return state
cache_sequence_nm = self._state_group_cache.sequence
cache_sequence_m = self._state_group_members_cache.sequence
# Help the cache hit ratio by expanding the filter a bit
db_state_filter = state_filter.return_expanded()
group_to_state_dict = await self._get_state_groups_from_groups(
list(incomplete_groups), state_filter=db_state_filter
results_from_requests = await yieldable_gather_results(
self._get_state_for_group_using_inflight_cache,
incomplete_groups,
state_filter,
)
# Now lets update the caches
self._insert_into_cache(
group_to_state_dict,
db_state_filter,
cache_seq_num_members=cache_sequence_m,
cache_seq_num_non_members=cache_sequence_nm,
)
# And finally update the result dict, by filtering out any extra
# stuff we pulled out of the database.
for group, group_state_dict in group_to_state_dict.items():
# We just replace any existing entries, as we will have loaded
# everything we need from the database anyway.
state[group] = state_filter.filter_state(group_state_dict)
for group, group_result in zip(incomplete_groups, results_from_requests):
state[group] = group_result
return state

View file

@ -0,0 +1,109 @@
from typing import Dict, List, Sequence, Tuple
from unittest.mock import patch
from twisted.internet.defer import Deferred, ensureDeferred
from synapse.storage.state import StateFilter
from synapse.types import MutableStateMap, StateMap
from tests.unittest import HomeserverTestCase
class StateGroupInflightCachingTestCase(HomeserverTestCase):
def setUp(self) -> None:
super(StateGroupInflightCachingTestCase, self).setUp()
# Patch out the `_get_state_groups_from_groups`.
# This is useful because it lets us pretend we have a slow database.
gsgfg_patch = patch(
"synapse.storage.databases.state.store.StateGroupDataStore._get_state_groups_from_groups",
self._fake_get_state_groups_from_groups,
)
gsgfg_patch.start()
self.addCleanup(gsgfg_patch.stop)
self.gsgfg_calls: List[
Tuple[Tuple[int, ...], StateFilter, Deferred[Dict[int, StateMap[str]]]]
] = []
def prepare(self, reactor, clock, homeserver) -> None:
super(StateGroupInflightCachingTestCase, self).prepare(
reactor, clock, homeserver
)
self.state_storage = homeserver.get_storage().state
self.state_datastore = homeserver.get_datastores().state
def _fake_get_state_groups_from_groups(
self, groups: Sequence[int], state_filter: StateFilter
) -> "Deferred[Dict[int, StateMap[str]]]":
print("hi", groups, state_filter)
d: Deferred[Dict[int, StateMap[str]]] = Deferred()
self.gsgfg_calls.append((tuple(groups), state_filter, d))
return d
def _complete_request_fake(
self,
groups: Tuple[int, ...],
state_filter: StateFilter,
d: "Deferred[Dict[int, StateMap[str]]]",
) -> None:
"""
Assemble a fake database response and complete the database request.
"""
result: Dict[int, StateMap[str]] = {}
for group in groups:
group_result: MutableStateMap[str] = {}
result[group] = group_result
for state_type, state_keys in state_filter.types.items():
if state_keys is None:
group_result[
(state_type, "wild wombat")
] = f"{group} {state_type} wild wombat"
group_result[
(state_type, "wild spqr")
] = f"{group} {state_type} wild spqr"
else:
for state_key in state_keys:
group_result[
(state_type, state_key)
] = f"{group} {state_type} {state_key}"
if state_filter.include_others:
group_result[("something.else", "wild")] = "card"
d.callback(result)
def test_duplicate_requests_deduplicated(self) -> None:
req1 = ensureDeferred(
self.state_datastore._get_state_for_group_using_inflight_cache(
42, StateFilter.all()
)
)
self.pump(by=0.1)
# This should have gone to the database
self.assertEqual(len(self.gsgfg_calls), 1)
self.assertFalse(req1.called)
req2 = ensureDeferred(
self.state_datastore._get_state_for_group_using_inflight_cache(
42, StateFilter.all()
)
)
self.pump(by=0.1)
# No more calls should have gone to the database
self.assertEqual(len(self.gsgfg_calls), 1)
self.assertFalse(req1.called)
self.assertFalse(req2.called)
groups, sf, d = self.gsgfg_calls[0]
self.assertEqual(groups, (42,))
self.assertEqual(sf, StateFilter.all())
# Now we can complete the request
self._complete_request_fake(groups, sf, d)
self.assertEqual(self.get_success(req1), {("something.else", "wild"): "card"})
self.assertEqual(self.get_success(req2), {("something.else", "wild"): "card"})