Add method to request 1 state group, tracking the inflight request

This commit is contained in:
Olivier Wilkinson (reivilibre) 2021-09-21 17:12:03 +01:00
parent 363565e6ac
commit 1fceefd65d

View file

@ -13,11 +13,24 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
from typing import (
TYPE_CHECKING,
Collection,
Dict,
Iterable,
Optional,
Sequence,
Set,
Tuple,
)
import attr
from twisted.internet import defer
from twisted.internet.defer import ensureDeferred
from synapse.api.constants import EventTypes
from synapse.logging.context import make_deferred_yieldable
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
@ -29,6 +42,7 @@ from synapse.storage.state import StateFilter
from synapse.storage.types import Cursor
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import MutableStateMap, StateKey, StateMap
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.descriptors import cached
from synapse.util.caches.dictionary_cache import DictionaryCache
@ -37,7 +51,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
MAX_STATE_DELTA_HOPS = 100
@ -106,6 +119,12 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
500000,
)
# Current ongoing get_state_for_groups in-flight requests
# {group ID -> {StateFilter -> ObservableDeferred}}
self._state_group_inflight_requests: Dict[
int, Dict[StateFilter, ObservableDeferred[StateMap[str]]]
] = {}
def get_max_state_group_txn(txn: Cursor) -> int:
txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
return txn.fetchone()[0] # type: ignore
@ -228,6 +247,59 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return state_filter.filter_state(state_dict_ids), not missing_types
async def _get_state_for_group_fire_request(
self, group: int, state_filter: StateFilter
) -> StateMap[str]:
"""
Fires off a request to get the state at a state group,
potentially filtering by type and/or state key.
This request will be tracked in the in-flight request cache and automatically
removed when it is finished.
Args:
group: ID of the state group for which we want to get state
state_filter: the state filter used to fetch state from the database
"""
cache_sequence_nm = self._state_group_cache.sequence
cache_sequence_m = self._state_group_members_cache.sequence
# Help the cache hit ratio by expanding the filter a bit
db_state_filter = state_filter.return_expanded()
async def _the_request() -> StateMap[str]:
group_to_state_dict = await self._get_state_groups_from_groups(
(group,), state_filter=db_state_filter
)
# Now let's update the caches
self._insert_into_cache(
group_to_state_dict,
db_state_filter,
cache_seq_num_members=cache_sequence_m,
cache_seq_num_non_members=cache_sequence_nm,
)
# Remove ourselves from the in-flight cache
group_request_dict = self._state_group_inflight_requests[group]
del group_request_dict[db_state_filter]
if not group_request_dict:
# If there are no more requests in-flight for this group,
# clean up the cache by removing the empty dictionary
del self._state_group_inflight_requests[group]
return group_to_state_dict[group]
# OSTD is this right wrt. log context rules?
request_deferred = ensureDeferred(_the_request())
observable_deferred = ObservableDeferred(request_deferred)
# Insert the ObservableDeferred into the cache
group_request_dict = self._state_group_inflight_requests.setdefault(group, {})
group_request_dict[db_state_filter] = observable_deferred
return await request_deferred
async def _get_state_for_groups(
self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
) -> Dict[int, MutableStateMap[str]]: