Add method to request 1 state group, tracking the inflight request
This commit is contained in:
parent
363565e6ac
commit
1fceefd65d
|
@ -13,11 +13,24 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Collection,
|
||||
Dict,
|
||||
Iterable,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
import attr
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.defer import ensureDeferred
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.database import (
|
||||
DatabasePool,
|
||||
|
@ -29,6 +42,7 @@ from synapse.storage.state import StateFilter
|
|||
from synapse.storage.types import Cursor
|
||||
from synapse.storage.util.sequence import build_sequence_generator
|
||||
from synapse.types import MutableStateMap, StateKey, StateMap
|
||||
from synapse.util.async_helpers import ObservableDeferred
|
||||
from synapse.util.caches.descriptors import cached
|
||||
from synapse.util.caches.dictionary_cache import DictionaryCache
|
||||
|
||||
|
@ -37,7 +51,6 @@ if TYPE_CHECKING:
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
MAX_STATE_DELTA_HOPS = 100
|
||||
|
||||
|
||||
|
@ -106,6 +119,12 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
|||
500000,
|
||||
)
|
||||
|
||||
# Current ongoing get_state_for_groups in-flight requests
|
||||
# {group ID -> {StateFilter -> ObservableDeferred}}
|
||||
self._state_group_inflight_requests: Dict[
|
||||
int, Dict[StateFilter, ObservableDeferred[StateMap[str]]]
|
||||
] = {}
|
||||
|
||||
def get_max_state_group_txn(txn: Cursor) -> int:
|
||||
txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
|
||||
return txn.fetchone()[0] # type: ignore
|
||||
|
@ -228,6 +247,59 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
|||
|
||||
return state_filter.filter_state(state_dict_ids), not missing_types
|
||||
|
||||
async def _get_state_for_group_fire_request(
|
||||
self, group: int, state_filter: StateFilter
|
||||
) -> StateMap[str]:
|
||||
"""
|
||||
Fires off a request to get the state at a state group,
|
||||
potentially filtering by type and/or state key.
|
||||
|
||||
This request will be tracked in the in-flight request cache and automatically
|
||||
removed when it is finished.
|
||||
|
||||
Args:
|
||||
group: ID of the state group for which we want to get state
|
||||
state_filter: the state filter used to fetch state from the database
|
||||
"""
|
||||
cache_sequence_nm = self._state_group_cache.sequence
|
||||
cache_sequence_m = self._state_group_members_cache.sequence
|
||||
|
||||
# Help the cache hit ratio by expanding the filter a bit
|
||||
db_state_filter = state_filter.return_expanded()
|
||||
|
||||
async def _the_request() -> StateMap[str]:
|
||||
group_to_state_dict = await self._get_state_groups_from_groups(
|
||||
(group,), state_filter=db_state_filter
|
||||
)
|
||||
|
||||
# Now let's update the caches
|
||||
self._insert_into_cache(
|
||||
group_to_state_dict,
|
||||
db_state_filter,
|
||||
cache_seq_num_members=cache_sequence_m,
|
||||
cache_seq_num_non_members=cache_sequence_nm,
|
||||
)
|
||||
|
||||
# Remove ourselves from the in-flight cache
|
||||
group_request_dict = self._state_group_inflight_requests[group]
|
||||
del group_request_dict[db_state_filter]
|
||||
if not group_request_dict:
|
||||
# If there are no more requests in-flight for this group,
|
||||
# clean up the cache by removing the empty dictionary
|
||||
del self._state_group_inflight_requests[group]
|
||||
|
||||
return group_to_state_dict[group]
|
||||
|
||||
# OSTD is this right wrt. log context rules?
|
||||
request_deferred = ensureDeferred(_the_request())
|
||||
observable_deferred = ObservableDeferred(request_deferred)
|
||||
|
||||
# Insert the ObservableDeferred into the cache
|
||||
group_request_dict = self._state_group_inflight_requests.setdefault(group, {})
|
||||
group_request_dict[db_state_filter] = observable_deferred
|
||||
|
||||
return await request_deferred
|
||||
|
||||
async def _get_state_for_groups(
|
||||
self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
|
||||
) -> Dict[int, MutableStateMap[str]]:
|
||||
|
|
Loading…
Reference in a new issue