mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-13 21:43:22 +01:00
Reduce amount of state pulled out when querying federation hierachy (#16785)
There are two changes here: 1. Only pull out the required state when handling the request. 2. Change the get filtered state return type to check that we're only querying state that was requested --------- Co-authored-by: reivilibre <oliverw@matrix.org>
This commit is contained in:
parent
4c67f0391b
commit
578c5c736e
4 changed files with 82 additions and 3 deletions
1
changelog.d/16785.misc
Normal file
1
changelog.d/16785.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Reduce amount of state pulled out when querying federation hierachy.
|
|
@ -44,6 +44,7 @@ from synapse.api.ratelimiting import Ratelimiter
|
||||||
from synapse.config.ratelimiting import RatelimitSettings
|
from synapse.config.ratelimiting import RatelimitSettings
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.types import JsonDict, Requester, StrCollection
|
from synapse.types import JsonDict, Requester, StrCollection
|
||||||
|
from synapse.types.state import StateFilter
|
||||||
from synapse.util.caches.response_cache import ResponseCache
|
from synapse.util.caches.response_cache import ResponseCache
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -546,7 +547,16 @@ class RoomSummaryHandler:
|
||||||
Returns:
|
Returns:
|
||||||
True if the room is accessible to the requesting user or server.
|
True if the room is accessible to the requesting user or server.
|
||||||
"""
|
"""
|
||||||
state_ids = await self._storage_controllers.state.get_current_state_ids(room_id)
|
event_types = [
|
||||||
|
(EventTypes.JoinRules, ""),
|
||||||
|
(EventTypes.RoomHistoryVisibility, ""),
|
||||||
|
]
|
||||||
|
if requester:
|
||||||
|
event_types.append((EventTypes.Member, requester))
|
||||||
|
|
||||||
|
state_ids = await self._storage_controllers.state.get_current_state_ids(
|
||||||
|
room_id, state_filter=StateFilter.from_types(event_types)
|
||||||
|
)
|
||||||
|
|
||||||
# If there's no state for the room, it isn't known.
|
# If there's no state for the room, it isn't known.
|
||||||
if not state_ids:
|
if not state_ids:
|
||||||
|
|
|
@ -30,7 +30,10 @@ from typing import (
|
||||||
Optional,
|
Optional,
|
||||||
Set,
|
Set,
|
||||||
Tuple,
|
Tuple,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
cast,
|
cast,
|
||||||
|
overload,
|
||||||
)
|
)
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
@ -52,7 +55,7 @@ from synapse.storage.database import (
|
||||||
)
|
)
|
||||||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||||
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
|
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
|
||||||
from synapse.types import JsonDict, JsonMapping, StateMap
|
from synapse.types import JsonDict, JsonMapping, StateKey, StateMap
|
||||||
from synapse.types.state import StateFilter
|
from synapse.types.state import StateFilter
|
||||||
from synapse.util.caches import intern_string
|
from synapse.util.caches import intern_string
|
||||||
from synapse.util.caches.descriptors import cached, cachedList
|
from synapse.util.caches.descriptors import cached, cachedList
|
||||||
|
@ -64,6 +67,8 @@ if TYPE_CHECKING:
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_T = TypeVar("_T")
|
||||||
|
|
||||||
|
|
||||||
MAX_STATE_DELTA_HOPS = 100
|
MAX_STATE_DELTA_HOPS = 100
|
||||||
|
|
||||||
|
@ -349,7 +354,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
def _get_filtered_current_state_ids_txn(
|
def _get_filtered_current_state_ids_txn(
|
||||||
txn: LoggingTransaction,
|
txn: LoggingTransaction,
|
||||||
) -> StateMap[str]:
|
) -> StateMap[str]:
|
||||||
results = {}
|
results = StateMapWrapper(state_filter=state_filter or StateFilter.all())
|
||||||
|
|
||||||
sql = """
|
sql = """
|
||||||
SELECT type, state_key, event_id FROM current_state_events
|
SELECT type, state_key, event_id FROM current_state_events
|
||||||
WHERE room_id = ?
|
WHERE room_id = ?
|
||||||
|
@ -726,3 +732,41 @@ class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore):
|
||||||
hs: "HomeServer",
|
hs: "HomeServer",
|
||||||
):
|
):
|
||||||
super().__init__(database, db_conn, hs)
|
super().__init__(database, db_conn, hs)
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(auto_attribs=True, slots=True)
|
||||||
|
class StateMapWrapper(Dict[StateKey, str]):
|
||||||
|
"""A wrapper around a StateMap[str] to ensure that we only query for items
|
||||||
|
that were not filtered out.
|
||||||
|
|
||||||
|
This is to help prevent bugs where we filter out state but other bits of the
|
||||||
|
code expect the state to be there.
|
||||||
|
"""
|
||||||
|
|
||||||
|
state_filter: StateFilter
|
||||||
|
|
||||||
|
def __getitem__(self, key: StateKey) -> str:
|
||||||
|
if key not in self.state_filter:
|
||||||
|
raise Exception("State map was filtered and doesn't include: %s", key)
|
||||||
|
return super().__getitem__(key)
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get(self, key: Tuple[str, str]) -> Optional[str]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get(self, key: Tuple[str, str], default: Union[str, _T]) -> Union[str, _T]:
|
||||||
|
...
|
||||||
|
|
||||||
|
def get(
|
||||||
|
self, key: StateKey, default: Union[str, _T, None] = None
|
||||||
|
) -> Union[str, _T, None]:
|
||||||
|
if key not in self.state_filter:
|
||||||
|
raise Exception("State map was filtered and doesn't include: %s", key)
|
||||||
|
return super().get(key, default)
|
||||||
|
|
||||||
|
def __contains__(self, key: Any) -> bool:
|
||||||
|
if key not in self.state_filter:
|
||||||
|
raise Exception("State map was filtered and doesn't include: %s", key)
|
||||||
|
|
||||||
|
return super().__contains__(key)
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
Collection,
|
Collection,
|
||||||
Dict,
|
Dict,
|
||||||
|
@ -584,6 +585,29 @@ class StateFilter:
|
||||||
# local users only
|
# local users only
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def __contains__(self, key: Any) -> bool:
|
||||||
|
if not isinstance(key, tuple) or len(key) != 2:
|
||||||
|
raise TypeError(
|
||||||
|
f"'in StateFilter' requires (str, str) as left operand, not {type(key).__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
typ, state_key = key
|
||||||
|
|
||||||
|
if not isinstance(typ, str) or not isinstance(state_key, str):
|
||||||
|
raise TypeError(
|
||||||
|
f"'in StateFilter' requires (str, str) as left operand, not ({type(typ).__name__}, {type(state_key).__name__})"
|
||||||
|
)
|
||||||
|
|
||||||
|
if typ in self.types:
|
||||||
|
state_keys = self.types[typ]
|
||||||
|
if state_keys is None or state_key in state_keys:
|
||||||
|
return True
|
||||||
|
|
||||||
|
elif self.include_others:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
_ALL_STATE_FILTER = StateFilter(types=immutabledict(), include_others=True)
|
_ALL_STATE_FILTER = StateFilter(types=immutabledict(), include_others=True)
|
||||||
_ALL_NON_MEMBER_STATE_FILTER = StateFilter(
|
_ALL_NON_MEMBER_STATE_FILTER = StateFilter(
|
||||||
|
|
Loading…
Reference in a new issue