mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-15 01:03:50 +01:00
Move state calculations from rest to handler
This commit is contained in:
parent
1ef7cae41b
commit
fa48020a52
2 changed files with 97 additions and 140 deletions
|
@ -23,6 +23,7 @@ from twisted.internet import defer
|
|||
|
||||
import collections
|
||||
import logging
|
||||
import itertools
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -672,35 +673,10 @@ class SyncHandler(BaseHandler):
|
|||
account_data_by_room,
|
||||
all_ephemeral_by_room,
|
||||
batch, full_state=False):
|
||||
if full_state:
|
||||
state = yield self.get_state_at(room_id, now_token)
|
||||
|
||||
elif batch.limited:
|
||||
current_state = yield self.get_state_at(room_id, now_token)
|
||||
|
||||
state_at_previous_sync = yield self.get_state_at(
|
||||
room_id, stream_position=since_token
|
||||
)
|
||||
|
||||
state = yield self.compute_state_delta(
|
||||
since_token=since_token,
|
||||
previous_state=state_at_previous_sync,
|
||||
current_state=current_state,
|
||||
room_id, batch, sync_config, since_token, now_token,
|
||||
full_state=full_state
|
||||
)
|
||||
else:
|
||||
state = {
|
||||
(event.type, event.state_key): event
|
||||
for event in batch.events if event.is_state()
|
||||
}
|
||||
|
||||
just_joined = yield self.check_joined_room(sync_config, state)
|
||||
if just_joined:
|
||||
state = yield self.get_state_at(room_id, now_token)
|
||||
|
||||
state = {
|
||||
(e.type, e.state_key): e
|
||||
for e in sync_config.filter_collection.filter_room_state(state.values())
|
||||
}
|
||||
|
||||
account_data = self.account_data_for_room(
|
||||
room_id, tags_by_room, account_data_by_room
|
||||
|
@ -766,29 +742,10 @@ class SyncHandler(BaseHandler):
|
|||
|
||||
logger.debug("Recents %r", batch)
|
||||
|
||||
state_events_at_leave = yield self.store.get_state_for_event(
|
||||
leave_event_id
|
||||
)
|
||||
|
||||
if not full_state:
|
||||
state_at_previous_sync = yield self.get_state_at(
|
||||
room_id, stream_position=since_token
|
||||
)
|
||||
|
||||
state_events_delta = yield self.compute_state_delta(
|
||||
since_token=since_token,
|
||||
previous_state=state_at_previous_sync,
|
||||
current_state=state_events_at_leave,
|
||||
room_id, batch, sync_config, since_token, leave_token,
|
||||
full_state=full_state
|
||||
)
|
||||
else:
|
||||
state_events_delta = state_events_at_leave
|
||||
|
||||
state_events_delta = {
|
||||
(e.type, e.state_key): e
|
||||
for e in sync_config.filter_collection.filter_room_state(
|
||||
state_events_delta.values()
|
||||
)
|
||||
}
|
||||
|
||||
account_data = self.account_data_for_room(
|
||||
room_id, tags_by_room, account_data_by_room
|
||||
|
@ -843,15 +800,18 @@ class SyncHandler(BaseHandler):
|
|||
state = {}
|
||||
defer.returnValue(state)
|
||||
|
||||
def compute_state_delta(self, since_token, previous_state, current_state):
|
||||
""" Works out the differnce in state between the current state and the
|
||||
state the client got when it last performed a sync.
|
||||
@defer.inlineCallbacks
|
||||
def compute_state_delta(self, room_id, batch, sync_config, since_token, now_token,
|
||||
full_state):
|
||||
""" Works out the differnce in state between the start of the timeline
|
||||
and the previous sync.
|
||||
|
||||
:param str since_token: the point we are comparing against
|
||||
:param dict[(str,str), synapse.events.FrozenEvent] previous_state: the
|
||||
state to compare to
|
||||
:param dict[(str,str), synapse.events.FrozenEvent] current_state: the
|
||||
new state
|
||||
:param str room_id
|
||||
:param TimelineBatch batch
|
||||
:param sync_config
|
||||
:param str since_token
|
||||
:param str now_token
|
||||
:param bool full_state
|
||||
|
||||
:returns A new event dictionary
|
||||
"""
|
||||
|
@ -860,12 +820,50 @@ class SyncHandler(BaseHandler):
|
|||
# updates even if they occured logically before the previous event.
|
||||
# TODO(mjark) Check for new redactions in the state events.
|
||||
|
||||
state_delta = {}
|
||||
for key, event in current_state.iteritems():
|
||||
if (key not in previous_state or
|
||||
previous_state[key].event_id != event.event_id):
|
||||
state_delta[key] = event
|
||||
return state_delta
|
||||
if full_state:
|
||||
if batch:
|
||||
state = yield self.store.get_state_for_event(batch.events[0].event_id)
|
||||
else:
|
||||
state = yield self.get_state_at(
|
||||
room_id, stream_position=now_token
|
||||
)
|
||||
|
||||
timeline_state = {
|
||||
(event.type, event.state_key): event
|
||||
for event in batch.events if event.is_state()
|
||||
}
|
||||
|
||||
state = _calculate_state(
|
||||
timeline_contains=timeline_state,
|
||||
timeline_start=state,
|
||||
previous={},
|
||||
)
|
||||
elif batch.limited:
|
||||
state_at_previous_sync = yield self.get_state_at(
|
||||
room_id, stream_position=since_token
|
||||
)
|
||||
|
||||
state_at_timeline_start = yield self.store.get_state_for_event(
|
||||
batch.events[0].event_id
|
||||
)
|
||||
|
||||
timeline_state = {
|
||||
(event.type, event.state_key): event
|
||||
for event in batch.events if event.is_state()
|
||||
}
|
||||
|
||||
state = _calculate_state(
|
||||
timeline_contains=timeline_state,
|
||||
timeline_start=state_at_timeline_start,
|
||||
previous=state_at_previous_sync,
|
||||
)
|
||||
else:
|
||||
state = {}
|
||||
|
||||
defer.returnValue({
|
||||
(e.type, e.state_key): e
|
||||
for e in sync_config.filter_collection.filter_room_state(state.values())
|
||||
})
|
||||
|
||||
def check_joined_room(self, sync_config, state_delta):
|
||||
"""
|
||||
|
@ -912,3 +910,37 @@ def _action_has_highlight(actions):
|
|||
pass
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _calculate_state(timeline_contains, timeline_start, previous):
|
||||
"""Works out what state to include in a sync response.
|
||||
|
||||
Args:
|
||||
timeline_contains (dict): state in the timeline
|
||||
timeline_start (dict): state at the start of the timeline
|
||||
previous (dict): state at the end of the previous sync (or empty dict
|
||||
if there is an initial sync)
|
||||
|
||||
Returns:
|
||||
dict
|
||||
"""
|
||||
event_id_to_state = {
|
||||
e.event_id: e
|
||||
for e in itertools.chain(
|
||||
timeline_contains.values(),
|
||||
previous.values(),
|
||||
timeline_start.values(),
|
||||
)
|
||||
}
|
||||
|
||||
tc_ids = set(e.event_id for e in timeline_contains.values())
|
||||
p_ids = set(e.event_id for e in previous.values())
|
||||
ts_ids = set(e.event_id for e in timeline_start.values())
|
||||
|
||||
state_ids = (ts_ids - p_ids) - tc_ids
|
||||
|
||||
evs = (event_id_to_state[e] for e in state_ids)
|
||||
return {
|
||||
(e.type, e.state_key): e
|
||||
for e in evs
|
||||
}
|
||||
|
|
|
@ -20,7 +20,6 @@ from synapse.http.servlet import (
|
|||
)
|
||||
from synapse.handlers.sync import SyncConfig
|
||||
from synapse.types import StreamToken
|
||||
from synapse.events import FrozenEvent
|
||||
from synapse.events.utils import (
|
||||
serialize_event, format_event_for_client_v2_without_room_id,
|
||||
)
|
||||
|
@ -287,9 +286,6 @@ class SyncRestServlet(RestServlet):
|
|||
state_dict = room.state
|
||||
timeline_events = room.timeline.events
|
||||
|
||||
state_dict = SyncRestServlet._rollback_state_for_timeline(
|
||||
state_dict, timeline_events)
|
||||
|
||||
state_events = state_dict.values()
|
||||
|
||||
serialized_state = [serialize(e) for e in state_events]
|
||||
|
@ -314,77 +310,6 @@ class SyncRestServlet(RestServlet):
|
|||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _rollback_state_for_timeline(state, timeline):
|
||||
"""
|
||||
Wind the state dictionary backwards, so that it represents the
|
||||
state at the start of the timeline, rather than at the end.
|
||||
|
||||
:param dict[(str, str), synapse.events.EventBase] state: the
|
||||
state dictionary. Will be updated to the state before the timeline.
|
||||
:param list[synapse.events.EventBase] timeline: the event timeline
|
||||
:return: updated state dictionary
|
||||
"""
|
||||
|
||||
result = state.copy()
|
||||
|
||||
for timeline_event in reversed(timeline):
|
||||
if not timeline_event.is_state():
|
||||
continue
|
||||
|
||||
event_key = (timeline_event.type, timeline_event.state_key)
|
||||
|
||||
logger.debug("Considering %s for removal", event_key)
|
||||
|
||||
state_event = result.get(event_key)
|
||||
if (state_event is None or
|
||||
state_event.event_id != timeline_event.event_id):
|
||||
# the event in the timeline isn't present in the state
|
||||
# dictionary.
|
||||
#
|
||||
# the most likely cause for this is that there was a fork in
|
||||
# the event graph, and the state is no longer valid. Really,
|
||||
# the event shouldn't be in the timeline. We're going to ignore
|
||||
# it for now, however.
|
||||
logger.debug("Found state event %r in timeline which doesn't "
|
||||
"match state dictionary", timeline_event)
|
||||
continue
|
||||
|
||||
prev_event_id = timeline_event.unsigned.get("replaces_state", None)
|
||||
|
||||
prev_content = timeline_event.unsigned.get('prev_content')
|
||||
prev_sender = timeline_event.unsigned.get('prev_sender')
|
||||
# Empircally it seems possible for the event to have a
|
||||
# "replaces_state" key but not a prev_content or prev_sender
|
||||
# markjh conjectures that it could be due to the server not
|
||||
# having a copy of that event.
|
||||
# If this is the case the we ignore the previous event. This will
|
||||
# cause the displayname calculations on the client to be incorrect
|
||||
if prev_event_id is None or not prev_content or not prev_sender:
|
||||
logger.debug(
|
||||
"Removing %r from the state dict, as it is missing"
|
||||
" prev_content (prev_event_id=%r)",
|
||||
timeline_event.event_id, prev_event_id
|
||||
)
|
||||
del result[event_key]
|
||||
else:
|
||||
logger.debug(
|
||||
"Replacing %r with %r in state dict",
|
||||
timeline_event.event_id, prev_event_id
|
||||
)
|
||||
result[event_key] = FrozenEvent({
|
||||
"type": timeline_event.type,
|
||||
"state_key": timeline_event.state_key,
|
||||
"content": prev_content,
|
||||
"sender": prev_sender,
|
||||
"event_id": prev_event_id,
|
||||
"room_id": timeline_event.room_id,
|
||||
})
|
||||
|
||||
logger.debug("New value: %r", result.get(event_key))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
SyncRestServlet(hs).register(http_server)
|
||||
|
|
Loading…
Reference in a new issue