0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-19 16:13:58 +01:00

py3-ize state.py

This commit is contained in:
Adrian Tschira 2018-05-24 20:58:21 +02:00
parent 7ea07c7305
commit 0e61705661

View file

@ -32,6 +32,8 @@ from frozendict import frozendict
import logging import logging
import hashlib import hashlib
from six import iteritems, itervalues
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -132,7 +134,7 @@ class StateHandler(object):
state_map = yield self.store.get_events(state.values(), get_prev_content=False) state_map = yield self.store.get_events(state.values(), get_prev_content=False)
state = { state = {
key: state_map[e_id] for key, e_id in state.iteritems() if e_id in state_map key: state_map[e_id] for key, e_id in iteritems(state) if e_id in state_map
} }
defer.returnValue(state) defer.returnValue(state)
@ -338,7 +340,7 @@ class StateHandler(object):
) )
if len(state_groups_ids) == 1: if len(state_groups_ids) == 1:
name, state_list = state_groups_ids.items().pop() name, state_list = list(state_groups_ids.items()).pop()
prev_group, delta_ids = yield self.store.get_state_group_delta(name) prev_group, delta_ids = yield self.store.get_state_group_delta(name)
@ -378,7 +380,7 @@ class StateHandler(object):
new_state = resolve_events_with_state_map(state_set_ids, state_map) new_state = resolve_events_with_state_map(state_set_ids, state_map)
new_state = { new_state = {
key: state_map[ev_id] for key, ev_id in new_state.iteritems() key: state_map[ev_id] for key, ev_id in iteritems(new_state)
} }
return new_state return new_state
@ -458,15 +460,15 @@ class StateResolutionHandler(object):
# build a map from state key to the event_ids which set that state. # build a map from state key to the event_ids which set that state.
# dict[(str, str), set[str]) # dict[(str, str), set[str])
state = {} state = {}
for st in state_groups_ids.itervalues(): for st in itervalues(state_groups_ids):
for key, e_id in st.iteritems(): for key, e_id in iteritems(st):
state.setdefault(key, set()).add(e_id) state.setdefault(key, set()).add(e_id)
# build a map from state key to the event_ids which set that state, # build a map from state key to the event_ids which set that state,
# including only those where there are state keys in conflict. # including only those where there are state keys in conflict.
conflicted_state = { conflicted_state = {
k: list(v) k: list(v)
for k, v in state.iteritems() for k, v in iteritems(state)
if len(v) > 1 if len(v) > 1
} }
@ -474,13 +476,13 @@ class StateResolutionHandler(object):
logger.info("Resolving conflicted state for %r", room_id) logger.info("Resolving conflicted state for %r", room_id)
with Measure(self.clock, "state._resolve_events"): with Measure(self.clock, "state._resolve_events"):
new_state = yield resolve_events_with_factory( new_state = yield resolve_events_with_factory(
state_groups_ids.values(), list(state_groups_ids.values()),
event_map=event_map, event_map=event_map,
state_map_factory=state_map_factory, state_map_factory=state_map_factory,
) )
else: else:
new_state = { new_state = {
key: e_ids.pop() for key, e_ids in state.iteritems() key: e_ids.pop() for key, e_ids in iteritems(state)
} }
with Measure(self.clock, "state.create_group_ids"): with Measure(self.clock, "state.create_group_ids"):
@ -489,8 +491,8 @@ class StateResolutionHandler(object):
# which will be used as a cache key for future resolutions, but # which will be used as a cache key for future resolutions, but
# not get persisted. # not get persisted.
state_group = None state_group = None
new_state_event_ids = frozenset(new_state.itervalues()) new_state_event_ids = frozenset(itervalues(new_state))
for sg, events in state_groups_ids.iteritems(): for sg, events in iteritems(state_groups_ids):
if new_state_event_ids == frozenset(e_id for e_id in events): if new_state_event_ids == frozenset(e_id for e_id in events):
state_group = sg state_group = sg
break break
@ -501,11 +503,11 @@ class StateResolutionHandler(object):
prev_group = None prev_group = None
delta_ids = None delta_ids = None
for old_group, old_ids in state_groups_ids.iteritems(): for old_group, old_ids in iteritems(state_groups_ids):
if not set(new_state) - set(old_ids): if not set(new_state) - set(old_ids):
n_delta_ids = { n_delta_ids = {
k: v k: v
for k, v in new_state.iteritems() for k, v in iteritems(new_state)
if old_ids.get(k) != v if old_ids.get(k) != v
} }
if not delta_ids or len(n_delta_ids) < len(delta_ids): if not delta_ids or len(n_delta_ids) < len(delta_ids):
@ -527,7 +529,7 @@ class StateResolutionHandler(object):
def _ordered_events(events): def _ordered_events(events):
def key_func(e): def key_func(e):
return -int(e.depth), hashlib.sha1(e.event_id).hexdigest() return -int(e.depth), hashlib.sha1(e.event_id.encode()).hexdigest()
return sorted(events, key=key_func) return sorted(events, key=key_func)
@ -584,7 +586,7 @@ def _seperate(state_sets):
conflicted_state = {} conflicted_state = {}
for state_set in state_sets[1:]: for state_set in state_sets[1:]:
for key, value in state_set.iteritems(): for key, value in iteritems(state_set):
# Check if there is an unconflicted entry for the state key. # Check if there is an unconflicted entry for the state key.
unconflicted_value = unconflicted_state.get(key) unconflicted_value = unconflicted_state.get(key)
if unconflicted_value is None: if unconflicted_value is None:
@ -640,7 +642,7 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory):
needed_events = set( needed_events = set(
event_id event_id
for event_ids in conflicted_state.itervalues() for event_ids in itervalues(conflicted_state)
for event_id in event_ids for event_id in event_ids
) )
if event_map is not None: if event_map is not None:
@ -662,7 +664,7 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory):
unconflicted_state, conflicted_state, state_map unconflicted_state, conflicted_state, state_map
) )
new_needed_events = set(auth_events.itervalues()) new_needed_events = set(itervalues(auth_events))
new_needed_events -= needed_events new_needed_events -= needed_events
if event_map is not None: if event_map is not None:
new_needed_events -= set(event_map.iterkeys()) new_needed_events -= set(event_map.iterkeys())
@ -679,7 +681,7 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory):
def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map): def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map):
auth_events = {} auth_events = {}
for event_ids in conflicted_state.itervalues(): for event_ids in itervalues(conflicted_state):
for event_id in event_ids: for event_id in event_ids:
if event_id in state_map: if event_id in state_map:
keys = event_auth.auth_types_for_event(state_map[event_id]) keys = event_auth.auth_types_for_event(state_map[event_id])
@ -694,7 +696,7 @@ def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_ma
def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_ids, def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_ids,
state_map): state_map):
conflicted_state = {} conflicted_state = {}
for key, event_ids in conflicted_state_ds.iteritems(): for key, event_ids in iteritems(conflicted_state_ds):
events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map] events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map]
if len(events) > 1: if len(events) > 1:
conflicted_state[key] = events conflicted_state[key] = events
@ -703,7 +705,7 @@ def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_
auth_events = { auth_events = {
key: state_map[ev_id] key: state_map[ev_id]
for key, ev_id in auth_event_ids.iteritems() for key, ev_id in iteritems(auth_event_ids)
if ev_id in state_map if ev_id in state_map
} }
@ -716,7 +718,7 @@ def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_
raise raise
new_state = unconflicted_state_ids new_state = unconflicted_state_ids
for key, event in resolved_state.iteritems(): for key, event in iteritems(resolved_state):
new_state[key] = event.event_id new_state[key] = event.event_id
return new_state return new_state
@ -741,7 +743,7 @@ def _resolve_state_events(conflicted_state, auth_events):
auth_events.update(resolved_state) auth_events.update(resolved_state)
for key, events in conflicted_state.iteritems(): for key, events in iteritems(conflicted_state):
if key[0] == EventTypes.JoinRules: if key[0] == EventTypes.JoinRules:
logger.debug("Resolving conflicted join rules %r", events) logger.debug("Resolving conflicted join rules %r", events)
resolved_state[key] = _resolve_auth_events( resolved_state[key] = _resolve_auth_events(
@ -751,7 +753,7 @@ def _resolve_state_events(conflicted_state, auth_events):
auth_events.update(resolved_state) auth_events.update(resolved_state)
for key, events in conflicted_state.iteritems(): for key, events in iteritems(conflicted_state):
if key[0] == EventTypes.Member: if key[0] == EventTypes.Member:
logger.debug("Resolving conflicted member lists %r", events) logger.debug("Resolving conflicted member lists %r", events)
resolved_state[key] = _resolve_auth_events( resolved_state[key] = _resolve_auth_events(
@ -761,7 +763,7 @@ def _resolve_state_events(conflicted_state, auth_events):
auth_events.update(resolved_state) auth_events.update(resolved_state)
for key, events in conflicted_state.iteritems(): for key, events in iteritems(conflicted_state):
if key not in resolved_state: if key not in resolved_state:
logger.debug("Resolving conflicted state %r:%r", key, events) logger.debug("Resolving conflicted state %r:%r", key, events)
resolved_state[key] = _resolve_normal_events( resolved_state[key] = _resolve_normal_events(