diff --git a/synapse/api/auth.py b/synapse/api/auth.py index f26e58562..fcf0b0d25 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -66,7 +66,7 @@ class Auth(object): @defer.inlineCallbacks def check_from_context(self, event, context, do_sig_check=True): auth_events_ids = yield self.compute_auth_events( - event, context.current_state_ids, for_verification=True, + event, context.prev_state_ids, for_verification=True, ) auth_events = yield self.store.get_events(auth_events_ids) auth_events = { @@ -852,7 +852,7 @@ class Auth(object): @defer.inlineCallbacks def add_auth_events(self, builder, context): - auth_ids = yield self.compute_auth_events(builder, context.current_state_ids) + auth_ids = yield self.compute_auth_events(builder, context.prev_state_ids) auth_events_entries = yield self.store.add_event_hashes( auth_ids diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index c75afd02d..e895b1c45 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -15,8 +15,9 @@ class EventContext(object): - def __init__(self, current_state_ids=None): - self.current_state_ids = current_state_ids + def __init__(self): + self.current_state_ids = None + self.prev_state_ids = None self.state_group = None self.rejected = False self.push_actions = [] diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index a7ea8fb98..8e61d74b1 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -222,7 +222,7 @@ class FederationHandler(BaseHandler): # joined the room. Don't bother if the user is just # changing their profile info. newly_joined = True - prev_state_id = context.current_state_ids.get( + prev_state_id = context.prev_state_ids.get( (event.type, event.state_key) ) if prev_state_id: @@ -835,12 +835,12 @@ class FederationHandler(BaseHandler): self.replication_layer.send_pdu(new_pdu, destinations) - state_ids = context.current_state_ids.values() + state_ids = context.prev_state_ids.values() auth_chain = yield self.store.get_auth_chain(set( [event.event_id] + state_ids )) - state = yield self.store.get_events(context.current_state_ids.values()) + state = yield self.store.get_events(context.prev_state_ids.values()) defer.returnValue({ "state": state.values(), @@ -1333,7 +1333,7 @@ class FederationHandler(BaseHandler): if not auth_events: auth_events_ids = yield self.auth.compute_auth_events( - event, context.current_state_ids, for_verification=True, + event, context.prev_state_ids, for_verification=True, ) auth_events = yield self.store.get_events(auth_events_ids) auth_events = { @@ -1432,6 +1432,11 @@ class FederationHandler(BaseHandler): current_state = set(e.event_id for e in auth_events.values()) event_auth_events = set(e_id for e_id, _ in event.auth_events) + if event.is_state(): + event_key = (event.type, event.state_key) + else: + event_key = None + if event_auth_events - current_state: have_events = yield self.store.have_events( event_auth_events - current_state @@ -1537,8 +1542,12 @@ class FederationHandler(BaseHandler): context.current_state_ids.update({ k: a.event_id for k, a in auth_events.items() + if k != event_key }) - context.state_group = None + context.prev_state_ids.update({ + k: a.event_id for k, a in auth_events.items() + }) + context.state_group = self.store.get_next_state_group() if different_auth and not event.internal_metadata.is_outlier(): logger.info("Different auth after resolution: %s", different_auth) @@ -1560,7 +1569,7 @@ class FederationHandler(BaseHandler): if do_resolution: # 1. Get what we think is the auth chain. auth_ids = yield self.auth.compute_auth_events( - event, context.current_state_ids + event, context.prev_state_ids ) local_auth_chain = yield self.store.get_auth_chain(auth_ids) @@ -1618,8 +1627,12 @@ class FederationHandler(BaseHandler): context.current_state_ids.update({ k: a.event_id for k, a in auth_events.items() + if k != event_key }) - context.state_group = None + context.prev_state_ids.update({ + k: a.event_id for k, a in auth_events.items() + }) + context.state_group = self.store.get_next_state_group() try: self.auth.check(event, auth_events=auth_events) @@ -1855,7 +1868,7 @@ class FederationHandler(BaseHandler): event.content["third_party_invite"]["signed"]["token"] ) original_invite = None - original_invite_id = context.current_state_ids.get(key) + original_invite_id = context.prev_state_ids.get(key) if original_invite_id: original_invite = yield self.store.get_event( original_invite_id, allow_none=True @@ -1893,7 +1906,7 @@ class FederationHandler(BaseHandler): signed = event.content["third_party_invite"]["signed"] token = signed["token"] - invite_event_id = context.current_state_ids.get( + invite_event_id = context.prev_state_ids.get( (EventTypes.ThirdPartyInvite, token,) ) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index e2f4387f6..3577db059 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -272,7 +272,7 @@ class MessageHandler(BaseHandler): If so, returns the version of the event in context. Otherwise, returns None. """ - prev_event_id = context.current_state_ids.get((event.type, event.state_key)) + prev_event_id = context.prev_state_ids.get((event.type, event.state_key)) prev_event = yield self.store.get_event(prev_event_id, allow_none=True) if not prev_event: return @@ -808,8 +808,8 @@ class MessageHandler(BaseHandler): event = builder.build() logger.debug( - "Created event %s with current state: %s", - event.event_id, context.current_state_ids, + "Created event %s with state: %s", + event.event_id, context.prev_state_ids, ) defer.returnValue( @@ -904,7 +904,7 @@ class MessageHandler(BaseHandler): if event.type == EventTypes.Redaction: auth_events_ids = yield self.auth.compute_auth_events( - event, context.current_state_ids, for_verification=True, + event, context.prev_state_ids, for_verification=True, ) auth_events = yield self.store.get_events(auth_events_ids) auth_events = { @@ -924,7 +924,7 @@ class MessageHandler(BaseHandler): "You don't have permission to redact events" ) - if event.type == EventTypes.Create and context.current_state_ids: + if event.type == EventTypes.Create and context.prev_state_ids: raise AuthError( 403, "Changing the room create event is forbidden", diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index dd4b90ee2..3ba5335af 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -93,7 +93,7 @@ class RoomMemberHandler(BaseHandler): ratelimit=ratelimit, ) - prev_member_event_id = context.current_state_ids.get( + prev_member_event_id = context.prev_state_ids.get( (EventTypes.Member, target.to_string()), None ) @@ -341,7 +341,7 @@ class RoomMemberHandler(BaseHandler): if event.membership == Membership.JOIN: if requester.is_guest: - guest_can_join = yield self._can_guest_join(context.current_state_ids) + guest_can_join = yield self._can_guest_join(context.prev_state_ids) if not guest_can_join: # This should be an auth check, but guests are a local concept, # so don't really fit into the general auth process. @@ -355,7 +355,7 @@ class RoomMemberHandler(BaseHandler): ratelimit=ratelimit, ) - prev_member_event_id = context.current_state_ids.get( + prev_member_event_id = context.prev_state_ids.get( (EventTypes.Member, event.state_key), None ) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 51cb21ee9..6ff9a06de 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -87,7 +87,7 @@ class BulkPushRuleEvaluator: ) room_members = yield self.store.get_joined_users_from_context( - event.room_id, context.state_group, context.current_state_ids + event, context ) evaluator = PushRuleEvaluatorForEvent(event, len(room_members)) diff --git a/synapse/state.py b/synapse/state.py index 147416fd8..a0f807e3b 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -128,7 +128,7 @@ class StateHandler(object): def get_current_user_in_room(self, room_id): latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) group, state_ids = yield self.resolve_state_groups(room_id, latest_event_ids) - joined_users = yield self.store.get_joined_users_from_context( + joined_users = yield self.store.get_joined_users_from_state( room_id, group, state_ids ) defer.returnValue(joined_users) @@ -154,27 +154,38 @@ class StateHandler(object): # state. Certainly store.get_current_state won't return any, and # persisting the event won't store the state group. if old_state: - context.current_state_ids = { + context.prev_state_ids = { (s.type, s.state_key): s.event_id for s in old_state } + if event.is_state(): + context.current_state_events = dict(context.prev_state_ids) + key = (event.type, event.state_key) + context.current_state_events[key] = event.event_id + else: + context.current_state_events = context.prev_state_ids else: context.current_state_ids = {} + context.prev_state_ids = {} context.prev_state_events = [] context.state_group = self.store.get_next_state_group() defer.returnValue(context) if old_state: - context.current_state_ids = { + context.prev_state_ids = { (s.type, s.state_key): s.event_id for s in old_state } context.state_group = self.store.get_next_state_group() if event.is_state(): key = (event.type, event.state_key) - if key in context.current_state_ids: - replaces = context.current_state_ids[key] + if key in context.prev_state_ids: + replaces = context.prev_state_ids[key] if replaces != event.event_id: # Paranoia check event.unsigned["replaces_state"] = replaces + context.current_state_ids = dict(context.prev_state_ids) + context.current_state_ids[key] = event.event_id + else: + context.current_state_ids = context.prev_state_ids context.prev_state_events = [] defer.returnValue(context) @@ -192,7 +203,7 @@ class StateHandler(object): group, curr_state = ret - context.current_state_ids = curr_state + context.prev_state_ids = curr_state if event.is_state() or group is None: context.state_group = self.store.get_next_state_group() else: @@ -200,9 +211,13 @@ class StateHandler(object): if event.is_state(): key = (event.type, event.state_key) - if key in context.current_state_ids: - replaces = context.current_state_ids[key] + if key in context.prev_state_ids: + replaces = context.prev_state_ids[key] event.unsigned["replaces_state"] = replaces + context.current_state_ids = dict(context.prev_state_ids) + context.current_state_ids[key] = event.event_id + else: + context.current_state_ids = context.prev_state_ids context.prev_state_events = [] defer.returnValue(context) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index cab166083..6ab10db32 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -354,7 +354,8 @@ class RoomMemberStore(SQLBaseStore): desc="who_forgot" ) - def get_joined_users_from_context(self, room_id, state_group, state_ids): + def get_joined_users_from_context(self, event, context): + state_group = context.state_group if not state_group: # If state_group is None it means it has yet to be assigned a # state group, i.e. we need to make sure that calls with a state_group @@ -363,12 +364,24 @@ class RoomMemberStore(SQLBaseStore): state_group = object() return self._get_joined_users_from_context( - room_id, state_group, state_ids + event.room_id, state_group, context.current_state_ids, event=event, + ) + + def get_joined_users_from_state(self, room_id, state_group, state_ids): + if not state_group: + # If state_group is None it means it has yet to be assigned a + # state group, i.e. we need to make sure that calls with a state_group + # of None don't hit previous cached calls with a None state_group. + # To do this we set the state_group to a new object as object() != object() + state_group = object() + + return self._get_joined_users_from_context( + room_id, state_group, state_ids, ) @cachedInlineCallbacks(num_args=2, cache_context=True) def _get_joined_users_from_context(self, room_id, state_group, current_state_ids, - cache_context): + cache_context, event=None): # We don't use `state_group`, its there so that we can cache based # on it. However, its important that its never None, since two current_state's # with a state_group of None are likely to be different. @@ -393,7 +406,13 @@ class RoomMemberStore(SQLBaseStore): desc="_get_joined_users_from_context", ) - defer.returnValue(set(row["user_id"] for row in rows)) + users_in_room = set(row["user_id"] for row in rows) + if event is not None and event.type == EventTypes.Member: + if event.membership == Membership.JOIN: + if event.event_id in member_event_ids: + users_in_room.add(event.state_key) + + defer.returnValue(users_in_room) def is_host_joined(self, room_id, host, state_group, state_ids): if not state_group: diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 56bfdc0b5..dce5a2f13 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -108,9 +108,6 @@ class StateStore(SQLBaseStore): state_event_ids = dict(context.current_state_ids) - if event.is_state(): - state_event_ids[(event.type, event.state_key)] = event.event_id - self._simple_insert_txn( txn, table="state_groups", diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 218cb2488..44e859b5d 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -312,7 +312,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): else: state_ids = None - context = EventContext(current_state_ids=state_ids) + context = EventContext() + context.current_state_ids = state_ids + context.prev_state_ids = state_ids context.push_actions = push_actions ordering = None diff --git a/tests/replication/test_resource.py b/tests/replication/test_resource.py index e70ac6f14..b69832cc1 100644 --- a/tests/replication/test_resource.py +++ b/tests/replication/test_resource.py @@ -60,8 +60,8 @@ class ReplicationResourceCase(unittest.TestCase): self.assertEquals(body, {}) @defer.inlineCallbacks - def test_events_and_state(self): - get = self.get(events="-1", state="-1", timeout="0") + def test_events(self): + get = self.get(events="-1", timeout="0") yield self.hs.get_handlers().room_creation_handler.create_room( synapse.types.create_requester(self.user), {} ) @@ -70,12 +70,6 @@ class ReplicationResourceCase(unittest.TestCase): self.assertEquals(body["events"]["field_names"], [ "position", "internal", "json", "state_group" ]) - self.assertEquals(body["state_groups"]["field_names"], [ - "position", "room_id", "event_id" - ]) - self.assertEquals(body["state_group_state"]["field_names"], [ - "position", "type", "state_key", "event_id" - ]) @defer.inlineCallbacks def test_presence(self): diff --git a/tests/test_state.py b/tests/test_state.py index de2d35145..6454f994e 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -86,17 +86,8 @@ class StateGroupStore(object): state_events = dict(context.current_state_ids) - if event.is_state(): - state_events[(event.type, event.state_key)] = event.event_id - - state_group = context.state_group - if not state_group: - state_group = self._next_group - self._next_group += 1 - - self._group_to_state[state_group] = state_events - - self._event_to_state_group[event.event_id] = state_group + self._group_to_state[context.state_group] = state_events + self._event_to_state_group[event.event_id] = context.state_group def get_events(self, event_ids, **kwargs): return { @@ -151,6 +142,7 @@ class StateTestCase(unittest.TestCase): "get_state_groups_ids", "add_event_hashes", "get_events", + "get_next_state_group", ] ) hs = Mock(spec_set=[ @@ -161,6 +153,8 @@ class StateTestCase(unittest.TestCase): hs.get_clock.return_value = MockClock() hs.get_auth.return_value = Auth(hs) + self.store.get_next_state_group.side_effect = Mock + self.state = StateHandler(hs) self.event_id = 0 @@ -209,7 +203,7 @@ class StateTestCase(unittest.TestCase): store.store_state_groups(event, context) context_store[event.event_id] = context - self.assertEqual(2, len(context_store["D"].current_state_ids)) + self.assertEqual(2, len(context_store["D"].prev_state_ids)) @defer.inlineCallbacks def test_branch_basic_conflict(self): @@ -265,7 +259,7 @@ class StateTestCase(unittest.TestCase): self.assertSetEqual( {"START", "A", "C"}, - {e_id for e_id in context_store["D"].current_state_ids.values()} + {e_id for e_id in context_store["D"].prev_state_ids.values()} ) @defer.inlineCallbacks @@ -331,7 +325,7 @@ class StateTestCase(unittest.TestCase): self.assertSetEqual( {"START", "A", "B", "C"}, - {e for e in context_store["E"].current_state_ids.values()} + {e for e in context_store["E"].prev_state_ids.values()} ) @defer.inlineCallbacks @@ -414,7 +408,7 @@ class StateTestCase(unittest.TestCase): self.assertSetEqual( {"A1", "A2", "A3", "A5", "B"}, - {e for e in context_store["D"].current_state_ids.values()} + {e for e in context_store["D"].prev_state_ids.values()} ) def _add_depths(self, nodes, edges): @@ -447,7 +441,7 @@ class StateTestCase(unittest.TestCase): set(e.event_id for e in old_state), set(context.current_state_ids.values()) ) - self.assertIsNone(context.state_group) + self.assertIsNotNone(context.state_group) @defer.inlineCallbacks def test_annotate_with_old_state(self): @@ -464,11 +458,9 @@ class StateTestCase(unittest.TestCase): ) self.assertEqual( - set(e.event_id for e in old_state), set(context.current_state_ids.values()) + set(e.event_id for e in old_state), set(context.prev_state_ids.values()) ) - self.assertIsNone(context.state_group) - @defer.inlineCallbacks def test_trivial_annotate_message(self): event = create_event(type="test_message", name="event") @@ -514,10 +506,10 @@ class StateTestCase(unittest.TestCase): self.assertEqual( set([e.event_id for e in old_state]), - set(context.current_state_ids.values()) + set(context.prev_state_ids.values()) ) - self.assertIsNone(context.state_group) + self.assertIsNotNone(context.state_group) @defer.inlineCallbacks def test_resolve_message_conflict(self): @@ -550,7 +542,7 @@ class StateTestCase(unittest.TestCase): self.assertEqual(len(context.current_state_ids), 6) - self.assertIsNone(context.state_group) + self.assertIsNotNone(context.state_group) @defer.inlineCallbacks def test_resolve_state_conflict(self): @@ -583,7 +575,7 @@ class StateTestCase(unittest.TestCase): self.assertEqual(len(context.current_state_ids), 6) - self.assertIsNone(context.state_group) + self.assertIsNotNone(context.state_group) @defer.inlineCallbacks def test_standard_depth_conflict(self):