0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-14 19:13:51 +01:00

Define StateMap as immutable and add a MutableStateMap type. (#8183)

This commit is contained in:
Patrick Cloke 2020-08-28 07:28:53 -04:00 committed by GitHub
parent 2c2e649be2
commit d5e73cb6aa
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 52 additions and 32 deletions

1
changelog.d/8183.misc Normal file
View file

@ -0,0 +1 @@
Add type hints to `synapse.state`.

View file

@ -72,7 +72,13 @@ from synapse.replication.http.federation import (
from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
from synapse.state import StateResolutionStore, resolve_events_with_store from synapse.state import StateResolutionStore, resolve_events_with_store
from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.types import JsonDict, StateMap, UserID, get_domain_from_id from synapse.types import (
JsonDict,
MutableStateMap,
StateMap,
UserID,
get_domain_from_id,
)
from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.distributor import user_joined_room from synapse.util.distributor import user_joined_room
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
@ -96,7 +102,7 @@ class _NewEventInfo:
event = attr.ib(type=EventBase) event = attr.ib(type=EventBase)
state = attr.ib(type=Optional[Sequence[EventBase]], default=None) state = attr.ib(type=Optional[Sequence[EventBase]], default=None)
auth_events = attr.ib(type=Optional[StateMap[EventBase]], default=None) auth_events = attr.ib(type=Optional[MutableStateMap[EventBase]], default=None)
class FederationHandler(BaseHandler): class FederationHandler(BaseHandler):
@ -2053,7 +2059,7 @@ class FederationHandler(BaseHandler):
origin: str, origin: str,
event: EventBase, event: EventBase,
state: Optional[Iterable[EventBase]], state: Optional[Iterable[EventBase]],
auth_events: Optional[StateMap[EventBase]], auth_events: Optional[MutableStateMap[EventBase]],
backfilled: bool, backfilled: bool,
) -> EventContext: ) -> EventContext:
context = await self.state_handler.compute_event_context(event, old_state=state) context = await self.state_handler.compute_event_context(event, old_state=state)
@ -2137,7 +2143,9 @@ class FederationHandler(BaseHandler):
current_states = await self.state_handler.resolve_events( current_states = await self.state_handler.resolve_events(
room_version, state_sets, event room_version, state_sets, event
) )
current_state_ids = {k: e.event_id for k, e in current_states.items()} current_state_ids = {
k: e.event_id for k, e in current_states.items()
} # type: StateMap[str]
else: else:
current_state_ids = await self.state_handler.get_current_state_ids( current_state_ids = await self.state_handler.get_current_state_ids(
event.room_id, latest_event_ids=extrem_ids event.room_id, latest_event_ids=extrem_ids
@ -2223,7 +2231,7 @@ class FederationHandler(BaseHandler):
origin: str, origin: str,
event: EventBase, event: EventBase,
context: EventContext, context: EventContext,
auth_events: StateMap[EventBase], auth_events: MutableStateMap[EventBase],
) -> EventContext: ) -> EventContext:
""" """
@ -2274,7 +2282,7 @@ class FederationHandler(BaseHandler):
origin: str, origin: str,
event: EventBase, event: EventBase,
context: EventContext, context: EventContext,
auth_events: StateMap[EventBase], auth_events: MutableStateMap[EventBase],
) -> EventContext: ) -> EventContext:
"""Helper for do_auth. See there for docs. """Helper for do_auth. See there for docs.

View file

@ -41,6 +41,7 @@ from synapse.http.endpoint import parse_and_validate_server_name
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.types import ( from synapse.types import (
JsonDict, JsonDict,
MutableStateMap,
Requester, Requester,
RoomAlias, RoomAlias,
RoomID, RoomID,
@ -814,7 +815,7 @@ class RoomCreationHandler(BaseHandler):
room_id: str, room_id: str,
preset_config: str, preset_config: str,
invite_list: List[str], invite_list: List[str],
initial_state: StateMap, initial_state: MutableStateMap,
creation_content: JsonDict, creation_content: JsonDict,
room_alias: Optional[RoomAlias] = None, room_alias: Optional[RoomAlias] = None,
power_level_content_override: Optional[JsonDict] = None, power_level_content_override: Optional[JsonDict] = None,

View file

@ -31,6 +31,7 @@ from synapse.storage.state import StateFilter
from synapse.types import ( from synapse.types import (
Collection, Collection,
JsonDict, JsonDict,
MutableStateMap,
RoomStreamToken, RoomStreamToken,
StateMap, StateMap,
StreamToken, StreamToken,
@ -588,7 +589,7 @@ class SyncHandler(object):
room_id: str, room_id: str,
sync_config: SyncConfig, sync_config: SyncConfig,
batch: TimelineBatch, batch: TimelineBatch,
state: StateMap[EventBase], state: MutableStateMap[EventBase],
now_token: StreamToken, now_token: StreamToken,
) -> Optional[JsonDict]: ) -> Optional[JsonDict]:
""" Works out a room summary block for this room, summarising the number """ Works out a room summary block for this room, summarising the number
@ -736,7 +737,7 @@ class SyncHandler(object):
since_token: Optional[StreamToken], since_token: Optional[StreamToken],
now_token: StreamToken, now_token: StreamToken,
full_state: bool, full_state: bool,
) -> StateMap[EventBase]: ) -> MutableStateMap[EventBase]:
""" Works out the difference in state between the start of the timeline """ Works out the difference in state between the start of the timeline
and the previous sync. and the previous sync.

View file

@ -25,6 +25,7 @@ from typing import (
Sequence, Sequence,
Set, Set,
Union, Union,
cast,
overload, overload,
) )
@ -41,7 +42,7 @@ from synapse.logging.utils import log_function
from synapse.state import v1, v2 from synapse.state import v1, v2
from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.roommember import ProfileInfo from synapse.storage.roommember import ProfileInfo
from synapse.types import Collection, StateMap from synapse.types import Collection, MutableStateMap, StateMap
from synapse.util import Clock from synapse.util import Clock
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
@ -205,7 +206,7 @@ class StateHandler(object):
logger.debug("calling resolve_state_groups from get_current_state_ids") logger.debug("calling resolve_state_groups from get_current_state_ids")
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids) ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
return dict(ret.state) return ret.state
async def get_current_users_in_room( async def get_current_users_in_room(
self, room_id: str, latest_event_ids: Optional[List[str]] = None self, room_id: str, latest_event_ids: Optional[List[str]] = None
@ -302,7 +303,7 @@ class StateHandler(object):
# if we're given the state before the event, then we use that # if we're given the state before the event, then we use that
state_ids_before_event = { state_ids_before_event = {
(s.type, s.state_key): s.event_id for s in old_state (s.type, s.state_key): s.event_id for s in old_state
} } # type: StateMap[str]
state_group_before_event = None state_group_before_event = None
state_group_before_event_prev_group = None state_group_before_event_prev_group = None
deltas_to_state_group_before_event = None deltas_to_state_group_before_event = None
@ -315,7 +316,7 @@ class StateHandler(object):
event.room_id, event.prev_event_ids() event.room_id, event.prev_event_ids()
) )
state_ids_before_event = dict(entry.state) state_ids_before_event = entry.state
state_group_before_event = entry.state_group state_group_before_event = entry.state_group
state_group_before_event_prev_group = entry.prev_group state_group_before_event_prev_group = entry.prev_group
deltas_to_state_group_before_event = entry.delta_ids deltas_to_state_group_before_event = entry.delta_ids
@ -540,7 +541,7 @@ class StateResolutionHandler(object):
# #
# XXX: is this actually worthwhile, or should we just let # XXX: is this actually worthwhile, or should we just let
# resolve_events_with_store do it? # resolve_events_with_store do it?
new_state = {} new_state = {} # type: MutableStateMap[str]
conflicted_state = False conflicted_state = False
for st in state_groups_ids.values(): for st in state_groups_ids.values():
for key, e_id in st.items(): for key, e_id in st.items():
@ -554,13 +555,20 @@ class StateResolutionHandler(object):
if conflicted_state: if conflicted_state:
logger.info("Resolving conflicted state for %r", room_id) logger.info("Resolving conflicted state for %r", room_id)
with Measure(self.clock, "state._resolve_events"): with Measure(self.clock, "state._resolve_events"):
new_state = await resolve_events_with_store( # resolve_events_with_store returns a StateMap, but we can
# treat it as a MutableStateMap as it is above. It isn't
# actually mutated anymore (and is frozen in
# _make_state_cache_entry below).
new_state = cast(
MutableStateMap,
await resolve_events_with_store(
self.clock, self.clock,
room_id, room_id,
room_version, room_version,
list(state_groups_ids.values()), list(state_groups_ids.values()),
event_map=event_map, event_map=event_map,
state_res_store=state_res_store, state_res_store=state_res_store,
),
) )
# if the new state matches any of the input state groups, we can # if the new state matches any of the input state groups, we can

View file

@ -32,7 +32,7 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase from synapse.events import EventBase
from synapse.types import StateMap from synapse.types import MutableStateMap, StateMap
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -131,7 +131,7 @@ async def resolve_events_with_store(
def _seperate( def _seperate(
state_sets: Iterable[StateMap[str]], state_sets: Iterable[StateMap[str]],
) -> Tuple[StateMap[str], StateMap[Set[str]]]: ) -> Tuple[MutableStateMap[str], MutableStateMap[Set[str]]]:
"""Takes the state_sets and figures out which keys are conflicted and """Takes the state_sets and figures out which keys are conflicted and
which aren't. i.e., which have multiple different event_ids associated which aren't. i.e., which have multiple different event_ids associated
with them in different state sets. with them in different state sets.
@ -152,7 +152,7 @@ def _seperate(
""" """
state_set_iterator = iter(state_sets) state_set_iterator = iter(state_sets)
unconflicted_state = dict(next(state_set_iterator)) unconflicted_state = dict(next(state_set_iterator))
conflicted_state = {} # type: StateMap[Set[str]] conflicted_state = {} # type: MutableStateMap[Set[str]]
for state_set in state_set_iterator: for state_set in state_set_iterator:
for key, value in state_set.items(): for key, value in state_set.items():
@ -208,7 +208,7 @@ def _create_auth_events_from_maps(
def _resolve_with_state( def _resolve_with_state(
unconflicted_state_ids: StateMap[str], unconflicted_state_ids: MutableStateMap[str],
conflicted_state_ids: StateMap[Set[str]], conflicted_state_ids: StateMap[Set[str]],
auth_event_ids: StateMap[str], auth_event_ids: StateMap[str],
state_map: Dict[str, EventBase], state_map: Dict[str, EventBase],
@ -241,7 +241,7 @@ def _resolve_with_state(
def _resolve_state_events( def _resolve_state_events(
conflicted_state: StateMap[List[EventBase]], auth_events: StateMap[EventBase] conflicted_state: StateMap[List[EventBase]], auth_events: MutableStateMap[EventBase]
) -> StateMap[EventBase]: ) -> StateMap[EventBase]:
""" This is where we actually decide which of the conflicted state to """ This is where we actually decide which of the conflicted state to
use. use.

View file

@ -38,7 +38,7 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase from synapse.events import EventBase
from synapse.types import StateMap from synapse.types import MutableStateMap, StateMap
from synapse.util import Clock from synapse.util import Clock
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -414,7 +414,7 @@ async def _iterative_auth_checks(
base_state: StateMap[str], base_state: StateMap[str],
event_map: Dict[str, EventBase], event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore", state_res_store: "synapse.state.StateResolutionStore",
) -> StateMap[str]: ) -> MutableStateMap[str]:
"""Sequentially apply auth checks to each event in given list, updating the """Sequentially apply auth checks to each event in given list, updating the
state as it goes along. state as it goes along.
@ -430,7 +430,7 @@ async def _iterative_auth_checks(
Returns: Returns:
Returns the final updated state Returns the final updated state
""" """
resolved_state = base_state.copy() resolved_state = dict(base_state)
room_version_obj = KNOWN_ROOM_VERSIONS[room_version] room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
for idx, event_id in enumerate(event_ids, start=1): for idx, event_id in enumerate(event_ids, start=1):

View file

@ -18,7 +18,7 @@ import re
import string import string
import sys import sys
from collections import namedtuple from collections import namedtuple
from typing import Any, Dict, Tuple, Type, TypeVar from typing import Any, Dict, Mapping, MutableMapping, Tuple, Type, TypeVar
import attr import attr
from signedjson.key import decode_verify_key_bytes from signedjson.key import decode_verify_key_bytes
@ -41,8 +41,9 @@ else:
# Define a state map type from type/state_key to T (usually an event ID or # Define a state map type from type/state_key to T (usually an event ID or
# event) # event)
T = TypeVar("T") T = TypeVar("T")
StateMap = Dict[Tuple[str, str], T] StateKey = Tuple[str, str]
StateMap = Mapping[StateKey, T]
MutableStateMap = MutableMapping[StateKey, T]
# the type of a JSON-serialisable dict. This could be made stronger, but it will # the type of a JSON-serialisable dict. This could be made stronger, but it will
# do for now. # do for now.