diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 31852b29a..91ec0995f 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -34,7 +34,7 @@ class Auth(object): self.store = hs.get_datastore() @defer.inlineCallbacks - def check(self, event, raises=False): + def check(self, event, snapshot, raises=False): """ Checks if this event is correctly authed. Returns: @@ -46,7 +46,11 @@ class Auth(object): try: if event.type in [RoomTopicEvent.TYPE, MessageEvent.TYPE, FeedbackEvent.TYPE]: - yield self.check_joined_room(event.room_id, event.user_id) + self._check_joined_room( + member=snapshot.membership_state, + user_id=snapshot.user_id, + room_id=snapshot.room_id, + ) defer.returnValue(True) elif event.type == RoomMemberEvent.TYPE: allowed = yield self.is_membership_change_allowed(event) @@ -67,14 +71,16 @@ class Auth(object): room_id=room_id, user_id=user_id ) - if not member or member.membership != Membership.JOIN: - raise AuthError(403, "User %s not in room %s" % - (user_id, room_id)) + self._check_joined_room(member, user_id, room_id) defer.returnValue(member) except AttributeError: pass defer.returnValue(None) + def _check_joined_room(self, member, user_id, room_id): + if not member or member.membership != Membership.JOIN: + raise AuthError(403, "User %s not in room %s" % (user_id, room_id)) + @defer.inlineCallbacks def is_membership_change_allowed(self, event): # does this room even exist diff --git a/synapse/federation/handler.py b/synapse/federation/handler.py index 984c1558e..ce98f4f94 100644 --- a/synapse/federation/handler.py +++ b/synapse/federation/handler.py @@ -51,19 +51,20 @@ class FederationEventHandler(object): @log_function @defer.inlineCallbacks - def handle_new_event(self, event): + def handle_new_event(self, event, snapshot): """ Takes in an event from the client to server side, that has already been authed and handled by the state module, and sends it to any remote home servers that may be interested. Args: event + snapshot (.storage.Snapshot): THe snapshot the event happened after Returns: Deferred: Resolved when it has successfully been queued for processing. """ - yield self.fill_out_prev_events(event) + yield self.fill_out_prev_events(event, snapshot) pdu = self.pdu_codec.pdu_from_event(event) @@ -137,13 +138,11 @@ class FederationEventHandler(object): yield self.event_handler.on_receive(new_state_event) @defer.inlineCallbacks - def fill_out_prev_events(self, event): + def fill_out_prev_events(self, event, snapshot): if hasattr(event, "prev_events"): return - results = yield self.store.get_latest_pdus_in_context( - event.room_id - ) + results = snapshot.prev_pdus es = [ "%s@%s" % (p_id, origin) for p_id, origin, _ in results diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 6229ee9bf..074363e80 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -85,20 +85,21 @@ class MessageHandler(BaseHandler): if stamp_event: event.content["hsob_ts"] = int(self.clock.time_msec()) - with (yield self.room_lock.lock(event.room_id)): - if not suppress_auth: - yield self.auth.check(event, raises=True) + snapshot = yield self.store.snapshot_room(event.room_id, event.user_id) - # store message in db - store_id = yield self.store.persist_event(event) + if not suppress_auth: + yield self.auth.check(event, snapshot, raises=True) - event.destinations = yield self.store.get_joined_hosts_for_room( - event.room_id - ) + # store message in db + store_id = yield self.store.persist_event(event) - self.notifier.on_new_room_event(event, store_id) + event.destinations = yield self.store.get_joined_hosts_for_room( + event.room_id + ) - yield self.hs.get_federation().handle_new_event(event) + self.notifier.on_new_room_event(event, store_id) + + yield self.hs.get_federation().handle_new_event(event, snapshot) @defer.inlineCallbacks def get_messages(self, user_id=None, room_id=None, pagin_config=None, @@ -135,23 +136,24 @@ class MessageHandler(BaseHandler): SynapseError if something went wrong. """ - with (yield self.room_lock.lock(event.room_id)): - yield self.auth.check(event, raises=True) + snapshot = yield self.store.snapshot_room(event.room_id, event.user_id) - if stamp_event: - event.content["hsob_ts"] = int(self.clock.time_msec()) + yield self.auth.check(event, snapshot, raises=True) - yield self.state_handler.handle_new_event(event) + if stamp_event: + event.content["hsob_ts"] = int(self.clock.time_msec()) - # store in db - store_id = yield self.store.persist_event(event) + yield self.state_handler.handle_new_event(event) - event.destinations = yield self.store.get_joined_hosts_for_room( - event.room_id - ) - self.notifier.on_new_room_event(event, store_id) + # store in db + store_id = yield self.store.persist_event(event) - yield self.hs.get_federation().handle_new_event(event) + event.destinations = yield self.store.get_joined_hosts_for_room( + event.room_id + ) + self.notifier.on_new_room_event(event, store_id) + + yield self.hs.get_federation().handle_new_event(event, snapshot) @defer.inlineCallbacks def get_room_data(self, user_id=None, room_id=None, @@ -220,16 +222,18 @@ class MessageHandler(BaseHandler): if stamp_event: event.content["hsob_ts"] = int(self.clock.time_msec()) - with (yield self.room_lock.lock(event.room_id)): - yield self.auth.check(event, raises=True) + snapshot = yield self.store.snapshot_room(event.room_id, user_id) - # store message in db - store_id = yield self.store.persist_event(event) - event.destinations = yield self.store.get_joined_hosts_for_room( - event.room_id - ) - yield self.hs.get_federation().handle_new_event(event) + yield self.auth.check(event, snapshot, raises=True) + + # store message in db + store_id = yield self.store.persist_event(event) + + event.destinations = yield self.store.get_joined_hosts_for_room( + event.room_id + ) + yield self.hs.get_federation().handle_new_event(event, snapshot) self.notifier.on_new_room_event(event, store_id) @@ -503,6 +507,11 @@ class RoomMemberHandler(BaseHandler): SynapseError if there was a problem changing the membership. """ + snapshot = yield self.store.snapshot_room( + event.room_id, event.user_id, + RoomMemberEvent.TYPE, event.target_user_id + ) + ## TODO(markjh): get prev state from snapshot. prev_state = yield self.store.get_room_member( event.target_user_id, event.room_id ) @@ -523,24 +532,22 @@ class RoomMemberHandler(BaseHandler): # if this HS is not currently in the room, i.e. we have to do the # invite/join dance. if event.membership == Membership.JOIN: - yield self._do_join(event, do_auth=do_auth) + yield self._do_join(event, snapshot, do_auth=do_auth) else: # This is not a JOIN, so we can handle it normally. if do_auth: - yield self.auth.check(event, raises=True) + yield self.auth.check(event, snapshot, raises=True) - prev_state = yield self.store.get_room_member( - event.target_user_id, event.room_id - ) if prev_state and prev_state.membership == event.membership: # double same action, treat this event as a NOOP. defer.returnValue({}) return - yield self.state_handler.handle_new_event(event) + yield self.state_handler.handle_new_event(event, snapshot) yield self._do_local_membership_update( event, membership=event.content["membership"], + snapshot=snapshot, ) defer.returnValue({"room_id": room_id}) @@ -570,12 +577,16 @@ class RoomMemberHandler(BaseHandler): content=content, ) - yield self._do_join(new_event, room_host=host, do_auth=True) + snapshot = yield store.snapshot_room( + room_id, joinee, RoomMemberEvent.TYPE, event.target_user_id + ) + + yield self._do_join(new_event, snapshot, room_host=host, do_auth=True) defer.returnValue({"room_id": room_id}) @defer.inlineCallbacks - def _do_join(self, event, room_host=None, do_auth=True): + def _do_join(self, event, snapshot, room_host=None, do_auth=True): joinee = self.hs.parse_userid(event.target_user_id) # room_id = RoomID.from_string(event.room_id, self.hs) room_id = event.room_id @@ -597,6 +608,7 @@ class RoomMemberHandler(BaseHandler): elif room_host: should_do_dance = True else: + # TODO(markjh): get prev_state from snapshot prev_state = yield self.store.get_room_member( joinee.to_string(), room_id ) @@ -624,12 +636,13 @@ class RoomMemberHandler(BaseHandler): logger.debug("Doing normal join") if do_auth: - yield self.auth.check(event, raises=True) + yield self.auth.check(event, snapshot, raises=True) - yield self.state_handler.handle_new_event(event) + yield self.state_handler.handle_new_event(event, snapshot) yield self._do_local_membership_update( event, membership=event.content["membership"], + snapshot=snapshot, ) user = self.hs.parse_userid(event.user_id) @@ -674,7 +687,7 @@ class RoomMemberHandler(BaseHandler): defer.returnValue([r.room_id for r in rooms]) @defer.inlineCallbacks - def _do_local_membership_update(self, event, membership): + def _do_local_membership_update(self, event, membership, snapshot): # store membership store_id = yield self.store.persist_event(event) @@ -700,7 +713,7 @@ class RoomMemberHandler(BaseHandler): event.destinations = list(set(destinations)) - yield self.hs.get_federation().handle_new_event(event) + yield self.hs.get_federation().handle_new_event(event, snapshot) self.notifier.on_new_room_event(event, store_id) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 773290692..d23df1509 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -187,6 +187,70 @@ class DataStore(RoomMemberStore, RoomStore, defer.returnValue(self.min_token) + def snapshot_room(self, room_id, user_id, state_type=None, state_key=None): + """Snapshot the room for an update by a user + Args: + room_id (synapse.types.RoomId): The room to snapshot. + user_id (synapse.types.UserId): The user to snapshot the room for. + state_type (str): Optional state type to snapshot. + state_key (str): Optional state key to snapshot. + Returns: + synapse.storage.Snapshot: A snapshot of the state of the room. + """ + def _snapshot(txn): + membership_state = self._get_room_member(txn, user_id) + prev_pdus = self._get_latest_pdus_in_context( + txn, room_id + ) + if state_type is not None and state_key is not None: + prev_state_pdu = self._get_current_state_pdu( + txn, room_id, state_type, state_key + ) + else: + prev_state_pdu = None + + return Snapshot( + store=self, + room_id=room_id, + user_id=user_id, + prev_pdus=prev_pdus, + membership_state=membership_state, + state_type=state_type, + state_key=state_key, + prev_state_pdu=prev_state_pdu, + ) + + return self._db_pool.runInteraction(_snapshot) + + +class Snapshot(object): + """Snapshot of the state of a room + Args: + store (DataStore): The datastore. + room_id (RoomId): The room of the snapshot. + user_id (UserId): The user this snapshot is for. + prev_pdus (list): The list of PDU ids this snapshot is after. + membership_state (RoomMemberEvent): The current state of the user in + the room. + state_type (str, optional): State type captured by the snapshot + state_key (str, optional): State key captured by the snapshot + prev_state_pdu (PduEntry, optional): pdu id of + the previous value of the state type and key in the room. + """ + + def __init__(self, store, room_id, user_id, prev_pdus, + membership_state, state_type=None, state_key=None, + prev_state_pdu=None): + self.store = store + self.room_id = room_id + self.user_id = user_id + self.prev_pdus = prev_pdus + self.membership_state + self.state_type = state_type + self.state_key = state_key + self.prev_state_pdu = prev_state_pdu + + def schema_path(schema): """ Get a filesystem path for the named database schema diff --git a/tests/handlers/test_room.py b/tests/handlers/test_room.py index bf71d3be3..9ab409643 100644 --- a/tests/handlers/test_room.py +++ b/tests/handlers/test_room.py @@ -45,6 +45,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase): "get_room_member", "get_room", "store_room", + "snapshot_room", ]), resource_for_federation=NonCallableMock(), http_client=NonCallableMock(spec_set=[]), @@ -75,6 +76,10 @@ class RoomMemberHandlerTestCase(unittest.TestCase): self.handlers.profile_handler = ProfileHandler(self.hs) self.room_member_handler = self.handlers.room_member_handler + self.snapshot = Mock() + self.datastore.snapshot_room.return_value = self.snapshot + + @defer.inlineCallbacks def test_invite(self): room_id = "!foo:red" @@ -104,8 +109,12 @@ class RoomMemberHandlerTestCase(unittest.TestCase): # Actual invocation yield self.room_member_handler.change_membership(event) - self.state_handler.handle_new_event.assert_called_once_with(event) - self.federation.handle_new_event.assert_called_once_with(event) + self.state_handler.handle_new_event.assert_called_once_with( + event, self.snapshot, + ) + self.federation.handle_new_event.assert_called_once_with( + event, self.snapshot, + ) self.assertEquals( set(["blue", "red", "green"]), @@ -116,7 +125,8 @@ class RoomMemberHandlerTestCase(unittest.TestCase): event ) self.notifier.on_new_room_event.assert_called_once_with( - event, store_id) + event, store_id + ) self.assertFalse(self.datastore.get_room.called) self.assertFalse(self.datastore.store_room.called) @@ -148,6 +158,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase): self.datastore.get_joined_hosts_for_room.side_effect = get_joined + store_id = "store_id_fooo" self.datastore.persist_event.return_value = defer.succeed(store_id) self.datastore.get_room.return_value = defer.succeed(1) # Not None. @@ -163,8 +174,12 @@ class RoomMemberHandlerTestCase(unittest.TestCase): # Actual invocation yield self.room_member_handler.change_membership(event) - self.state_handler.handle_new_event.assert_called_once_with(event) - self.federation.handle_new_event.assert_called_once_with(event) + self.state_handler.handle_new_event.assert_called_once_with( + event, self.snapshot + ) + self.federation.handle_new_event.assert_called_once_with( + event, self.snapshot + ) self.assertEquals( set(["red", "green"]), diff --git a/tests/utils.py b/tests/utils.py index c68b17f7b..f10bb8960 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -127,6 +127,13 @@ class MemoryDataStore(object): self.current_state = {} self.events = [] + Snapshot = namedtuple("Snapshot", "room_id user_id membership_state") + + def snapshot_room(self, room_id, user_id, state_type=None, state_key=None): + return self.Snapshot( + room_id, user_id, self.get_room_member(user_id, room_id) + ) + def register(self, user_id, token, password_hash): if user_id in self.tokens_to_users.values(): raise StoreError(400, "User in use.")