mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-14 12:13:52 +01:00
Handle the case where we don't have a common ancestor
This commit is contained in:
parent
2eaa199e6a
commit
942d8412c4
2 changed files with 42 additions and 9 deletions
|
@ -174,7 +174,9 @@ class StateHandler(object):
|
||||||
n = new_branch[-1]
|
n = new_branch[-1]
|
||||||
c = current_branch[-1]
|
c = current_branch[-1]
|
||||||
|
|
||||||
if n.pdu_id == c.pdu_id and n.origin == c.origin:
|
common_ancestor = n.pdu_id == c.pdu_id and n.origin == c.origin
|
||||||
|
|
||||||
|
if common_ancestor:
|
||||||
# We found a common ancestor!
|
# We found a common ancestor!
|
||||||
|
|
||||||
if len(current_branch) == 1:
|
if len(current_branch) == 1:
|
||||||
|
@ -185,10 +187,12 @@ class StateHandler(object):
|
||||||
# We didn't find a common ancestor. This is probably fine.
|
# We didn't find a common ancestor. This is probably fine.
|
||||||
pass
|
pass
|
||||||
|
|
||||||
result = self._do_conflict_res(new_branch, current_branch)
|
result = self._do_conflict_res(
|
||||||
|
new_branch, current_branch, common_ancestor
|
||||||
|
)
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
|
||||||
def _do_conflict_res(self, new_branch, current_branch):
|
def _do_conflict_res(self, new_branch, current_branch, common_ancestor):
|
||||||
conflict_res = [
|
conflict_res = [
|
||||||
self._do_power_level_conflict_res,
|
self._do_power_level_conflict_res,
|
||||||
self._do_chain_length_conflict_res,
|
self._do_chain_length_conflict_res,
|
||||||
|
@ -196,7 +200,9 @@ class StateHandler(object):
|
||||||
]
|
]
|
||||||
|
|
||||||
for algo in conflict_res:
|
for algo in conflict_res:
|
||||||
new_res, curr_res = algo(new_branch, current_branch)
|
new_res, curr_res = algo(
|
||||||
|
new_branch, current_branch, common_ancestor
|
||||||
|
)
|
||||||
|
|
||||||
if new_res < curr_res:
|
if new_res < curr_res:
|
||||||
defer.returnValue(False)
|
defer.returnValue(False)
|
||||||
|
@ -205,23 +211,26 @@ class StateHandler(object):
|
||||||
|
|
||||||
raise Exception("Conflict resolution failed.")
|
raise Exception("Conflict resolution failed.")
|
||||||
|
|
||||||
def _do_power_level_conflict_res(self, new_branch, current_branch):
|
def _do_power_level_conflict_res(self, new_branch, current_branch,
|
||||||
|
common_ancestor):
|
||||||
max_power_new = max(
|
max_power_new = max(
|
||||||
new_branch[:-1],
|
new_branch[:-1] if common_ancestor else new_branch,
|
||||||
key=lambda t: t.power_level
|
key=lambda t: t.power_level
|
||||||
).power_level
|
).power_level
|
||||||
|
|
||||||
max_power_current = max(
|
max_power_current = max(
|
||||||
current_branch[:-1],
|
current_branch[:-1] if common_ancestor else current_branch,
|
||||||
key=lambda t: t.power_level
|
key=lambda t: t.power_level
|
||||||
).power_level
|
).power_level
|
||||||
|
|
||||||
return (max_power_new, max_power_current)
|
return (max_power_new, max_power_current)
|
||||||
|
|
||||||
def _do_chain_length_conflict_res(self, new_branch, current_branch):
|
def _do_chain_length_conflict_res(self, new_branch, current_branch,
|
||||||
|
common_ancestor):
|
||||||
return (len(new_branch), len(current_branch))
|
return (len(new_branch), len(current_branch))
|
||||||
|
|
||||||
def _do_hash_conflict_res(self, new_branch, current_branch):
|
def _do_hash_conflict_res(self, new_branch, current_branch,
|
||||||
|
common_ancestor):
|
||||||
new_str = "".join([p.pdu_id + p.origin for p in new_branch])
|
new_str = "".join([p.pdu_id + p.origin for p in new_branch])
|
||||||
c_str = "".join([p.pdu_id + p.origin for p in current_branch])
|
c_str = "".join([p.pdu_id + p.origin for p in current_branch])
|
||||||
|
|
||||||
|
|
|
@ -440,6 +440,30 @@ class StateTestCase(unittest.TestCase):
|
||||||
|
|
||||||
self.assertEqual(1, self.persistence.update_current_state.call_count)
|
self.assertEqual(1, self.persistence.update_current_state.call_count)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_no_common_ancestor(self):
|
||||||
|
# We do a direct overwriting of the old state, i.e., the new state
|
||||||
|
# points to the old state.
|
||||||
|
|
||||||
|
old_pdu = new_fake_pdu_entry("A", "test", "mem", "x", None, 5)
|
||||||
|
new_pdu = new_fake_pdu_entry("B", "test", "mem", "x", None, 10)
|
||||||
|
|
||||||
|
self.persistence.get_unresolved_state_tree.return_value = (
|
||||||
|
(ReturnType([new_pdu], [old_pdu]), None)
|
||||||
|
)
|
||||||
|
|
||||||
|
is_new = yield self.state.handle_new_state(new_pdu)
|
||||||
|
|
||||||
|
self.assertTrue(is_new)
|
||||||
|
|
||||||
|
self.persistence.get_unresolved_state_tree.assert_called_once_with(
|
||||||
|
new_pdu
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(1, self.persistence.update_current_state.call_count)
|
||||||
|
|
||||||
|
self.assertFalse(self.replication.get_pdu.called)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_new_event(self):
|
def test_new_event(self):
|
||||||
event = Mock()
|
event = Mock()
|
||||||
|
|
Loading…
Reference in a new issue