diff --git a/synapse/state.py b/synapse/state.py index c75499c3e..90b14e758 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -337,7 +337,7 @@ class StateHandler(object): for st in state_groups_ids.values() ] with Measure(self.clock, "state._resolve_events"): - new_state, _ = Resolver.resolve_events( + new_state, _ = resolve_events( state_sets, event_type, state_key ) new_state = { @@ -392,11 +392,11 @@ class StateHandler(object): ) with Measure(self.clock, "state._resolve_events"): if event.is_state(): - return Resolver.resolve_events( + return resolve_events( state_sets, event.type, event.state_key ) else: - return Resolver.resolve_events(state_sets) + return resolve_events(state_sets) def _ordered_events(events): @@ -406,138 +406,136 @@ def _ordered_events(events): return sorted(events, key=key_func) -class Resolver(object): - @staticmethod - def resolve_events(state_sets, event_type=None, state_key=""): - """ - Returns - (dict[(str, str), synapse.events.FrozenEvent], list[str]): a tuple - (new_state, prev_states). new_state is a map from (type, state_key) - to event. prev_states is a list of event_ids. - """ - state = {} - for st in state_sets: - for e in st: - state.setdefault( - (e.type, e.state_key), - {} - )[e.event_id] = e +def resolve_events(state_sets, event_type=None, state_key=""): + """ + Returns + (dict[(str, str), synapse.events.FrozenEvent], list[str]): a tuple + (new_state, prev_states). new_state is a map from (type, state_key) + to event. prev_states is a list of event_ids. + """ + state = {} + for st in state_sets: + for e in st: + state.setdefault( + (e.type, e.state_key), + {} + )[e.event_id] = e - unconflicted_state = { - k: v.values()[0] for k, v in state.items() - if len(v.values()) == 1 - } + unconflicted_state = { + k: v.values()[0] for k, v in state.items() + if len(v.values()) == 1 + } - conflicted_state = { - k: v.values() - for k, v in state.items() - if len(v.values()) > 1 - } + conflicted_state = { + k: v.values() + for k, v in state.items() + if len(v.values()) > 1 + } - if event_type: - prev_states_events = conflicted_state.get( - (event_type, state_key), [] + if event_type: + prev_states_events = conflicted_state.get( + (event_type, state_key), [] + ) + prev_states = [s.event_id for s in prev_states_events] + else: + prev_states = [] + + auth_events = { + k: e for k, e in unconflicted_state.items() + if k[0] in AuthEventTypes + } + + try: + resolved_state = _resolve_state_events( + conflicted_state, auth_events + ) + except: + logger.exception("Failed to resolve state") + raise + + new_state = unconflicted_state + new_state.update(resolved_state) + + return new_state, prev_states + + +def _resolve_state_events(conflicted_state, auth_events): + """ This is where we actually decide which of the conflicted state to + use. + + We resolve conflicts in the following order: + 1. power levels + 2. join rules + 3. memberships + 4. other events. + """ + resolved_state = {} + power_key = (EventTypes.PowerLevels, "") + if power_key in conflicted_state: + events = conflicted_state[power_key] + logger.debug("Resolving conflicted power levels %r", events) + resolved_state[power_key] = _resolve_auth_events( + events, auth_events) + + auth_events.update(resolved_state) + + for key, events in conflicted_state.items(): + if key[0] == EventTypes.JoinRules: + logger.debug("Resolving conflicted join rules %r", events) + resolved_state[key] = _resolve_auth_events( + events, + auth_events ) - prev_states = [s.event_id for s in prev_states_events] - else: - prev_states = [] - auth_events = { - k: e for k, e in unconflicted_state.items() - if k[0] in AuthEventTypes - } + auth_events.update(resolved_state) + for key, events in conflicted_state.items(): + if key[0] == EventTypes.Member: + logger.debug("Resolving conflicted member lists %r", events) + resolved_state[key] = _resolve_auth_events( + events, + auth_events + ) + + auth_events.update(resolved_state) + + for key, events in conflicted_state.items(): + if key not in resolved_state: + logger.debug("Resolving conflicted state %r:%r", key, events) + resolved_state[key] = _resolve_normal_events( + events, auth_events + ) + + return resolved_state + + +def _resolve_auth_events(events, auth_events): + reverse = [i for i in reversed(_ordered_events(events))] + + auth_events = dict(auth_events) + + prev_event = reverse[0] + for event in reverse[1:]: + auth_events[(prev_event.type, prev_event.state_key)] = prev_event try: - resolved_state = Resolver._resolve_state_events( - conflicted_state, auth_events - ) - except: - logger.exception("Failed to resolve state") - raise + # The signatures have already been checked at this point + event_auth.check(event, auth_events, do_sig_check=False) + prev_event = event + except AuthError: + return prev_event - new_state = unconflicted_state - new_state.update(resolved_state) + return event - return new_state, prev_states - @staticmethod - def _resolve_state_events(conflicted_state, auth_events): - """ This is where we actually decide which of the conflicted state to - use. +def _resolve_normal_events(events, auth_events): + for event in _ordered_events(events): + try: + # The signatures have already been checked at this point + event_auth.check(event, auth_events, do_sig_check=False) + return event + except AuthError: + pass - We resolve conflicts in the following order: - 1. power levels - 2. join rules - 3. memberships - 4. other events. - """ - resolved_state = {} - power_key = (EventTypes.PowerLevels, "") - if power_key in conflicted_state: - events = conflicted_state[power_key] - logger.debug("Resolving conflicted power levels %r", events) - resolved_state[power_key] = Resolver._resolve_auth_events( - events, auth_events) - - auth_events.update(resolved_state) - - for key, events in conflicted_state.items(): - if key[0] == EventTypes.JoinRules: - logger.debug("Resolving conflicted join rules %r", events) - resolved_state[key] = Resolver._resolve_auth_events( - events, - auth_events - ) - - auth_events.update(resolved_state) - - for key, events in conflicted_state.items(): - if key[0] == EventTypes.Member: - logger.debug("Resolving conflicted member lists %r", events) - resolved_state[key] = Resolver._resolve_auth_events( - events, - auth_events - ) - - auth_events.update(resolved_state) - - for key, events in conflicted_state.items(): - if key not in resolved_state: - logger.debug("Resolving conflicted state %r:%r", key, events) - resolved_state[key] = Resolver._resolve_normal_events( - events, auth_events - ) - - return resolved_state - - @staticmethod - def _resolve_auth_events(events, auth_events): - reverse = [i for i in reversed(_ordered_events(events))] - - auth_events = dict(auth_events) - - prev_event = reverse[0] - for event in reverse[1:]: - auth_events[(prev_event.type, prev_event.state_key)] = prev_event - try: - # The signatures have already been checked at this point - event_auth.check(event, auth_events, do_sig_check=False) - prev_event = event - except AuthError: - return prev_event - - return event - - @staticmethod - def _resolve_normal_events(events, auth_events): - for event in _ordered_events(events): - try: - # The signatures have already been checked at this point - event_auth.check(event, auth_events, do_sig_check=False) - return event - except AuthError: - pass - - # Use the last event (the one with the least depth) if they all fail - # the auth check. - return event + # Use the last event (the one with the least depth) if they all fail + # the auth check. + return event