mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-15 01:33:52 +01:00
Make sync not pull out full state
This commit is contained in:
parent
7356d52e73
commit
778fa85f47
2 changed files with 74 additions and 34 deletions
|
@ -355,11 +355,11 @@ class SyncHandler(object):
|
|||
Returns:
|
||||
A Deferred map from ((type, state_key)->Event)
|
||||
"""
|
||||
state = yield self.store.get_state_for_event(event.event_id)
|
||||
state_ids = yield self.store.get_state_ids_for_event(event.event_id)
|
||||
if event.is_state():
|
||||
state = state.copy()
|
||||
state[(event.type, event.state_key)] = event
|
||||
defer.returnValue(state)
|
||||
state_ids = state_ids.copy()
|
||||
state_ids[(event.type, event.state_key)] = event.event_id
|
||||
defer.returnValue(state_ids)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_at(self, room_id, stream_position):
|
||||
|
@ -412,57 +412,61 @@ class SyncHandler(object):
|
|||
with Measure(self.clock, "compute_state_delta"):
|
||||
if full_state:
|
||||
if batch:
|
||||
current_state = yield self.store.get_state_for_event(
|
||||
current_state_ids = yield self.store.get_state_ids_for_event(
|
||||
batch.events[-1].event_id
|
||||
)
|
||||
|
||||
state = yield self.store.get_state_for_event(
|
||||
state_ids = yield self.store.get_state_ids_for_event(
|
||||
batch.events[0].event_id
|
||||
)
|
||||
else:
|
||||
current_state = yield self.get_state_at(
|
||||
current_state_ids = yield self.get_state_at(
|
||||
room_id, stream_position=now_token
|
||||
)
|
||||
|
||||
state = current_state
|
||||
state_ids = current_state_ids
|
||||
|
||||
timeline_state = {
|
||||
(event.type, event.state_key): event
|
||||
(event.type, event.state_key): event.event_id
|
||||
for event in batch.events if event.is_state()
|
||||
}
|
||||
|
||||
state = _calculate_state(
|
||||
state_ids = _calculate_state(
|
||||
timeline_contains=timeline_state,
|
||||
timeline_start=state,
|
||||
timeline_start=state_ids,
|
||||
previous={},
|
||||
current=current_state,
|
||||
current=current_state_ids,
|
||||
)
|
||||
elif batch.limited:
|
||||
state_at_previous_sync = yield self.get_state_at(
|
||||
room_id, stream_position=since_token
|
||||
)
|
||||
|
||||
current_state = yield self.store.get_state_for_event(
|
||||
current_state_ids = yield self.store.get_state_ids_for_event(
|
||||
batch.events[-1].event_id
|
||||
)
|
||||
|
||||
state_at_timeline_start = yield self.store.get_state_for_event(
|
||||
state_at_timeline_start = yield self.store.get_state_ids_for_event(
|
||||
batch.events[0].event_id
|
||||
)
|
||||
|
||||
timeline_state = {
|
||||
(event.type, event.state_key): event
|
||||
(event.type, event.state_key): event.event_id
|
||||
for event in batch.events if event.is_state()
|
||||
}
|
||||
|
||||
state = _calculate_state(
|
||||
state_ids = _calculate_state(
|
||||
timeline_contains=timeline_state,
|
||||
timeline_start=state_at_timeline_start,
|
||||
previous=state_at_previous_sync,
|
||||
current=current_state,
|
||||
current=current_state_ids,
|
||||
)
|
||||
else:
|
||||
state_ids = {}
|
||||
|
||||
state = {}
|
||||
if state_ids:
|
||||
state = yield self.store.get_events(state_ids.values())
|
||||
|
||||
defer.returnValue({
|
||||
(e.type, e.state_key): e
|
||||
|
@ -766,8 +770,13 @@ class SyncHandler(object):
|
|||
# the last sync (even if we have since left). This is to make sure
|
||||
# we do send down the room, and with full state, where necessary
|
||||
if room_id in joined_room_ids or has_join:
|
||||
old_state = yield self.get_state_at(room_id, since_token)
|
||||
old_mem_ev = old_state.get((EventTypes.Member, user_id), None)
|
||||
old_state_ids = yield self.get_state_at(room_id, since_token)
|
||||
old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None)
|
||||
old_mem_ev = None
|
||||
if old_mem_ev_id:
|
||||
old_mem_ev = yield self.store.get_event(
|
||||
old_mem_ev_id, allow_none=True
|
||||
)
|
||||
if not old_mem_ev or old_mem_ev.membership != Membership.JOIN:
|
||||
newly_joined_rooms.append(room_id)
|
||||
|
||||
|
@ -1059,27 +1068,25 @@ def _calculate_state(timeline_contains, timeline_start, previous, current):
|
|||
Returns:
|
||||
dict
|
||||
"""
|
||||
event_id_to_state = {
|
||||
e.event_id: e
|
||||
for e in itertools.chain(
|
||||
timeline_contains.values(),
|
||||
previous.values(),
|
||||
timeline_start.values(),
|
||||
current.values(),
|
||||
event_id_to_key = {
|
||||
e: key
|
||||
for key, e in itertools.chain(
|
||||
timeline_contains.items(),
|
||||
previous.items(),
|
||||
timeline_start.items(),
|
||||
current.items(),
|
||||
)
|
||||
}
|
||||
|
||||
c_ids = set(e.event_id for e in current.values())
|
||||
tc_ids = set(e.event_id for e in timeline_contains.values())
|
||||
p_ids = set(e.event_id for e in previous.values())
|
||||
ts_ids = set(e.event_id for e in timeline_start.values())
|
||||
c_ids = set(e for e in current.values())
|
||||
tc_ids = set(e for e in timeline_contains.values())
|
||||
p_ids = set(e for e in previous.values())
|
||||
ts_ids = set(e for e in timeline_start.values())
|
||||
|
||||
state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids
|
||||
|
||||
evs = (event_id_to_state[e] for e in state_ids)
|
||||
return {
|
||||
(e.type, e.state_key): e
|
||||
for e in evs
|
||||
event_id_to_key[e]: e for e in state_ids
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -283,6 +283,22 @@ class StateStore(SQLBaseStore):
|
|||
|
||||
defer.returnValue({event: event_to_state[event] for event in event_ids})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_ids_for_events(self, event_ids, types):
|
||||
event_to_groups = yield self._get_state_group_for_events(
|
||||
event_ids,
|
||||
)
|
||||
|
||||
groups = set(event_to_groups.values())
|
||||
group_to_state = yield self._get_state_for_groups(groups, types)
|
||||
|
||||
event_to_state = {
|
||||
event_id: group_to_state[group]
|
||||
for event_id, group in event_to_groups.items()
|
||||
}
|
||||
|
||||
defer.returnValue({event: event_to_state[event] for event in event_ids})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_for_event(self, event_id, types=None):
|
||||
"""
|
||||
|
@ -300,6 +316,23 @@ class StateStore(SQLBaseStore):
|
|||
state_map = yield self.get_state_for_events([event_id], types)
|
||||
defer.returnValue(state_map[event_id])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_ids_for_event(self, event_id, types=None):
|
||||
"""
|
||||
Get the state dict corresponding to a particular event
|
||||
|
||||
Args:
|
||||
event_id(str): event whose state should be returned
|
||||
types(list[(str, str)]|None): List of (type, state_key) tuples
|
||||
which are used to filter the state fetched. May be None, which
|
||||
matches any key
|
||||
|
||||
Returns:
|
||||
A deferred dict from (type, state_key) -> state_event
|
||||
"""
|
||||
state_map = yield self.get_state_ids_for_events([event_id], types)
|
||||
defer.returnValue(state_map[event_id])
|
||||
|
||||
@cached(num_args=2, max_entries=10000)
|
||||
def _get_state_group_for_event(self, room_id, event_id):
|
||||
return self._simple_select_one_onecol(
|
||||
|
|
Loading…
Reference in a new issue