diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 3471afd7e..7105ee21d 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -358,7 +358,7 @@ class Auth(object): def add_auth_events(self, builder, context): yield run_on_reactor() - auth_ids = self.compute_auth_events(builder, context) + auth_ids = self.compute_auth_events(builder, context.current_state) auth_events_entries = yield self.store.add_event_hashes( auth_ids @@ -372,26 +372,26 @@ class Auth(object): if v.event_id in auth_ids } - def compute_auth_events(self, event, context): + def compute_auth_events(self, event, current_state): if event.type == EventTypes.Create: return [] auth_ids = [] key = (EventTypes.PowerLevels, "", ) - power_level_event = context.current_state.get(key) + power_level_event = current_state.get(key) if power_level_event: auth_ids.append(power_level_event.event_id) key = (EventTypes.JoinRules, "", ) - join_rule_event = context.current_state.get(key) + join_rule_event = current_state.get(key) key = (EventTypes.Member, event.user_id, ) - member_event = context.current_state.get(key) + member_event = current_state.get(key) key = (EventTypes.Create, "", ) - create_event = context.current_state.get(key) + create_event = current_state.get(key) if create_event: auth_ids.append(create_event.event_id) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index a968a8736..04a468948 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -842,7 +842,9 @@ class FederationHandler(BaseHandler): logger.debug("Different auth: %s", different_auth) # 1. Get what we think is the auth chain. - auth_ids = self.auth.compute_auth_events(event, context) + auth_ids = self.auth.compute_auth_events( + event, context.current_state + ) local_auth_chain = yield self.store.get_auth_chain(auth_ids) try: diff --git a/synapse/state.py b/synapse/state.py index 6a6fb8aea..695a5e7ac 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -103,7 +103,9 @@ class StateHandler(object): context.state_group = None if hasattr(event, "auth_events") and event.auth_events: - auth_ids = zip(*event.auth_events)[0] + auth_ids = self.hs.get_auth().compute_auth_events( + event, context.current_state + ) context.auth_events = { k: v for k, v in context.current_state.items() @@ -149,7 +151,9 @@ class StateHandler(object): event.unsigned["replaces_state"] = replaces.event_id if hasattr(event, "auth_events") and event.auth_events: - auth_ids = zip(*event.auth_events)[0] + auth_ids = self.hs.get_auth().compute_auth_events( + event, context.current_state + ) context.auth_events = { k: v for k, v in context.current_state.items()