0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-15 22:42:23 +01:00

Fix bug in state handling where we incorrectly identified a missing pdu. Update tests to catch this case.

This commit is contained in:
Erik Johnston 2014-09-08 19:50:46 +01:00
parent de727f854a
commit 83ce57302d
3 changed files with 271 additions and 71 deletions

View file

@ -134,7 +134,9 @@ class StateHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def _handle_new_state(self, new_pdu): def _handle_new_state(self, new_pdu):
tree = yield self.store.get_unresolved_state_tree(new_pdu) tree, missing_branch = yield self.store.get_unresolved_state_tree(
new_pdu
)
new_branch, current_branch = tree new_branch, current_branch = tree
logger.debug( logger.debug(
@ -142,6 +144,28 @@ class StateHandler(object):
new_branch, current_branch new_branch, current_branch
) )
if missing_branch is not None:
# We're missing some PDUs. Fetch them.
# TODO (erikj): Limit this.
missing_prev = tree[missing_branch][-1]
pdu_id = missing_prev.prev_state_id
origin = missing_prev.prev_state_origin
is_missing = yield self.store.get_pdu(pdu_id, origin) is None
if not is_missing:
raise Exception("Conflict resolution failed")
yield self._replication.get_pdu(
destination=missing_prev.origin,
pdu_origin=origin,
pdu_id=pdu_id,
outlier=True
)
updated_current = yield self._handle_new_state(new_pdu)
defer.returnValue(updated_current)
if not current_branch: if not current_branch:
# There is no current state # There is no current state
defer.returnValue(True) defer.returnValue(True)
@ -151,13 +175,20 @@ class StateHandler(object):
c = current_branch[-1] c = current_branch[-1]
if n.pdu_id == c.pdu_id and n.origin == c.origin: if n.pdu_id == c.pdu_id and n.origin == c.origin:
# We have all the PDUs we need, so we can just do the conflict # We found a common ancestor!
# resolution.
if len(current_branch) == 1: if len(current_branch) == 1:
# This is a direct clobber so we can just... # This is a direct clobber so we can just...
defer.returnValue(True) defer.returnValue(True)
else:
# We didn't find a common ancestor. This is probably fine.
pass
result = self._do_conflict_res(new_branch, current_branch)
defer.returnValue(result)
def _do_conflict_res(self, new_branch, current_branch):
conflict_res = [ conflict_res = [
self._do_power_level_conflict_res, self._do_power_level_conflict_res,
self._do_chain_length_conflict_res, self._do_chain_length_conflict_res,
@ -174,43 +205,6 @@ class StateHandler(object):
raise Exception("Conflict resolution failed.") raise Exception("Conflict resolution failed.")
else:
# We need to ask for PDUs.
missing_prev = max(
new_branch[-1], current_branch[-1],
key=lambda x: x.depth
)
if not hasattr(missing_prev, "prev_state_id"):
# FIXME Hmm
# temporary fallback
for algo in conflict_res:
new_res, curr_res = algo(new_branch, current_branch)
if new_res < curr_res:
defer.returnValue(False)
elif new_res > curr_res:
defer.returnValue(True)
return
pdu_id = missing_prev.prev_state_id
origin = missing_prev.prev_state_origin
is_missing = yield self.store.get_pdu(pdu_id, origin) is None
if not is_missing:
raise Exception("Conflict resolution failed.")
yield self._replication.get_pdu(
destination=missing_prev.origin,
pdu_origin=origin,
pdu_id=pdu_id,
outlier=True
)
updated_current = yield self._handle_new_state(new_pdu)
defer.returnValue(updated_current)
def _do_power_level_conflict_res(self, new_branch, current_branch): def _do_power_level_conflict_res(self, new_branch, current_branch):
max_power_new = max( max_power_new = max(
new_branch[:-1], new_branch[:-1],

View file

@ -308,7 +308,7 @@ class PduStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_oldest_pdus_in_context(self, context): def get_oldest_pdus_in_context(self, context):
"""Get a list of Pdus that we haven't backfilled beyond yet (and haven't """Get a list of Pdus that we haven't backfilled beyond yet (and havent
seen). This list is used when we want to backfill backwards and is the seen). This list is used when we want to backfill backwards and is the
list we send to the remote server. list we send to the remote server.
@ -524,13 +524,16 @@ class StatePduStore(SQLBaseStore):
txn, new_pdu, current txn, new_pdu, current
) )
missing_branch = None
for branch, prev_state, state in enum_branches: for branch, prev_state, state in enum_branches:
if state: if state:
return_value[branch].append(state) return_value[branch].append(state)
else: else:
# We don't have prev_state :(
missing_branch = branch
break break
return return_value return (return_value, missing_branch)
def update_current_state(self, pdu_id, origin, context, pdu_type, def update_current_state(self, pdu_id, origin, context, pdu_type,
state_key): state_key):

View file

@ -24,6 +24,8 @@ from collections import namedtuple
from mock import Mock from mock import Mock
import mock
ReturnType = namedtuple( ReturnType = namedtuple(
"StateReturnType", ["new_branch", "current_branch"] "StateReturnType", ["new_branch", "current_branch"]
@ -54,7 +56,7 @@ class StateTestCase(unittest.TestCase):
new_pdu = new_fake_pdu_entry("A", "test", "mem", "x", None, 10) new_pdu = new_fake_pdu_entry("A", "test", "mem", "x", None, 10)
self.persistence.get_unresolved_state_tree.return_value = ( self.persistence.get_unresolved_state_tree.return_value = (
ReturnType([new_pdu], []) (ReturnType([new_pdu], []), None)
) )
is_new = yield self.state.handle_new_state(new_pdu) is_new = yield self.state.handle_new_state(new_pdu)
@ -78,7 +80,7 @@ class StateTestCase(unittest.TestCase):
new_pdu = new_fake_pdu_entry("B", "test", "mem", "x", "A", 5) new_pdu = new_fake_pdu_entry("B", "test", "mem", "x", "A", 5)
self.persistence.get_unresolved_state_tree.return_value = ( self.persistence.get_unresolved_state_tree.return_value = (
ReturnType([new_pdu, old_pdu], [old_pdu]) (ReturnType([new_pdu, old_pdu], [old_pdu]), None)
) )
is_new = yield self.state.handle_new_state(new_pdu) is_new = yield self.state.handle_new_state(new_pdu)
@ -103,7 +105,7 @@ class StateTestCase(unittest.TestCase):
new_pdu = new_fake_pdu_entry("C", "test", "mem", "x", "A", 5) new_pdu = new_fake_pdu_entry("C", "test", "mem", "x", "A", 5)
self.persistence.get_unresolved_state_tree.return_value = ( self.persistence.get_unresolved_state_tree.return_value = (
ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]) (ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]), None)
) )
is_new = yield self.state.handle_new_state(new_pdu) is_new = yield self.state.handle_new_state(new_pdu)
@ -128,7 +130,7 @@ class StateTestCase(unittest.TestCase):
new_pdu = new_fake_pdu_entry("C", "test", "mem", "x", "A", 15) new_pdu = new_fake_pdu_entry("C", "test", "mem", "x", "A", 15)
self.persistence.get_unresolved_state_tree.return_value = ( self.persistence.get_unresolved_state_tree.return_value = (
ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]) (ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]), None)
) )
is_new = yield self.state.handle_new_state(new_pdu) is_new = yield self.state.handle_new_state(new_pdu)
@ -153,7 +155,7 @@ class StateTestCase(unittest.TestCase):
new_pdu = new_fake_pdu_entry("C", "test", "mem", "x", "A", 10) new_pdu = new_fake_pdu_entry("C", "test", "mem", "x", "A", 10)
self.persistence.get_unresolved_state_tree.return_value = ( self.persistence.get_unresolved_state_tree.return_value = (
ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]) (ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]), None)
) )
is_new = yield self.state.handle_new_state(new_pdu) is_new = yield self.state.handle_new_state(new_pdu)
@ -179,7 +181,13 @@ class StateTestCase(unittest.TestCase):
new_pdu = new_fake_pdu_entry("D", "test", "mem", "x", "C", 10) new_pdu = new_fake_pdu_entry("D", "test", "mem", "x", "C", 10)
self.persistence.get_unresolved_state_tree.return_value = ( self.persistence.get_unresolved_state_tree.return_value = (
ReturnType([new_pdu, old_pdu_3, old_pdu_1], [old_pdu_2, old_pdu_1]) (
ReturnType(
[new_pdu, old_pdu_3, old_pdu_1],
[old_pdu_2, old_pdu_1]
),
None
)
) )
is_new = yield self.state.handle_new_state(new_pdu) is_new = yield self.state.handle_new_state(new_pdu)
@ -200,22 +208,32 @@ class StateTestCase(unittest.TestCase):
# triggering a get_pdu request # triggering a get_pdu request
# The pdu we haven't seen # The pdu we haven't seen
old_pdu_1 = new_fake_pdu_entry("A", "test", "mem", "x", None, 10) old_pdu_1 = new_fake_pdu_entry(
"A", "test", "mem", "x", None, 10, depth=0
)
old_pdu_2 = new_fake_pdu_entry("B", "test", "mem", "x", None, 10) old_pdu_2 = new_fake_pdu_entry(
new_pdu = new_fake_pdu_entry("C", "test", "mem", "x", "A", 20) "B", "test", "mem", "x", "A", 10, depth=1
)
new_pdu = new_fake_pdu_entry(
"C", "test", "mem", "x", "A", 20, depth=2
)
# The return_value of `get_unresolved_state_tree`, which changes after # The return_value of `get_unresolved_state_tree`, which changes after
# the call to get_pdu # the call to get_pdu
tree_to_return = [ReturnType([new_pdu], [old_pdu_2])] tree_to_return = [(ReturnType([new_pdu], [old_pdu_2]), 0)]
def return_tree(p): def return_tree(p):
return tree_to_return[0] return tree_to_return[0]
def set_return_tree(*args, **kwargs): def set_return_tree(destination, pdu_origin, pdu_id, outlier=False):
tree_to_return[0] = ReturnType( tree_to_return[0] = (
ReturnType(
[new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1] [new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]
),
None
) )
return defer.succeed(None)
self.persistence.get_unresolved_state_tree.side_effect = return_tree self.persistence.get_unresolved_state_tree.side_effect = return_tree
@ -227,6 +245,13 @@ class StateTestCase(unittest.TestCase):
self.assertTrue(is_new) self.assertTrue(is_new)
self.replication.get_pdu.assert_called_with(
destination=new_pdu.origin,
pdu_origin=old_pdu_1.origin,
pdu_id=old_pdu_1.pdu_id,
outlier=True
)
self.persistence.get_unresolved_state_tree.assert_called_with( self.persistence.get_unresolved_state_tree.assert_called_with(
new_pdu new_pdu
) )
@ -237,6 +262,184 @@ class StateTestCase(unittest.TestCase):
self.assertEqual(1, self.persistence.update_current_state.call_count) self.assertEqual(1, self.persistence.update_current_state.call_count)
@defer.inlineCallbacks
def test_missing_pdu_depth_1(self):
# We try to update state against a PDU we haven't yet seen,
# triggering a get_pdu request
# The pdu we haven't seen
old_pdu_1 = new_fake_pdu_entry(
"A", "test", "mem", "x", None, 10, depth=0
)
old_pdu_2 = new_fake_pdu_entry(
"B", "test", "mem", "x", "A", 10, depth=2
)
old_pdu_3 = new_fake_pdu_entry(
"C", "test", "mem", "x", "B", 10, depth=3
)
new_pdu = new_fake_pdu_entry(
"D", "test", "mem", "x", "A", 20, depth=4
)
# The return_value of `get_unresolved_state_tree`, which changes after
# the call to get_pdu
tree_to_return = [
(
ReturnType([new_pdu], [old_pdu_3]),
0
),
(
ReturnType(
[new_pdu, old_pdu_1], [old_pdu_3]
),
1
),
(
ReturnType(
[new_pdu, old_pdu_1], [old_pdu_3, old_pdu_2, old_pdu_1]
),
None
),
]
to_return = [0]
def return_tree(p):
return tree_to_return[to_return[0]]
def set_return_tree(destination, pdu_origin, pdu_id, outlier=False):
to_return[0] += 1
return defer.succeed(None)
self.persistence.get_unresolved_state_tree.side_effect = return_tree
self.replication.get_pdu.side_effect = set_return_tree
self.persistence.get_pdu.return_value = None
is_new = yield self.state.handle_new_state(new_pdu)
self.assertTrue(is_new)
self.assertEqual(2, self.replication.get_pdu.call_count)
self.replication.get_pdu.assert_has_calls(
[
mock.call(
destination=new_pdu.origin,
pdu_origin=old_pdu_1.origin,
pdu_id=old_pdu_1.pdu_id,
outlier=True
),
mock.call(
destination=old_pdu_3.origin,
pdu_origin=old_pdu_2.origin,
pdu_id=old_pdu_2.pdu_id,
outlier=True
),
]
)
self.persistence.get_unresolved_state_tree.assert_called_with(
new_pdu
)
self.assertEquals(
3, self.persistence.get_unresolved_state_tree.call_count
)
self.assertEqual(1, self.persistence.update_current_state.call_count)
@defer.inlineCallbacks
def test_missing_pdu_depth_2(self):
# We try to update state against a PDU we haven't yet seen,
# triggering a get_pdu request
# The pdu we haven't seen
old_pdu_1 = new_fake_pdu_entry(
"A", "test", "mem", "x", None, 10, depth=0
)
old_pdu_2 = new_fake_pdu_entry(
"B", "test", "mem", "x", "A", 10, depth=2
)
old_pdu_3 = new_fake_pdu_entry(
"C", "test", "mem", "x", "B", 10, depth=3
)
new_pdu = new_fake_pdu_entry(
"D", "test", "mem", "x", "A", 20, depth=1
)
# The return_value of `get_unresolved_state_tree`, which changes after
# the call to get_pdu
tree_to_return = [
(
ReturnType([new_pdu], [old_pdu_3]),
1,
),
(
ReturnType(
[new_pdu], [old_pdu_3, old_pdu_2]
),
0,
),
(
ReturnType(
[new_pdu, old_pdu_1], [old_pdu_3, old_pdu_2, old_pdu_1]
),
None
),
]
to_return = [0]
def return_tree(p):
return tree_to_return[to_return[0]]
def set_return_tree(destination, pdu_origin, pdu_id, outlier=False):
to_return[0] += 1
return defer.succeed(None)
self.persistence.get_unresolved_state_tree.side_effect = return_tree
self.replication.get_pdu.side_effect = set_return_tree
self.persistence.get_pdu.return_value = None
is_new = yield self.state.handle_new_state(new_pdu)
self.assertTrue(is_new)
self.assertEqual(2, self.replication.get_pdu.call_count)
self.replication.get_pdu.assert_has_calls(
[
mock.call(
destination=old_pdu_3.origin,
pdu_origin=old_pdu_2.origin,
pdu_id=old_pdu_2.pdu_id,
outlier=True
),
mock.call(
destination=new_pdu.origin,
pdu_origin=old_pdu_1.origin,
pdu_id=old_pdu_1.pdu_id,
outlier=True
),
]
)
self.persistence.get_unresolved_state_tree.assert_called_with(
new_pdu
)
self.assertEquals(
3, self.persistence.get_unresolved_state_tree.call_count
)
self.assertEqual(1, self.persistence.update_current_state.call_count)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_new_event(self): def test_new_event(self):
event = Mock() event = Mock()
@ -270,7 +473,7 @@ class StateTestCase(unittest.TestCase):
def new_fake_pdu_entry(pdu_id, context, pdu_type, state_key, prev_state_id, def new_fake_pdu_entry(pdu_id, context, pdu_type, state_key, prev_state_id,
power_level): power_level, depth=0):
new_pdu = PduEntry( new_pdu = PduEntry(
pdu_id=pdu_id, pdu_id=pdu_id,
pdu_type=pdu_type, pdu_type=pdu_type,
@ -280,7 +483,7 @@ def new_fake_pdu_entry(pdu_id, context, pdu_type, state_key, prev_state_id,
origin="example.com", origin="example.com",
context="context", context="context",
ts=1405353060021, ts=1405353060021,
depth=0, depth=depth,
content_json="{}", content_json="{}",
unrecognized_keys="{}", unrecognized_keys="{}",
outlier=True, outlier=True,