Take a snapshot of the state of the room before performing updates

This commit is contained in:
Mark Haines 2014-08-22 17:00:10 +01:00
parent d12a7c3939
commit 1379dcae6f
6 changed files with 162 additions and 58 deletions

View file

@ -34,7 +34,7 @@ class Auth(object):
self.store = hs.get_datastore()
@defer.inlineCallbacks
def check(self, event, raises=False):
def check(self, event, snapshot, raises=False):
""" Checks if this event is correctly authed.
Returns:
@ -46,7 +46,11 @@ class Auth(object):
try:
if event.type in [RoomTopicEvent.TYPE, MessageEvent.TYPE,
FeedbackEvent.TYPE]:
yield self.check_joined_room(event.room_id, event.user_id)
self._check_joined_room(
member=snapshot.membership_state,
user_id=snapshot.user_id,
room_id=snapshot.room_id,
)
defer.returnValue(True)
elif event.type == RoomMemberEvent.TYPE:
allowed = yield self.is_membership_change_allowed(event)
@ -67,14 +71,16 @@ class Auth(object):
room_id=room_id,
user_id=user_id
)
if not member or member.membership != Membership.JOIN:
raise AuthError(403, "User %s not in room %s" %
(user_id, room_id))
self._check_joined_room(member, user_id, room_id)
defer.returnValue(member)
except AttributeError:
pass
defer.returnValue(None)
def _check_joined_room(self, member, user_id, room_id):
if not member or member.membership != Membership.JOIN:
raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
@defer.inlineCallbacks
def is_membership_change_allowed(self, event):
# does this room even exist

View file

@ -51,19 +51,20 @@ class FederationEventHandler(object):
@log_function
@defer.inlineCallbacks
def handle_new_event(self, event):
def handle_new_event(self, event, snapshot):
""" Takes in an event from the client to server side, that has already
been authed and handled by the state module, and sends it to any
remote home servers that may be interested.
Args:
event
snapshot (.storage.Snapshot): THe snapshot the event happened after
Returns:
Deferred: Resolved when it has successfully been queued for
processing.
"""
yield self.fill_out_prev_events(event)
yield self.fill_out_prev_events(event, snapshot)
pdu = self.pdu_codec.pdu_from_event(event)
@ -137,13 +138,11 @@ class FederationEventHandler(object):
yield self.event_handler.on_receive(new_state_event)
@defer.inlineCallbacks
def fill_out_prev_events(self, event):
def fill_out_prev_events(self, event, snapshot):
if hasattr(event, "prev_events"):
return
results = yield self.store.get_latest_pdus_in_context(
event.room_id
)
results = snapshot.prev_pdus
es = [
"%s@%s" % (p_id, origin) for p_id, origin, _ in results

View file

@ -85,9 +85,10 @@ class MessageHandler(BaseHandler):
if stamp_event:
event.content["hsob_ts"] = int(self.clock.time_msec())
with (yield self.room_lock.lock(event.room_id)):
snapshot = yield self.store.snapshot_room(event.room_id, event.user_id)
if not suppress_auth:
yield self.auth.check(event, raises=True)
yield self.auth.check(event, snapshot, raises=True)
# store message in db
store_id = yield self.store.persist_event(event)
@ -98,7 +99,7 @@ class MessageHandler(BaseHandler):
self.notifier.on_new_room_event(event, store_id)
yield self.hs.get_federation().handle_new_event(event)
yield self.hs.get_federation().handle_new_event(event, snapshot)
@defer.inlineCallbacks
def get_messages(self, user_id=None, room_id=None, pagin_config=None,
@ -135,8 +136,9 @@ class MessageHandler(BaseHandler):
SynapseError if something went wrong.
"""
with (yield self.room_lock.lock(event.room_id)):
yield self.auth.check(event, raises=True)
snapshot = yield self.store.snapshot_room(event.room_id, event.user_id)
yield self.auth.check(event, snapshot, raises=True)
if stamp_event:
event.content["hsob_ts"] = int(self.clock.time_msec())
@ -151,7 +153,7 @@ class MessageHandler(BaseHandler):
)
self.notifier.on_new_room_event(event, store_id)
yield self.hs.get_federation().handle_new_event(event)
yield self.hs.get_federation().handle_new_event(event, snapshot)
@defer.inlineCallbacks
def get_room_data(self, user_id=None, room_id=None,
@ -220,8 +222,10 @@ class MessageHandler(BaseHandler):
if stamp_event:
event.content["hsob_ts"] = int(self.clock.time_msec())
with (yield self.room_lock.lock(event.room_id)):
yield self.auth.check(event, raises=True)
snapshot = yield self.store.snapshot_room(event.room_id, user_id)
yield self.auth.check(event, snapshot, raises=True)
# store message in db
store_id = yield self.store.persist_event(event)
@ -229,7 +233,7 @@ class MessageHandler(BaseHandler):
event.destinations = yield self.store.get_joined_hosts_for_room(
event.room_id
)
yield self.hs.get_federation().handle_new_event(event)
yield self.hs.get_federation().handle_new_event(event, snapshot)
self.notifier.on_new_room_event(event, store_id)
@ -503,6 +507,11 @@ class RoomMemberHandler(BaseHandler):
SynapseError if there was a problem changing the membership.
"""
snapshot = yield self.store.snapshot_room(
event.room_id, event.user_id,
RoomMemberEvent.TYPE, event.target_user_id
)
## TODO(markjh): get prev state from snapshot.
prev_state = yield self.store.get_room_member(
event.target_user_id, event.room_id
)
@ -523,24 +532,22 @@ class RoomMemberHandler(BaseHandler):
# if this HS is not currently in the room, i.e. we have to do the
# invite/join dance.
if event.membership == Membership.JOIN:
yield self._do_join(event, do_auth=do_auth)
yield self._do_join(event, snapshot, do_auth=do_auth)
else:
# This is not a JOIN, so we can handle it normally.
if do_auth:
yield self.auth.check(event, raises=True)
yield self.auth.check(event, snapshot, raises=True)
prev_state = yield self.store.get_room_member(
event.target_user_id, event.room_id
)
if prev_state and prev_state.membership == event.membership:
# double same action, treat this event as a NOOP.
defer.returnValue({})
return
yield self.state_handler.handle_new_event(event)
yield self.state_handler.handle_new_event(event, snapshot)
yield self._do_local_membership_update(
event,
membership=event.content["membership"],
snapshot=snapshot,
)
defer.returnValue({"room_id": room_id})
@ -570,12 +577,16 @@ class RoomMemberHandler(BaseHandler):
content=content,
)
yield self._do_join(new_event, room_host=host, do_auth=True)
snapshot = yield store.snapshot_room(
room_id, joinee, RoomMemberEvent.TYPE, event.target_user_id
)
yield self._do_join(new_event, snapshot, room_host=host, do_auth=True)
defer.returnValue({"room_id": room_id})
@defer.inlineCallbacks
def _do_join(self, event, room_host=None, do_auth=True):
def _do_join(self, event, snapshot, room_host=None, do_auth=True):
joinee = self.hs.parse_userid(event.target_user_id)
# room_id = RoomID.from_string(event.room_id, self.hs)
room_id = event.room_id
@ -597,6 +608,7 @@ class RoomMemberHandler(BaseHandler):
elif room_host:
should_do_dance = True
else:
# TODO(markjh): get prev_state from snapshot
prev_state = yield self.store.get_room_member(
joinee.to_string(), room_id
)
@ -624,12 +636,13 @@ class RoomMemberHandler(BaseHandler):
logger.debug("Doing normal join")
if do_auth:
yield self.auth.check(event, raises=True)
yield self.auth.check(event, snapshot, raises=True)
yield self.state_handler.handle_new_event(event)
yield self.state_handler.handle_new_event(event, snapshot)
yield self._do_local_membership_update(
event,
membership=event.content["membership"],
snapshot=snapshot,
)
user = self.hs.parse_userid(event.user_id)
@ -674,7 +687,7 @@ class RoomMemberHandler(BaseHandler):
defer.returnValue([r.room_id for r in rooms])
@defer.inlineCallbacks
def _do_local_membership_update(self, event, membership):
def _do_local_membership_update(self, event, membership, snapshot):
# store membership
store_id = yield self.store.persist_event(event)
@ -700,7 +713,7 @@ class RoomMemberHandler(BaseHandler):
event.destinations = list(set(destinations))
yield self.hs.get_federation().handle_new_event(event)
yield self.hs.get_federation().handle_new_event(event, snapshot)
self.notifier.on_new_room_event(event, store_id)

View file

@ -187,6 +187,70 @@ class DataStore(RoomMemberStore, RoomStore,
defer.returnValue(self.min_token)
def snapshot_room(self, room_id, user_id, state_type=None, state_key=None):
"""Snapshot the room for an update by a user
Args:
room_id (synapse.types.RoomId): The room to snapshot.
user_id (synapse.types.UserId): The user to snapshot the room for.
state_type (str): Optional state type to snapshot.
state_key (str): Optional state key to snapshot.
Returns:
synapse.storage.Snapshot: A snapshot of the state of the room.
"""
def _snapshot(txn):
membership_state = self._get_room_member(txn, user_id)
prev_pdus = self._get_latest_pdus_in_context(
txn, room_id
)
if state_type is not None and state_key is not None:
prev_state_pdu = self._get_current_state_pdu(
txn, room_id, state_type, state_key
)
else:
prev_state_pdu = None
return Snapshot(
store=self,
room_id=room_id,
user_id=user_id,
prev_pdus=prev_pdus,
membership_state=membership_state,
state_type=state_type,
state_key=state_key,
prev_state_pdu=prev_state_pdu,
)
return self._db_pool.runInteraction(_snapshot)
class Snapshot(object):
"""Snapshot of the state of a room
Args:
store (DataStore): The datastore.
room_id (RoomId): The room of the snapshot.
user_id (UserId): The user this snapshot is for.
prev_pdus (list): The list of PDU ids this snapshot is after.
membership_state (RoomMemberEvent): The current state of the user in
the room.
state_type (str, optional): State type captured by the snapshot
state_key (str, optional): State key captured by the snapshot
prev_state_pdu (PduEntry, optional): pdu id of
the previous value of the state type and key in the room.
"""
def __init__(self, store, room_id, user_id, prev_pdus,
membership_state, state_type=None, state_key=None,
prev_state_pdu=None):
self.store = store
self.room_id = room_id
self.user_id = user_id
self.prev_pdus = prev_pdus
self.membership_state
self.state_type = state_type
self.state_key = state_key
self.prev_state_pdu = prev_state_pdu
def schema_path(schema):
""" Get a filesystem path for the named database schema

View file

@ -45,6 +45,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
"get_room_member",
"get_room",
"store_room",
"snapshot_room",
]),
resource_for_federation=NonCallableMock(),
http_client=NonCallableMock(spec_set=[]),
@ -75,6 +76,10 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
self.handlers.profile_handler = ProfileHandler(self.hs)
self.room_member_handler = self.handlers.room_member_handler
self.snapshot = Mock()
self.datastore.snapshot_room.return_value = self.snapshot
@defer.inlineCallbacks
def test_invite(self):
room_id = "!foo:red"
@ -104,8 +109,12 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
# Actual invocation
yield self.room_member_handler.change_membership(event)
self.state_handler.handle_new_event.assert_called_once_with(event)
self.federation.handle_new_event.assert_called_once_with(event)
self.state_handler.handle_new_event.assert_called_once_with(
event, self.snapshot,
)
self.federation.handle_new_event.assert_called_once_with(
event, self.snapshot,
)
self.assertEquals(
set(["blue", "red", "green"]),
@ -116,7 +125,8 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
event
)
self.notifier.on_new_room_event.assert_called_once_with(
event, store_id)
event, store_id
)
self.assertFalse(self.datastore.get_room.called)
self.assertFalse(self.datastore.store_room.called)
@ -148,6 +158,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
self.datastore.get_joined_hosts_for_room.side_effect = get_joined
store_id = "store_id_fooo"
self.datastore.persist_event.return_value = defer.succeed(store_id)
self.datastore.get_room.return_value = defer.succeed(1) # Not None.
@ -163,8 +174,12 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
# Actual invocation
yield self.room_member_handler.change_membership(event)
self.state_handler.handle_new_event.assert_called_once_with(event)
self.federation.handle_new_event.assert_called_once_with(event)
self.state_handler.handle_new_event.assert_called_once_with(
event, self.snapshot
)
self.federation.handle_new_event.assert_called_once_with(
event, self.snapshot
)
self.assertEquals(
set(["red", "green"]),

View file

@ -127,6 +127,13 @@ class MemoryDataStore(object):
self.current_state = {}
self.events = []
Snapshot = namedtuple("Snapshot", "room_id user_id membership_state")
def snapshot_room(self, room_id, user_id, state_type=None, state_key=None):
return self.Snapshot(
room_id, user_id, self.get_room_member(user_id, room_id)
)
def register(self, user_id, token, password_hash):
if user_id in self.tokens_to_users.values():
raise StoreError(400, "User in use.")