Start updating state handling to use snapshots

This commit is contained in:
Mark Haines 2014-08-27 15:11:51 +01:00
parent 46a2f6a816
commit a0d1f5a014
4 changed files with 24 additions and 25 deletions

View file

@ -133,7 +133,7 @@ class MessageHandler(BaseRoomHandler):
if stamp_event:
event.content["hsob_ts"] = int(self.clock.time_msec())
yield self.state_handler.handle_new_event(event)
yield self.state_handler.handle_new_event(event, snapshot)
yield self._on_new_room_event(event, snapshot)
@ -362,6 +362,13 @@ class RoomCreationHandler(BaseRoomHandler):
content=config,
)
snapshot = yield self.store.snapshot_room(
room_id=room_id,
user_id=user_id,
state_type=RoomConfigEvent.TYPE,
state_key="",
)
if room_alias:
yield self.store.create_room_alias_association(
room_id=room_id,
@ -369,11 +376,11 @@ class RoomCreationHandler(BaseRoomHandler):
servers=[self.hs.hostname],
)
yield self.state_handler.handle_new_event(config_event)
yield self.state_handler.handle_new_event(config_event, snapshot)
# store_id = persist...
federation_handler = self.hs.get_handlers().federation_handler
yield federation_handler.handle_new_event(config_event)
yield federation_handler.handle_new_event(config_event, snapshot)
# self.notifier.on_new_room_event(event, store_id)
content = {"membership": Membership.JOIN}

View file

@ -45,7 +45,7 @@ class StateHandler(object):
@defer.inlineCallbacks
@log_function
def handle_new_event(self, event):
def handle_new_event(self, event, snapshot):
""" Given an event this works out if a) we have sufficient power level
to update the state and b) works out what the prev_state should be.
@ -70,25 +70,13 @@ class StateHandler(object):
# Now I need to fill out the prev state and work out if it has auth
# (w.r.t. to power levels)
results = yield self.store.get_latest_pdus_in_context(
event.room_id
)
snapshot.fill_out_prev_events(event)
event.prev_events = [
encode_event_id(p_id, origin) for p_id, origin, _ in results
]
event.prev_events = [
e for e in event.prev_events if e != event.event_id
]
if results:
event.depth = max([int(v) for _, _, v in results]) + 1
else:
event.depth = 0
current_state = yield self.store.get_current_state_pdu(
key.context, key.type, key.state_key
)
current_state = snapshot.prev_state_pdu
if current_state:
event.prev_state = encode_event_id(

View file

@ -330,6 +330,7 @@ class RoomCreationTest(unittest.TestCase):
db_pool=None,
datastore=NonCallableMock(spec_set=[
"store_room",
"snapshot_room",
]),
http_client=NonCallableMock(spec_set=[]),
notifier=NonCallableMock(spec_set=["on_new_room_event"]),

View file

@ -243,21 +243,24 @@ class StateTestCase(unittest.TestCase):
state_pdu = new_fake_pdu_entry("C", "test", "mem", "x", "A", 20)
tup = ("pdu_id", "origin.com", 5)
pdus = [tup]
snapshot = Mock()
snapshot.prev_state_pdu = state_pdu
event_id = "pdu_id@origin.com"
self.persistence.get_latest_pdus_in_context.return_value = pdus
self.persistence.get_current_state_pdu.return_value = state_pdu
def fill_out_prev_events(event):
event.prev_events = [event_id]
event.depth = 6
snapshot.fill_out_prev_events = fill_out_prev_events
yield self.state.handle_new_event(event)
yield self.state.handle_new_event(event, snapshot)
self.assertLess(tup[2], event.depth)
self.assertLess(5, event.depth)
self.assertEquals(1, len(event.prev_events))
prev_id = event.prev_events[0]
self.assertEqual(encode_event_id(tup[0], tup[1]), prev_id)
self.assertEqual(event_id, prev_id)
self.assertEqual(
encode_event_id(state_pdu.pdu_id, state_pdu.origin),