mirror of
https://mau.dev/maunium/synapse.git
synced 2025-01-19 06:42:01 +01:00
docstrings and unittests for storage.state (#3958)
I spent ages trying to figure out how I was going mad...
This commit is contained in:
parent
2c695fd1aa
commit
ae6ad4cf41
3 changed files with 62 additions and 8 deletions
1
changelog.d/3958.misc
Normal file
1
changelog.d/3958.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Fix docstrings and add tests for state store methods
|
|
@ -255,7 +255,17 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_groups_ids(self, room_id, event_ids):
|
||||
def get_state_groups_ids(self, _room_id, event_ids):
|
||||
"""Get the event IDs of all the state for the state groups for the given events
|
||||
|
||||
Args:
|
||||
_room_id (str): id of the room for these events
|
||||
event_ids (iterable[str]): ids of the events
|
||||
|
||||
Returns:
|
||||
Deferred[dict[int, dict[tuple[str, str], str]]]:
|
||||
dict of state_group_id -> (dict of (type, state_key) -> event id)
|
||||
"""
|
||||
if not event_ids:
|
||||
defer.returnValue({})
|
||||
|
||||
|
@ -270,7 +280,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_ids_for_group(self, state_group):
|
||||
"""Get the state IDs for the given state group
|
||||
"""Get the event IDs of all the state in the given state group
|
||||
|
||||
Args:
|
||||
state_group (int)
|
||||
|
@ -286,7 +296,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
def get_state_groups(self, room_id, event_ids):
|
||||
""" Get the state groups for the given list of event_ids
|
||||
|
||||
The return value is a dict mapping group names to lists of events.
|
||||
Returns:
|
||||
Deferred[dict[int, list[EventBase]]]:
|
||||
dict of state_group_id -> list of state events.
|
||||
"""
|
||||
if not event_ids:
|
||||
defer.returnValue({})
|
||||
|
@ -324,7 +336,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
member events (if True), or to exclude member events (if False)
|
||||
|
||||
Returns:
|
||||
dictionary state_group -> (dict of (type, state_key) -> event id)
|
||||
Returns:
|
||||
Deferred[dict[int, dict[tuple[str, str], str]]]:
|
||||
dict of state_group_id -> (dict of (type, state_key) -> event id)
|
||||
"""
|
||||
results = {}
|
||||
|
||||
|
@ -732,8 +746,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
If None, `types` filtering is applied to all events.
|
||||
|
||||
Returns:
|
||||
Deferred[dict[int, dict[(type, state_key), EventBase]]]
|
||||
a dictionary mapping from state group to state dictionary.
|
||||
Deferred[dict[int, dict[tuple[str, str], str]]]:
|
||||
dict of state_group_id -> (dict of (type, state_key) -> event id)
|
||||
"""
|
||||
if types is not None:
|
||||
non_member_types = [t for t in types if t[0] != EventTypes.Member]
|
||||
|
@ -788,8 +802,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
If None, `types` filtering is applied to all events.
|
||||
|
||||
Returns:
|
||||
Deferred[dict[int, dict[(type, state_key), EventBase]]]
|
||||
a dictionary mapping from state group to state dictionary.
|
||||
Deferred[dict[int, dict[tuple[str, str], str]]]:
|
||||
dict of state_group_id -> (dict of (type, state_key) -> event id)
|
||||
"""
|
||||
if types:
|
||||
types = frozenset(types)
|
||||
|
|
|
@ -74,6 +74,45 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
|||
self.assertEqual(s1[t].event_id, s2[t].event_id)
|
||||
self.assertEqual(len(s1), len(s2))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_state_groups_ids(self):
|
||||
e1 = yield self.inject_state_event(
|
||||
self.room, self.u_alice, EventTypes.Create, '', {}
|
||||
)
|
||||
e2 = yield self.inject_state_event(
|
||||
self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"}
|
||||
)
|
||||
|
||||
state_group_map = yield self.store.get_state_groups_ids(self.room, [e2.event_id])
|
||||
self.assertEqual(len(state_group_map), 1)
|
||||
state_map = list(state_group_map.values())[0]
|
||||
self.assertDictEqual(
|
||||
state_map,
|
||||
{
|
||||
(EventTypes.Create, ''): e1.event_id,
|
||||
(EventTypes.Name, ''): e2.event_id,
|
||||
},
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_state_groups(self):
|
||||
e1 = yield self.inject_state_event(
|
||||
self.room, self.u_alice, EventTypes.Create, '', {}
|
||||
)
|
||||
e2 = yield self.inject_state_event(
|
||||
self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"}
|
||||
)
|
||||
|
||||
state_group_map = yield self.store.get_state_groups(
|
||||
self.room, [e2.event_id])
|
||||
self.assertEqual(len(state_group_map), 1)
|
||||
state_list = list(state_group_map.values())[0]
|
||||
|
||||
self.assertEqual(
|
||||
{ev.event_id for ev in state_list},
|
||||
{e1.event_id, e2.event_id},
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_state_for_event(self):
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue