mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-18 18:24:35 +01:00
Choose state algorithm based on room version
This commit is contained in:
parent
152c0aa58e
commit
ce6db0e547
4 changed files with 105 additions and 16 deletions
|
@ -274,8 +274,9 @@ class FederationHandler(BaseHandler):
|
||||||
ev_ids, get_prev_content=False, check_redacted=False
|
ev_ids, get_prev_content=False, check_redacted=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
room_version = yield self.store.get_room_version(pdu.room_id)
|
||||||
state_map = yield resolve_events_with_factory(
|
state_map = yield resolve_events_with_factory(
|
||||||
state_groups, {pdu.event_id: pdu}, fetch
|
room_version, state_groups, {pdu.event_id: pdu}, fetch
|
||||||
)
|
)
|
||||||
|
|
||||||
state = (yield self.store.get_events(state_map.values())).values()
|
state = (yield self.store.get_events(state_map.values())).values()
|
||||||
|
@ -1811,7 +1812,10 @@ class FederationHandler(BaseHandler):
|
||||||
(d.type, d.state_key): d for d in different_events if d
|
(d.type, d.state_key): d for d in different_events if d
|
||||||
})
|
})
|
||||||
|
|
||||||
new_state = self.state_handler.resolve_events(
|
room_version = yield self.store.get_room_version(event.room_id)
|
||||||
|
|
||||||
|
new_state = yield self.state_handler.resolve_events(
|
||||||
|
room_version,
|
||||||
[list(local_view.values()), list(remote_view.values())],
|
[list(local_view.values()), list(remote_view.values())],
|
||||||
event
|
event
|
||||||
)
|
)
|
||||||
|
|
|
@ -341,9 +341,10 @@ class RoomMemberHandler(object):
|
||||||
prev_events_and_hashes = yield self.store.get_prev_events_for_room(
|
prev_events_and_hashes = yield self.store.get_prev_events_for_room(
|
||||||
room_id,
|
room_id,
|
||||||
)
|
)
|
||||||
latest_event_ids = (
|
latest_event_ids = [
|
||||||
event_id for (event_id, _, _) in prev_events_and_hashes
|
event_id for (event_id, _, _) in prev_events_and_hashes
|
||||||
)
|
]
|
||||||
|
|
||||||
current_state_ids = yield self.state_handler.get_current_state_ids(
|
current_state_ids = yield self.state_handler.get_current_state_ids(
|
||||||
room_id, latest_event_ids=latest_event_ids,
|
room_id, latest_event_ids=latest_event_ids,
|
||||||
)
|
)
|
||||||
|
|
|
@ -23,16 +23,15 @@ from frozendict import frozendict
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes, RoomVersions
|
||||||
from synapse.events.snapshot import EventContext
|
from synapse.events.snapshot import EventContext
|
||||||
|
from synapse.state import v1
|
||||||
from synapse.util.async import Linearizer
|
from synapse.util.async import Linearizer
|
||||||
from synapse.util.caches import CACHE_SIZE_FACTOR
|
from synapse.util.caches import CACHE_SIZE_FACTOR
|
||||||
from synapse.util.caches.expiringcache import ExpiringCache
|
from synapse.util.caches.expiringcache import ExpiringCache
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
|
|
||||||
from .v1 import resolve_events_with_factory, resolve_events_with_state_map
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -263,8 +262,14 @@ class StateHandler(object):
|
||||||
defer.returnValue(context)
|
defer.returnValue(context)
|
||||||
|
|
||||||
logger.debug("calling resolve_state_groups from compute_event_context")
|
logger.debug("calling resolve_state_groups from compute_event_context")
|
||||||
|
if event.type == EventTypes.Create:
|
||||||
|
room_version = event.content.get("room_version", RoomVersions.V1)
|
||||||
|
else:
|
||||||
|
room_version = None
|
||||||
|
|
||||||
entry = yield self.resolve_state_groups_for_events(
|
entry = yield self.resolve_state_groups_for_events(
|
||||||
event.room_id, [e for e, _ in event.prev_events],
|
event.room_id, [e for e, _ in event.prev_events],
|
||||||
|
explicit_room_version=room_version,
|
||||||
)
|
)
|
||||||
|
|
||||||
prev_state_ids = entry.state
|
prev_state_ids = entry.state
|
||||||
|
@ -332,13 +337,17 @@ class StateHandler(object):
|
||||||
defer.returnValue(context)
|
defer.returnValue(context)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def resolve_state_groups_for_events(self, room_id, event_ids):
|
def resolve_state_groups_for_events(self, room_id, event_ids,
|
||||||
|
explicit_room_version=None):
|
||||||
""" Given a list of event_ids this method fetches the state at each
|
""" Given a list of event_ids this method fetches the state at each
|
||||||
event, resolves conflicts between them and returns them.
|
event, resolves conflicts between them and returns them.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
room_id (str):
|
room_id (str)
|
||||||
event_ids (list[str]):
|
event_ids (list[str])
|
||||||
|
explicit_room_version (str|None): If set uses the the given room
|
||||||
|
version to choose the resolution algorithm. If None, then
|
||||||
|
checks the database for room version.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[_StateCacheEntry]: resolved state
|
Deferred[_StateCacheEntry]: resolved state
|
||||||
|
@ -364,8 +373,13 @@ class StateHandler(object):
|
||||||
delta_ids=delta_ids,
|
delta_ids=delta_ids,
|
||||||
))
|
))
|
||||||
|
|
||||||
|
room_version = explicit_room_version
|
||||||
|
if not room_version:
|
||||||
|
room_version = yield self.store.get_room_version(room_id)
|
||||||
|
|
||||||
result = yield self._state_resolution_handler.resolve_state_groups(
|
result = yield self._state_resolution_handler.resolve_state_groups(
|
||||||
room_id, state_groups_ids, None, self._state_map_factory,
|
room_id, room_version, state_groups_ids, None,
|
||||||
|
self._state_map_factory,
|
||||||
)
|
)
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
|
||||||
|
@ -374,7 +388,8 @@ class StateHandler(object):
|
||||||
ev_ids, get_prev_content=False, check_redacted=False,
|
ev_ids, get_prev_content=False, check_redacted=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
def resolve_events(self, state_sets, event):
|
@defer.inlineCallbacks
|
||||||
|
def resolve_events(self, room_version, state_sets, event):
|
||||||
logger.info(
|
logger.info(
|
||||||
"Resolving state for %s with %d groups", event.room_id, len(state_sets)
|
"Resolving state for %s with %d groups", event.room_id, len(state_sets)
|
||||||
)
|
)
|
||||||
|
@ -389,14 +404,18 @@ class StateHandler(object):
|
||||||
for ev in st
|
for ev in st
|
||||||
}
|
}
|
||||||
|
|
||||||
|
room_version = yield self.store.get_room_version(event.room_id)
|
||||||
|
|
||||||
with Measure(self.clock, "state._resolve_events"):
|
with Measure(self.clock, "state._resolve_events"):
|
||||||
new_state = resolve_events_with_state_map(state_set_ids, state_map)
|
new_state = resolve_events_with_state_map(
|
||||||
|
room_version, state_set_ids, state_map,
|
||||||
|
)
|
||||||
|
|
||||||
new_state = {
|
new_state = {
|
||||||
key: state_map[ev_id] for key, ev_id in iteritems(new_state)
|
key: state_map[ev_id] for key, ev_id in iteritems(new_state)
|
||||||
}
|
}
|
||||||
|
|
||||||
return new_state
|
defer.returnValue(new_state)
|
||||||
|
|
||||||
|
|
||||||
class StateResolutionHandler(object):
|
class StateResolutionHandler(object):
|
||||||
|
@ -429,7 +448,7 @@ class StateResolutionHandler(object):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
def resolve_state_groups(
|
def resolve_state_groups(
|
||||||
self, room_id, state_groups_ids, event_map, state_map_factory,
|
self, room_id, room_version, state_groups_ids, event_map, state_map_factory,
|
||||||
):
|
):
|
||||||
"""Resolves conflicts between a set of state groups
|
"""Resolves conflicts between a set of state groups
|
||||||
|
|
||||||
|
@ -438,6 +457,7 @@ class StateResolutionHandler(object):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
room_id (str): room we are resolving for (used for logging)
|
room_id (str): room we are resolving for (used for logging)
|
||||||
|
room_version (str): version of the room
|
||||||
state_groups_ids (dict[int, dict[(str, str), str]]):
|
state_groups_ids (dict[int, dict[(str, str), str]]):
|
||||||
map from state group id to the state in that state group
|
map from state group id to the state in that state group
|
||||||
(where 'state' is a map from state key to event id)
|
(where 'state' is a map from state key to event id)
|
||||||
|
@ -491,6 +511,7 @@ class StateResolutionHandler(object):
|
||||||
logger.info("Resolving conflicted state for %r", room_id)
|
logger.info("Resolving conflicted state for %r", room_id)
|
||||||
with Measure(self.clock, "state._resolve_events"):
|
with Measure(self.clock, "state._resolve_events"):
|
||||||
new_state = yield resolve_events_with_factory(
|
new_state = yield resolve_events_with_factory(
|
||||||
|
room_version,
|
||||||
list(itervalues(state_groups_ids)),
|
list(itervalues(state_groups_ids)),
|
||||||
event_map=event_map,
|
event_map=event_map,
|
||||||
state_map_factory=state_map_factory,
|
state_map_factory=state_map_factory,
|
||||||
|
@ -572,3 +593,64 @@ def _make_state_cache_entry(
|
||||||
prev_group=prev_group,
|
prev_group=prev_group,
|
||||||
delta_ids=delta_ids,
|
delta_ids=delta_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_events_with_state_map(room_version, state_sets, state_map):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
room_version(str): Version of the room
|
||||||
|
state_sets(list): List of dicts of (type, state_key) -> event_id,
|
||||||
|
which are the different state groups to resolve.
|
||||||
|
state_map(dict): a dict from event_id to event, for all events in
|
||||||
|
state_sets.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
dict[(str, str), str]:
|
||||||
|
a map from (type, state_key) to event_id.
|
||||||
|
"""
|
||||||
|
if room_version in (RoomVersions.V1, RoomVersions.VDH_TEST,):
|
||||||
|
return v1.resolve_events_with_state_map(
|
||||||
|
state_sets, state_map,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# This should only happen if we added a version but forgot to add it to
|
||||||
|
# the list above.
|
||||||
|
raise Exception(
|
||||||
|
"No state resolution algorithm defined for version %r" % (room_version,)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_events_with_factory(room_version, state_sets, event_map, state_map_factory):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
room_version(str): Version of the room
|
||||||
|
|
||||||
|
state_sets(list): List of dicts of (type, state_key) -> event_id,
|
||||||
|
which are the different state groups to resolve.
|
||||||
|
|
||||||
|
event_map(dict[str,FrozenEvent]|None):
|
||||||
|
a dict from event_id to event, for any events that we happen to
|
||||||
|
have in flight (eg, those currently being persisted). This will be
|
||||||
|
used as a starting point fof finding the state we need; any missing
|
||||||
|
events will be requested via state_map_factory.
|
||||||
|
|
||||||
|
If None, all events will be fetched via state_map_factory.
|
||||||
|
|
||||||
|
state_map_factory(func): will be called
|
||||||
|
with a list of event_ids that are needed, and should return with
|
||||||
|
a Deferred of dict of event_id to event.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
Deferred[dict[(str, str), str]]:
|
||||||
|
a map from (type, state_key) to event_id.
|
||||||
|
"""
|
||||||
|
if room_version in (RoomVersions.V1, RoomVersions.VDH_TEST,):
|
||||||
|
return v1.resolve_events_with_factory(
|
||||||
|
state_sets, event_map, state_map_factory,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# This should only happen if we added a version but forgot to add it to
|
||||||
|
# the list above.
|
||||||
|
raise Exception(
|
||||||
|
"No state resolution algorithm defined for version %r" % (room_version,)
|
||||||
|
)
|
||||||
|
|
|
@ -705,9 +705,11 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
|
||||||
}
|
}
|
||||||
|
|
||||||
events_map = {ev.event_id: ev for ev, _ in events_context}
|
events_map = {ev.event_id: ev for ev, _ in events_context}
|
||||||
|
room_version = yield self.get_room_version(room_id)
|
||||||
|
|
||||||
logger.debug("calling resolve_state_groups from preserve_events")
|
logger.debug("calling resolve_state_groups from preserve_events")
|
||||||
res = yield self._state_resolution_handler.resolve_state_groups(
|
res = yield self._state_resolution_handler.resolve_state_groups(
|
||||||
room_id, state_groups, events_map, get_events
|
room_id, room_version, state_groups, events_map, get_events
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue((res.state, None))
|
defer.returnValue((res.state, None))
|
||||||
|
|
Loading…
Reference in a new issue