mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-19 16:13:58 +01:00
py3-ize state.py
This commit is contained in:
parent
7ea07c7305
commit
0e61705661
1 changed files with 25 additions and 23 deletions
|
@ -32,6 +32,8 @@ from frozendict import frozendict
|
||||||
import logging
|
import logging
|
||||||
import hashlib
|
import hashlib
|
||||||
|
|
||||||
|
from six import iteritems, itervalues
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -132,7 +134,7 @@ class StateHandler(object):
|
||||||
|
|
||||||
state_map = yield self.store.get_events(state.values(), get_prev_content=False)
|
state_map = yield self.store.get_events(state.values(), get_prev_content=False)
|
||||||
state = {
|
state = {
|
||||||
key: state_map[e_id] for key, e_id in state.iteritems() if e_id in state_map
|
key: state_map[e_id] for key, e_id in iteritems(state) if e_id in state_map
|
||||||
}
|
}
|
||||||
|
|
||||||
defer.returnValue(state)
|
defer.returnValue(state)
|
||||||
|
@ -338,7 +340,7 @@ class StateHandler(object):
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(state_groups_ids) == 1:
|
if len(state_groups_ids) == 1:
|
||||||
name, state_list = state_groups_ids.items().pop()
|
name, state_list = list(state_groups_ids.items()).pop()
|
||||||
|
|
||||||
prev_group, delta_ids = yield self.store.get_state_group_delta(name)
|
prev_group, delta_ids = yield self.store.get_state_group_delta(name)
|
||||||
|
|
||||||
|
@ -378,7 +380,7 @@ class StateHandler(object):
|
||||||
new_state = resolve_events_with_state_map(state_set_ids, state_map)
|
new_state = resolve_events_with_state_map(state_set_ids, state_map)
|
||||||
|
|
||||||
new_state = {
|
new_state = {
|
||||||
key: state_map[ev_id] for key, ev_id in new_state.iteritems()
|
key: state_map[ev_id] for key, ev_id in iteritems(new_state)
|
||||||
}
|
}
|
||||||
|
|
||||||
return new_state
|
return new_state
|
||||||
|
@ -458,15 +460,15 @@ class StateResolutionHandler(object):
|
||||||
# build a map from state key to the event_ids which set that state.
|
# build a map from state key to the event_ids which set that state.
|
||||||
# dict[(str, str), set[str])
|
# dict[(str, str), set[str])
|
||||||
state = {}
|
state = {}
|
||||||
for st in state_groups_ids.itervalues():
|
for st in itervalues(state_groups_ids):
|
||||||
for key, e_id in st.iteritems():
|
for key, e_id in iteritems(st):
|
||||||
state.setdefault(key, set()).add(e_id)
|
state.setdefault(key, set()).add(e_id)
|
||||||
|
|
||||||
# build a map from state key to the event_ids which set that state,
|
# build a map from state key to the event_ids which set that state,
|
||||||
# including only those where there are state keys in conflict.
|
# including only those where there are state keys in conflict.
|
||||||
conflicted_state = {
|
conflicted_state = {
|
||||||
k: list(v)
|
k: list(v)
|
||||||
for k, v in state.iteritems()
|
for k, v in iteritems(state)
|
||||||
if len(v) > 1
|
if len(v) > 1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -474,13 +476,13 @@ class StateResolutionHandler(object):
|
||||||
logger.info("Resolving conflicted state for %r", room_id)
|
logger.info("Resolving conflicted state for %r", room_id)
|
||||||
with Measure(self.clock, "state._resolve_events"):
|
with Measure(self.clock, "state._resolve_events"):
|
||||||
new_state = yield resolve_events_with_factory(
|
new_state = yield resolve_events_with_factory(
|
||||||
state_groups_ids.values(),
|
list(state_groups_ids.values()),
|
||||||
event_map=event_map,
|
event_map=event_map,
|
||||||
state_map_factory=state_map_factory,
|
state_map_factory=state_map_factory,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
new_state = {
|
new_state = {
|
||||||
key: e_ids.pop() for key, e_ids in state.iteritems()
|
key: e_ids.pop() for key, e_ids in iteritems(state)
|
||||||
}
|
}
|
||||||
|
|
||||||
with Measure(self.clock, "state.create_group_ids"):
|
with Measure(self.clock, "state.create_group_ids"):
|
||||||
|
@ -489,8 +491,8 @@ class StateResolutionHandler(object):
|
||||||
# which will be used as a cache key for future resolutions, but
|
# which will be used as a cache key for future resolutions, but
|
||||||
# not get persisted.
|
# not get persisted.
|
||||||
state_group = None
|
state_group = None
|
||||||
new_state_event_ids = frozenset(new_state.itervalues())
|
new_state_event_ids = frozenset(itervalues(new_state))
|
||||||
for sg, events in state_groups_ids.iteritems():
|
for sg, events in iteritems(state_groups_ids):
|
||||||
if new_state_event_ids == frozenset(e_id for e_id in events):
|
if new_state_event_ids == frozenset(e_id for e_id in events):
|
||||||
state_group = sg
|
state_group = sg
|
||||||
break
|
break
|
||||||
|
@ -501,11 +503,11 @@ class StateResolutionHandler(object):
|
||||||
|
|
||||||
prev_group = None
|
prev_group = None
|
||||||
delta_ids = None
|
delta_ids = None
|
||||||
for old_group, old_ids in state_groups_ids.iteritems():
|
for old_group, old_ids in iteritems(state_groups_ids):
|
||||||
if not set(new_state) - set(old_ids):
|
if not set(new_state) - set(old_ids):
|
||||||
n_delta_ids = {
|
n_delta_ids = {
|
||||||
k: v
|
k: v
|
||||||
for k, v in new_state.iteritems()
|
for k, v in iteritems(new_state)
|
||||||
if old_ids.get(k) != v
|
if old_ids.get(k) != v
|
||||||
}
|
}
|
||||||
if not delta_ids or len(n_delta_ids) < len(delta_ids):
|
if not delta_ids or len(n_delta_ids) < len(delta_ids):
|
||||||
|
@ -527,7 +529,7 @@ class StateResolutionHandler(object):
|
||||||
|
|
||||||
def _ordered_events(events):
|
def _ordered_events(events):
|
||||||
def key_func(e):
|
def key_func(e):
|
||||||
return -int(e.depth), hashlib.sha1(e.event_id).hexdigest()
|
return -int(e.depth), hashlib.sha1(e.event_id.encode()).hexdigest()
|
||||||
|
|
||||||
return sorted(events, key=key_func)
|
return sorted(events, key=key_func)
|
||||||
|
|
||||||
|
@ -584,7 +586,7 @@ def _seperate(state_sets):
|
||||||
conflicted_state = {}
|
conflicted_state = {}
|
||||||
|
|
||||||
for state_set in state_sets[1:]:
|
for state_set in state_sets[1:]:
|
||||||
for key, value in state_set.iteritems():
|
for key, value in iteritems(state_set):
|
||||||
# Check if there is an unconflicted entry for the state key.
|
# Check if there is an unconflicted entry for the state key.
|
||||||
unconflicted_value = unconflicted_state.get(key)
|
unconflicted_value = unconflicted_state.get(key)
|
||||||
if unconflicted_value is None:
|
if unconflicted_value is None:
|
||||||
|
@ -640,7 +642,7 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory):
|
||||||
|
|
||||||
needed_events = set(
|
needed_events = set(
|
||||||
event_id
|
event_id
|
||||||
for event_ids in conflicted_state.itervalues()
|
for event_ids in itervalues(conflicted_state)
|
||||||
for event_id in event_ids
|
for event_id in event_ids
|
||||||
)
|
)
|
||||||
if event_map is not None:
|
if event_map is not None:
|
||||||
|
@ -662,7 +664,7 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory):
|
||||||
unconflicted_state, conflicted_state, state_map
|
unconflicted_state, conflicted_state, state_map
|
||||||
)
|
)
|
||||||
|
|
||||||
new_needed_events = set(auth_events.itervalues())
|
new_needed_events = set(itervalues(auth_events))
|
||||||
new_needed_events -= needed_events
|
new_needed_events -= needed_events
|
||||||
if event_map is not None:
|
if event_map is not None:
|
||||||
new_needed_events -= set(event_map.iterkeys())
|
new_needed_events -= set(event_map.iterkeys())
|
||||||
|
@ -679,7 +681,7 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory):
|
||||||
|
|
||||||
def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map):
|
def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map):
|
||||||
auth_events = {}
|
auth_events = {}
|
||||||
for event_ids in conflicted_state.itervalues():
|
for event_ids in itervalues(conflicted_state):
|
||||||
for event_id in event_ids:
|
for event_id in event_ids:
|
||||||
if event_id in state_map:
|
if event_id in state_map:
|
||||||
keys = event_auth.auth_types_for_event(state_map[event_id])
|
keys = event_auth.auth_types_for_event(state_map[event_id])
|
||||||
|
@ -694,7 +696,7 @@ def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_ma
|
||||||
def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_ids,
|
def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_ids,
|
||||||
state_map):
|
state_map):
|
||||||
conflicted_state = {}
|
conflicted_state = {}
|
||||||
for key, event_ids in conflicted_state_ds.iteritems():
|
for key, event_ids in iteritems(conflicted_state_ds):
|
||||||
events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map]
|
events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map]
|
||||||
if len(events) > 1:
|
if len(events) > 1:
|
||||||
conflicted_state[key] = events
|
conflicted_state[key] = events
|
||||||
|
@ -703,7 +705,7 @@ def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_
|
||||||
|
|
||||||
auth_events = {
|
auth_events = {
|
||||||
key: state_map[ev_id]
|
key: state_map[ev_id]
|
||||||
for key, ev_id in auth_event_ids.iteritems()
|
for key, ev_id in iteritems(auth_event_ids)
|
||||||
if ev_id in state_map
|
if ev_id in state_map
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -716,7 +718,7 @@ def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_
|
||||||
raise
|
raise
|
||||||
|
|
||||||
new_state = unconflicted_state_ids
|
new_state = unconflicted_state_ids
|
||||||
for key, event in resolved_state.iteritems():
|
for key, event in iteritems(resolved_state):
|
||||||
new_state[key] = event.event_id
|
new_state[key] = event.event_id
|
||||||
|
|
||||||
return new_state
|
return new_state
|
||||||
|
@ -741,7 +743,7 @@ def _resolve_state_events(conflicted_state, auth_events):
|
||||||
|
|
||||||
auth_events.update(resolved_state)
|
auth_events.update(resolved_state)
|
||||||
|
|
||||||
for key, events in conflicted_state.iteritems():
|
for key, events in iteritems(conflicted_state):
|
||||||
if key[0] == EventTypes.JoinRules:
|
if key[0] == EventTypes.JoinRules:
|
||||||
logger.debug("Resolving conflicted join rules %r", events)
|
logger.debug("Resolving conflicted join rules %r", events)
|
||||||
resolved_state[key] = _resolve_auth_events(
|
resolved_state[key] = _resolve_auth_events(
|
||||||
|
@ -751,7 +753,7 @@ def _resolve_state_events(conflicted_state, auth_events):
|
||||||
|
|
||||||
auth_events.update(resolved_state)
|
auth_events.update(resolved_state)
|
||||||
|
|
||||||
for key, events in conflicted_state.iteritems():
|
for key, events in iteritems(conflicted_state):
|
||||||
if key[0] == EventTypes.Member:
|
if key[0] == EventTypes.Member:
|
||||||
logger.debug("Resolving conflicted member lists %r", events)
|
logger.debug("Resolving conflicted member lists %r", events)
|
||||||
resolved_state[key] = _resolve_auth_events(
|
resolved_state[key] = _resolve_auth_events(
|
||||||
|
@ -761,7 +763,7 @@ def _resolve_state_events(conflicted_state, auth_events):
|
||||||
|
|
||||||
auth_events.update(resolved_state)
|
auth_events.update(resolved_state)
|
||||||
|
|
||||||
for key, events in conflicted_state.iteritems():
|
for key, events in iteritems(conflicted_state):
|
||||||
if key not in resolved_state:
|
if key not in resolved_state:
|
||||||
logger.debug("Resolving conflicted state %r:%r", key, events)
|
logger.debug("Resolving conflicted state %r:%r", key, events)
|
||||||
resolved_state[key] = _resolve_normal_events(
|
resolved_state[key] = _resolve_normal_events(
|
||||||
|
|
Loading…
Reference in a new issue