mirror of
https://mau.dev/maunium/synapse.git
synced 2024-11-05 14:18:55 +01:00
Move state calculations from rest to handler
This commit is contained in:
parent
1ef7cae41b
commit
fa48020a52
2 changed files with 97 additions and 140 deletions
|
@ -23,6 +23,7 @@ from twisted.internet import defer
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
import logging
|
import logging
|
||||||
|
import itertools
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -672,35 +673,10 @@ class SyncHandler(BaseHandler):
|
||||||
account_data_by_room,
|
account_data_by_room,
|
||||||
all_ephemeral_by_room,
|
all_ephemeral_by_room,
|
||||||
batch, full_state=False):
|
batch, full_state=False):
|
||||||
if full_state:
|
state = yield self.compute_state_delta(
|
||||||
state = yield self.get_state_at(room_id, now_token)
|
room_id, batch, sync_config, since_token, now_token,
|
||||||
|
full_state=full_state
|
||||||
elif batch.limited:
|
)
|
||||||
current_state = yield self.get_state_at(room_id, now_token)
|
|
||||||
|
|
||||||
state_at_previous_sync = yield self.get_state_at(
|
|
||||||
room_id, stream_position=since_token
|
|
||||||
)
|
|
||||||
|
|
||||||
state = yield self.compute_state_delta(
|
|
||||||
since_token=since_token,
|
|
||||||
previous_state=state_at_previous_sync,
|
|
||||||
current_state=current_state,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
state = {
|
|
||||||
(event.type, event.state_key): event
|
|
||||||
for event in batch.events if event.is_state()
|
|
||||||
}
|
|
||||||
|
|
||||||
just_joined = yield self.check_joined_room(sync_config, state)
|
|
||||||
if just_joined:
|
|
||||||
state = yield self.get_state_at(room_id, now_token)
|
|
||||||
|
|
||||||
state = {
|
|
||||||
(e.type, e.state_key): e
|
|
||||||
for e in sync_config.filter_collection.filter_room_state(state.values())
|
|
||||||
}
|
|
||||||
|
|
||||||
account_data = self.account_data_for_room(
|
account_data = self.account_data_for_room(
|
||||||
room_id, tags_by_room, account_data_by_room
|
room_id, tags_by_room, account_data_by_room
|
||||||
|
@ -766,30 +742,11 @@ class SyncHandler(BaseHandler):
|
||||||
|
|
||||||
logger.debug("Recents %r", batch)
|
logger.debug("Recents %r", batch)
|
||||||
|
|
||||||
state_events_at_leave = yield self.store.get_state_for_event(
|
state_events_delta = yield self.compute_state_delta(
|
||||||
leave_event_id
|
room_id, batch, sync_config, since_token, leave_token,
|
||||||
|
full_state=full_state
|
||||||
)
|
)
|
||||||
|
|
||||||
if not full_state:
|
|
||||||
state_at_previous_sync = yield self.get_state_at(
|
|
||||||
room_id, stream_position=since_token
|
|
||||||
)
|
|
||||||
|
|
||||||
state_events_delta = yield self.compute_state_delta(
|
|
||||||
since_token=since_token,
|
|
||||||
previous_state=state_at_previous_sync,
|
|
||||||
current_state=state_events_at_leave,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
state_events_delta = state_events_at_leave
|
|
||||||
|
|
||||||
state_events_delta = {
|
|
||||||
(e.type, e.state_key): e
|
|
||||||
for e in sync_config.filter_collection.filter_room_state(
|
|
||||||
state_events_delta.values()
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
account_data = self.account_data_for_room(
|
account_data = self.account_data_for_room(
|
||||||
room_id, tags_by_room, account_data_by_room
|
room_id, tags_by_room, account_data_by_room
|
||||||
)
|
)
|
||||||
|
@ -843,15 +800,18 @@ class SyncHandler(BaseHandler):
|
||||||
state = {}
|
state = {}
|
||||||
defer.returnValue(state)
|
defer.returnValue(state)
|
||||||
|
|
||||||
def compute_state_delta(self, since_token, previous_state, current_state):
|
@defer.inlineCallbacks
|
||||||
""" Works out the differnce in state between the current state and the
|
def compute_state_delta(self, room_id, batch, sync_config, since_token, now_token,
|
||||||
state the client got when it last performed a sync.
|
full_state):
|
||||||
|
""" Works out the differnce in state between the start of the timeline
|
||||||
|
and the previous sync.
|
||||||
|
|
||||||
:param str since_token: the point we are comparing against
|
:param str room_id
|
||||||
:param dict[(str,str), synapse.events.FrozenEvent] previous_state: the
|
:param TimelineBatch batch
|
||||||
state to compare to
|
:param sync_config
|
||||||
:param dict[(str,str), synapse.events.FrozenEvent] current_state: the
|
:param str since_token
|
||||||
new state
|
:param str now_token
|
||||||
|
:param bool full_state
|
||||||
|
|
||||||
:returns A new event dictionary
|
:returns A new event dictionary
|
||||||
"""
|
"""
|
||||||
|
@ -860,12 +820,50 @@ class SyncHandler(BaseHandler):
|
||||||
# updates even if they occured logically before the previous event.
|
# updates even if they occured logically before the previous event.
|
||||||
# TODO(mjark) Check for new redactions in the state events.
|
# TODO(mjark) Check for new redactions in the state events.
|
||||||
|
|
||||||
state_delta = {}
|
if full_state:
|
||||||
for key, event in current_state.iteritems():
|
if batch:
|
||||||
if (key not in previous_state or
|
state = yield self.store.get_state_for_event(batch.events[0].event_id)
|
||||||
previous_state[key].event_id != event.event_id):
|
else:
|
||||||
state_delta[key] = event
|
state = yield self.get_state_at(
|
||||||
return state_delta
|
room_id, stream_position=now_token
|
||||||
|
)
|
||||||
|
|
||||||
|
timeline_state = {
|
||||||
|
(event.type, event.state_key): event
|
||||||
|
for event in batch.events if event.is_state()
|
||||||
|
}
|
||||||
|
|
||||||
|
state = _calculate_state(
|
||||||
|
timeline_contains=timeline_state,
|
||||||
|
timeline_start=state,
|
||||||
|
previous={},
|
||||||
|
)
|
||||||
|
elif batch.limited:
|
||||||
|
state_at_previous_sync = yield self.get_state_at(
|
||||||
|
room_id, stream_position=since_token
|
||||||
|
)
|
||||||
|
|
||||||
|
state_at_timeline_start = yield self.store.get_state_for_event(
|
||||||
|
batch.events[0].event_id
|
||||||
|
)
|
||||||
|
|
||||||
|
timeline_state = {
|
||||||
|
(event.type, event.state_key): event
|
||||||
|
for event in batch.events if event.is_state()
|
||||||
|
}
|
||||||
|
|
||||||
|
state = _calculate_state(
|
||||||
|
timeline_contains=timeline_state,
|
||||||
|
timeline_start=state_at_timeline_start,
|
||||||
|
previous=state_at_previous_sync,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
state = {}
|
||||||
|
|
||||||
|
defer.returnValue({
|
||||||
|
(e.type, e.state_key): e
|
||||||
|
for e in sync_config.filter_collection.filter_room_state(state.values())
|
||||||
|
})
|
||||||
|
|
||||||
def check_joined_room(self, sync_config, state_delta):
|
def check_joined_room(self, sync_config, state_delta):
|
||||||
"""
|
"""
|
||||||
|
@ -912,3 +910,37 @@ def _action_has_highlight(actions):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _calculate_state(timeline_contains, timeline_start, previous):
|
||||||
|
"""Works out what state to include in a sync response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeline_contains (dict): state in the timeline
|
||||||
|
timeline_start (dict): state at the start of the timeline
|
||||||
|
previous (dict): state at the end of the previous sync (or empty dict
|
||||||
|
if there is an initial sync)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict
|
||||||
|
"""
|
||||||
|
event_id_to_state = {
|
||||||
|
e.event_id: e
|
||||||
|
for e in itertools.chain(
|
||||||
|
timeline_contains.values(),
|
||||||
|
previous.values(),
|
||||||
|
timeline_start.values(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
tc_ids = set(e.event_id for e in timeline_contains.values())
|
||||||
|
p_ids = set(e.event_id for e in previous.values())
|
||||||
|
ts_ids = set(e.event_id for e in timeline_start.values())
|
||||||
|
|
||||||
|
state_ids = (ts_ids - p_ids) - tc_ids
|
||||||
|
|
||||||
|
evs = (event_id_to_state[e] for e in state_ids)
|
||||||
|
return {
|
||||||
|
(e.type, e.state_key): e
|
||||||
|
for e in evs
|
||||||
|
}
|
||||||
|
|
|
@ -20,7 +20,6 @@ from synapse.http.servlet import (
|
||||||
)
|
)
|
||||||
from synapse.handlers.sync import SyncConfig
|
from synapse.handlers.sync import SyncConfig
|
||||||
from synapse.types import StreamToken
|
from synapse.types import StreamToken
|
||||||
from synapse.events import FrozenEvent
|
|
||||||
from synapse.events.utils import (
|
from synapse.events.utils import (
|
||||||
serialize_event, format_event_for_client_v2_without_room_id,
|
serialize_event, format_event_for_client_v2_without_room_id,
|
||||||
)
|
)
|
||||||
|
@ -287,9 +286,6 @@ class SyncRestServlet(RestServlet):
|
||||||
state_dict = room.state
|
state_dict = room.state
|
||||||
timeline_events = room.timeline.events
|
timeline_events = room.timeline.events
|
||||||
|
|
||||||
state_dict = SyncRestServlet._rollback_state_for_timeline(
|
|
||||||
state_dict, timeline_events)
|
|
||||||
|
|
||||||
state_events = state_dict.values()
|
state_events = state_dict.values()
|
||||||
|
|
||||||
serialized_state = [serialize(e) for e in state_events]
|
serialized_state = [serialize(e) for e in state_events]
|
||||||
|
@ -314,77 +310,6 @@ class SyncRestServlet(RestServlet):
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _rollback_state_for_timeline(state, timeline):
|
|
||||||
"""
|
|
||||||
Wind the state dictionary backwards, so that it represents the
|
|
||||||
state at the start of the timeline, rather than at the end.
|
|
||||||
|
|
||||||
:param dict[(str, str), synapse.events.EventBase] state: the
|
|
||||||
state dictionary. Will be updated to the state before the timeline.
|
|
||||||
:param list[synapse.events.EventBase] timeline: the event timeline
|
|
||||||
:return: updated state dictionary
|
|
||||||
"""
|
|
||||||
|
|
||||||
result = state.copy()
|
|
||||||
|
|
||||||
for timeline_event in reversed(timeline):
|
|
||||||
if not timeline_event.is_state():
|
|
||||||
continue
|
|
||||||
|
|
||||||
event_key = (timeline_event.type, timeline_event.state_key)
|
|
||||||
|
|
||||||
logger.debug("Considering %s for removal", event_key)
|
|
||||||
|
|
||||||
state_event = result.get(event_key)
|
|
||||||
if (state_event is None or
|
|
||||||
state_event.event_id != timeline_event.event_id):
|
|
||||||
# the event in the timeline isn't present in the state
|
|
||||||
# dictionary.
|
|
||||||
#
|
|
||||||
# the most likely cause for this is that there was a fork in
|
|
||||||
# the event graph, and the state is no longer valid. Really,
|
|
||||||
# the event shouldn't be in the timeline. We're going to ignore
|
|
||||||
# it for now, however.
|
|
||||||
logger.debug("Found state event %r in timeline which doesn't "
|
|
||||||
"match state dictionary", timeline_event)
|
|
||||||
continue
|
|
||||||
|
|
||||||
prev_event_id = timeline_event.unsigned.get("replaces_state", None)
|
|
||||||
|
|
||||||
prev_content = timeline_event.unsigned.get('prev_content')
|
|
||||||
prev_sender = timeline_event.unsigned.get('prev_sender')
|
|
||||||
# Empircally it seems possible for the event to have a
|
|
||||||
# "replaces_state" key but not a prev_content or prev_sender
|
|
||||||
# markjh conjectures that it could be due to the server not
|
|
||||||
# having a copy of that event.
|
|
||||||
# If this is the case the we ignore the previous event. This will
|
|
||||||
# cause the displayname calculations on the client to be incorrect
|
|
||||||
if prev_event_id is None or not prev_content or not prev_sender:
|
|
||||||
logger.debug(
|
|
||||||
"Removing %r from the state dict, as it is missing"
|
|
||||||
" prev_content (prev_event_id=%r)",
|
|
||||||
timeline_event.event_id, prev_event_id
|
|
||||||
)
|
|
||||||
del result[event_key]
|
|
||||||
else:
|
|
||||||
logger.debug(
|
|
||||||
"Replacing %r with %r in state dict",
|
|
||||||
timeline_event.event_id, prev_event_id
|
|
||||||
)
|
|
||||||
result[event_key] = FrozenEvent({
|
|
||||||
"type": timeline_event.type,
|
|
||||||
"state_key": timeline_event.state_key,
|
|
||||||
"content": prev_content,
|
|
||||||
"sender": prev_sender,
|
|
||||||
"event_id": prev_event_id,
|
|
||||||
"room_id": timeline_event.room_id,
|
|
||||||
})
|
|
||||||
|
|
||||||
logger.debug("New value: %r", result.get(event_key))
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def register_servlets(hs, http_server):
|
def register_servlets(hs, http_server):
|
||||||
SyncRestServlet(hs).register(http_server)
|
SyncRestServlet(hs).register(http_server)
|
||||||
|
|
Loading…
Reference in a new issue