mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-14 07:43:48 +01:00
Fix tests
This commit is contained in:
parent
50943ab942
commit
3f11953fcb
1 changed files with 46 additions and 5 deletions
|
@ -67,6 +67,8 @@ class StateGroupStore(object):
|
||||||
self._event_to_state_group = {}
|
self._event_to_state_group = {}
|
||||||
self._group_to_state = {}
|
self._group_to_state = {}
|
||||||
|
|
||||||
|
self._event_id_to_event = {}
|
||||||
|
|
||||||
self._next_group = 1
|
self._next_group = 1
|
||||||
|
|
||||||
def get_state_groups_ids(self, room_id, event_ids):
|
def get_state_groups_ids(self, room_id, event_ids):
|
||||||
|
@ -96,6 +98,16 @@ class StateGroupStore(object):
|
||||||
|
|
||||||
self._event_to_state_group[event.event_id] = state_group
|
self._event_to_state_group[event.event_id] = state_group
|
||||||
|
|
||||||
|
def get_events(self, event_ids, **kwargs):
|
||||||
|
return {
|
||||||
|
e_id: self._event_id_to_event[e_id] for e_id in event_ids
|
||||||
|
if e_id in self._event_id_to_event
|
||||||
|
}
|
||||||
|
|
||||||
|
def register_events(self, events):
|
||||||
|
for e in events:
|
||||||
|
self._event_id_to_event[e.event_id] = e
|
||||||
|
|
||||||
|
|
||||||
class DictObj(dict):
|
class DictObj(dict):
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
|
@ -138,6 +150,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
spec_set=[
|
spec_set=[
|
||||||
"get_state_groups_ids",
|
"get_state_groups_ids",
|
||||||
"add_event_hashes",
|
"add_event_hashes",
|
||||||
|
"get_events",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
hs = Mock(spec_set=[
|
hs = Mock(spec_set=[
|
||||||
|
@ -240,6 +253,8 @@ class StateTestCase(unittest.TestCase):
|
||||||
|
|
||||||
store = StateGroupStore()
|
store = StateGroupStore()
|
||||||
self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
|
self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
|
||||||
|
self.store.get_events = store.get_events
|
||||||
|
store.register_events(graph.walk())
|
||||||
|
|
||||||
context_store = {}
|
context_store = {}
|
||||||
|
|
||||||
|
@ -250,7 +265,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
|
|
||||||
self.assertSetEqual(
|
self.assertSetEqual(
|
||||||
{"START", "A", "C"},
|
{"START", "A", "C"},
|
||||||
{e.event_id for e in context_store["D"].current_state.values()}
|
{e_id for e_id in context_store["D"].current_state_ids.values()}
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -304,6 +319,8 @@ class StateTestCase(unittest.TestCase):
|
||||||
|
|
||||||
store = StateGroupStore()
|
store = StateGroupStore()
|
||||||
self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
|
self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
|
||||||
|
self.store.get_events = store.get_events
|
||||||
|
store.register_events(graph.walk())
|
||||||
|
|
||||||
context_store = {}
|
context_store = {}
|
||||||
|
|
||||||
|
@ -314,7 +331,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
|
|
||||||
self.assertSetEqual(
|
self.assertSetEqual(
|
||||||
{"START", "A", "B", "C"},
|
{"START", "A", "B", "C"},
|
||||||
{e.event_id for e in context_store["E"].current_state.values()}
|
{e for e in context_store["E"].current_state_ids.values()}
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -385,6 +402,8 @@ class StateTestCase(unittest.TestCase):
|
||||||
|
|
||||||
store = StateGroupStore()
|
store = StateGroupStore()
|
||||||
self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
|
self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
|
||||||
|
self.store.get_events = store.get_events
|
||||||
|
store.register_events(graph.walk())
|
||||||
|
|
||||||
context_store = {}
|
context_store = {}
|
||||||
|
|
||||||
|
@ -395,7 +414,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
|
|
||||||
self.assertSetEqual(
|
self.assertSetEqual(
|
||||||
{"A1", "A2", "A3", "A5", "B"},
|
{"A1", "A2", "A3", "A5", "B"},
|
||||||
{e.event_id for e in context_store["D"].current_state.values()}
|
{e for e in context_store["D"].current_state_ids.values()}
|
||||||
)
|
)
|
||||||
|
|
||||||
def _add_depths(self, nodes, edges):
|
def _add_depths(self, nodes, edges):
|
||||||
|
@ -522,6 +541,11 @@ class StateTestCase(unittest.TestCase):
|
||||||
create_event(type="test4", state_key=""),
|
create_event(type="test4", state_key=""),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
store = StateGroupStore()
|
||||||
|
store.register_events(old_state_1)
|
||||||
|
store.register_events(old_state_2)
|
||||||
|
self.store.get_events = store.get_events
|
||||||
|
|
||||||
context = yield self._get_context(event, old_state_1, old_state_2)
|
context = yield self._get_context(event, old_state_1, old_state_2)
|
||||||
|
|
||||||
self.assertEqual(len(context.current_state_ids), 6)
|
self.assertEqual(len(context.current_state_ids), 6)
|
||||||
|
@ -550,6 +574,11 @@ class StateTestCase(unittest.TestCase):
|
||||||
create_event(type="test4", state_key=""),
|
create_event(type="test4", state_key=""),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
store = StateGroupStore()
|
||||||
|
store.register_events(old_state_1)
|
||||||
|
store.register_events(old_state_2)
|
||||||
|
self.store.get_events = store.get_events
|
||||||
|
|
||||||
context = yield self._get_context(event, old_state_1, old_state_2)
|
context = yield self._get_context(event, old_state_1, old_state_2)
|
||||||
|
|
||||||
self.assertEqual(len(context.current_state_ids), 6)
|
self.assertEqual(len(context.current_state_ids), 6)
|
||||||
|
@ -585,9 +614,16 @@ class StateTestCase(unittest.TestCase):
|
||||||
create_event(type="test1", state_key="1", depth=2),
|
create_event(type="test1", state_key="1", depth=2),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
store = StateGroupStore()
|
||||||
|
store.register_events(old_state_1)
|
||||||
|
store.register_events(old_state_2)
|
||||||
|
self.store.get_events = store.get_events
|
||||||
|
|
||||||
context = yield self._get_context(event, old_state_1, old_state_2)
|
context = yield self._get_context(event, old_state_1, old_state_2)
|
||||||
|
|
||||||
self.assertEqual(old_state_2[2].event.id, context.current_state_ids[("test1", "1")])
|
self.assertEqual(
|
||||||
|
old_state_2[2].event_id, context.current_state_ids[("test1", "1")]
|
||||||
|
)
|
||||||
|
|
||||||
# Reverse the depth to make sure we are actually using the depths
|
# Reverse the depth to make sure we are actually using the depths
|
||||||
# during state resolution.
|
# during state resolution.
|
||||||
|
@ -604,9 +640,14 @@ class StateTestCase(unittest.TestCase):
|
||||||
create_event(type="test1", state_key="1", depth=1),
|
create_event(type="test1", state_key="1", depth=1),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
store.register_events(old_state_1)
|
||||||
|
store.register_events(old_state_2)
|
||||||
|
|
||||||
context = yield self._get_context(event, old_state_1, old_state_2)
|
context = yield self._get_context(event, old_state_1, old_state_2)
|
||||||
|
|
||||||
self.assertEqual(old_state_1[2].event_id, context.current_state_ids[("test1", "1")])
|
self.assertEqual(
|
||||||
|
old_state_1[2].event_id, context.current_state_ids[("test1", "1")]
|
||||||
|
)
|
||||||
|
|
||||||
def _get_context(self, event, old_state_1, old_state_2):
|
def _get_context(self, event, old_state_1, old_state_2):
|
||||||
group_name_1 = "group_name_1"
|
group_name_1 = "group_name_1"
|
||||||
|
|
Loading…
Reference in a new issue