mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-18 07:33:49 +01:00
Choose state algorithm based on room version
This commit is contained in:
parent
152c0aa58e
commit
ce6db0e547
4 changed files with 105 additions and 16 deletions
|
@ -274,8 +274,9 @@ class FederationHandler(BaseHandler):
|
|||
ev_ids, get_prev_content=False, check_redacted=False
|
||||
)
|
||||
|
||||
room_version = yield self.store.get_room_version(pdu.room_id)
|
||||
state_map = yield resolve_events_with_factory(
|
||||
state_groups, {pdu.event_id: pdu}, fetch
|
||||
room_version, state_groups, {pdu.event_id: pdu}, fetch
|
||||
)
|
||||
|
||||
state = (yield self.store.get_events(state_map.values())).values()
|
||||
|
@ -1811,7 +1812,10 @@ class FederationHandler(BaseHandler):
|
|||
(d.type, d.state_key): d for d in different_events if d
|
||||
})
|
||||
|
||||
new_state = self.state_handler.resolve_events(
|
||||
room_version = yield self.store.get_room_version(event.room_id)
|
||||
|
||||
new_state = yield self.state_handler.resolve_events(
|
||||
room_version,
|
||||
[list(local_view.values()), list(remote_view.values())],
|
||||
event
|
||||
)
|
||||
|
|
|
@ -341,9 +341,10 @@ class RoomMemberHandler(object):
|
|||
prev_events_and_hashes = yield self.store.get_prev_events_for_room(
|
||||
room_id,
|
||||
)
|
||||
latest_event_ids = (
|
||||
latest_event_ids = [
|
||||
event_id for (event_id, _, _) in prev_events_and_hashes
|
||||
)
|
||||
]
|
||||
|
||||
current_state_ids = yield self.state_handler.get_current_state_ids(
|
||||
room_id, latest_event_ids=latest_event_ids,
|
||||
)
|
||||
|
|
|
@ -23,16 +23,15 @@ from frozendict import frozendict
|
|||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.constants import EventTypes, RoomVersions
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.state import v1
|
||||
from synapse.util.async import Linearizer
|
||||
from synapse.util.caches import CACHE_SIZE_FACTOR
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
from .v1 import resolve_events_with_factory, resolve_events_with_state_map
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -263,8 +262,14 @@ class StateHandler(object):
|
|||
defer.returnValue(context)
|
||||
|
||||
logger.debug("calling resolve_state_groups from compute_event_context")
|
||||
if event.type == EventTypes.Create:
|
||||
room_version = event.content.get("room_version", RoomVersions.V1)
|
||||
else:
|
||||
room_version = None
|
||||
|
||||
entry = yield self.resolve_state_groups_for_events(
|
||||
event.room_id, [e for e, _ in event.prev_events],
|
||||
explicit_room_version=room_version,
|
||||
)
|
||||
|
||||
prev_state_ids = entry.state
|
||||
|
@ -332,13 +337,17 @@ class StateHandler(object):
|
|||
defer.returnValue(context)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def resolve_state_groups_for_events(self, room_id, event_ids):
|
||||
def resolve_state_groups_for_events(self, room_id, event_ids,
|
||||
explicit_room_version=None):
|
||||
""" Given a list of event_ids this method fetches the state at each
|
||||
event, resolves conflicts between them and returns them.
|
||||
|
||||
Args:
|
||||
room_id (str):
|
||||
event_ids (list[str]):
|
||||
room_id (str)
|
||||
event_ids (list[str])
|
||||
explicit_room_version (str|None): If set uses the the given room
|
||||
version to choose the resolution algorithm. If None, then
|
||||
checks the database for room version.
|
||||
|
||||
Returns:
|
||||
Deferred[_StateCacheEntry]: resolved state
|
||||
|
@ -364,8 +373,13 @@ class StateHandler(object):
|
|||
delta_ids=delta_ids,
|
||||
))
|
||||
|
||||
room_version = explicit_room_version
|
||||
if not room_version:
|
||||
room_version = yield self.store.get_room_version(room_id)
|
||||
|
||||
result = yield self._state_resolution_handler.resolve_state_groups(
|
||||
room_id, state_groups_ids, None, self._state_map_factory,
|
||||
room_id, room_version, state_groups_ids, None,
|
||||
self._state_map_factory,
|
||||
)
|
||||
defer.returnValue(result)
|
||||
|
||||
|
@ -374,7 +388,8 @@ class StateHandler(object):
|
|||
ev_ids, get_prev_content=False, check_redacted=False,
|
||||
)
|
||||
|
||||
def resolve_events(self, state_sets, event):
|
||||
@defer.inlineCallbacks
|
||||
def resolve_events(self, room_version, state_sets, event):
|
||||
logger.info(
|
||||
"Resolving state for %s with %d groups", event.room_id, len(state_sets)
|
||||
)
|
||||
|
@ -389,14 +404,18 @@ class StateHandler(object):
|
|||
for ev in st
|
||||
}
|
||||
|
||||
room_version = yield self.store.get_room_version(event.room_id)
|
||||
|
||||
with Measure(self.clock, "state._resolve_events"):
|
||||
new_state = resolve_events_with_state_map(state_set_ids, state_map)
|
||||
new_state = resolve_events_with_state_map(
|
||||
room_version, state_set_ids, state_map,
|
||||
)
|
||||
|
||||
new_state = {
|
||||
key: state_map[ev_id] for key, ev_id in iteritems(new_state)
|
||||
}
|
||||
|
||||
return new_state
|
||||
defer.returnValue(new_state)
|
||||
|
||||
|
||||
class StateResolutionHandler(object):
|
||||
|
@ -429,7 +448,7 @@ class StateResolutionHandler(object):
|
|||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def resolve_state_groups(
|
||||
self, room_id, state_groups_ids, event_map, state_map_factory,
|
||||
self, room_id, room_version, state_groups_ids, event_map, state_map_factory,
|
||||
):
|
||||
"""Resolves conflicts between a set of state groups
|
||||
|
||||
|
@ -438,6 +457,7 @@ class StateResolutionHandler(object):
|
|||
|
||||
Args:
|
||||
room_id (str): room we are resolving for (used for logging)
|
||||
room_version (str): version of the room
|
||||
state_groups_ids (dict[int, dict[(str, str), str]]):
|
||||
map from state group id to the state in that state group
|
||||
(where 'state' is a map from state key to event id)
|
||||
|
@ -491,6 +511,7 @@ class StateResolutionHandler(object):
|
|||
logger.info("Resolving conflicted state for %r", room_id)
|
||||
with Measure(self.clock, "state._resolve_events"):
|
||||
new_state = yield resolve_events_with_factory(
|
||||
room_version,
|
||||
list(itervalues(state_groups_ids)),
|
||||
event_map=event_map,
|
||||
state_map_factory=state_map_factory,
|
||||
|
@ -572,3 +593,64 @@ def _make_state_cache_entry(
|
|||
prev_group=prev_group,
|
||||
delta_ids=delta_ids,
|
||||
)
|
||||
|
||||
|
||||
def resolve_events_with_state_map(room_version, state_sets, state_map):
|
||||
"""
|
||||
Args:
|
||||
room_version(str): Version of the room
|
||||
state_sets(list): List of dicts of (type, state_key) -> event_id,
|
||||
which are the different state groups to resolve.
|
||||
state_map(dict): a dict from event_id to event, for all events in
|
||||
state_sets.
|
||||
|
||||
Returns
|
||||
dict[(str, str), str]:
|
||||
a map from (type, state_key) to event_id.
|
||||
"""
|
||||
if room_version in (RoomVersions.V1, RoomVersions.VDH_TEST,):
|
||||
return v1.resolve_events_with_state_map(
|
||||
state_sets, state_map,
|
||||
)
|
||||
else:
|
||||
# This should only happen if we added a version but forgot to add it to
|
||||
# the list above.
|
||||
raise Exception(
|
||||
"No state resolution algorithm defined for version %r" % (room_version,)
|
||||
)
|
||||
|
||||
|
||||
def resolve_events_with_factory(room_version, state_sets, event_map, state_map_factory):
|
||||
"""
|
||||
Args:
|
||||
room_version(str): Version of the room
|
||||
|
||||
state_sets(list): List of dicts of (type, state_key) -> event_id,
|
||||
which are the different state groups to resolve.
|
||||
|
||||
event_map(dict[str,FrozenEvent]|None):
|
||||
a dict from event_id to event, for any events that we happen to
|
||||
have in flight (eg, those currently being persisted). This will be
|
||||
used as a starting point fof finding the state we need; any missing
|
||||
events will be requested via state_map_factory.
|
||||
|
||||
If None, all events will be fetched via state_map_factory.
|
||||
|
||||
state_map_factory(func): will be called
|
||||
with a list of event_ids that are needed, and should return with
|
||||
a Deferred of dict of event_id to event.
|
||||
|
||||
Returns
|
||||
Deferred[dict[(str, str), str]]:
|
||||
a map from (type, state_key) to event_id.
|
||||
"""
|
||||
if room_version in (RoomVersions.V1, RoomVersions.VDH_TEST,):
|
||||
return v1.resolve_events_with_factory(
|
||||
state_sets, event_map, state_map_factory,
|
||||
)
|
||||
else:
|
||||
# This should only happen if we added a version but forgot to add it to
|
||||
# the list above.
|
||||
raise Exception(
|
||||
"No state resolution algorithm defined for version %r" % (room_version,)
|
||||
)
|
||||
|
|
|
@ -705,9 +705,11 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
|
|||
}
|
||||
|
||||
events_map = {ev.event_id: ev for ev, _ in events_context}
|
||||
room_version = yield self.get_room_version(room_id)
|
||||
|
||||
logger.debug("calling resolve_state_groups from preserve_events")
|
||||
res = yield self._state_resolution_handler.resolve_state_groups(
|
||||
room_id, state_groups, events_map, get_events
|
||||
room_id, room_version, state_groups, events_map, get_events
|
||||
)
|
||||
|
||||
defer.returnValue((res.state, None))
|
||||
|
|
Loading…
Reference in a new issue