diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 2f3a70b4e5..55ea567793 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -14,43 +14,71 @@ # limitations under the License. from ._base import SQLBaseStore -from twisted.internet import defer class StateStore(SQLBaseStore): + """ Keeps track of the state at a given event. + + This is done by the concept of `state groups`. Every event is a assigned + a state group (identified by an arbitrary string), which references a + collection of state events. The current state of an event is then the + collection of state events referenced by the event's state group. + + Hence, every change in the current state causes a new state group to be + generated. However, if no change happens (e.g., if we get a message event + with only one parent it inherits the state group from its parent.) + + There are three tables: + * `state_groups`: Stores group name, first event with in the group and + room id. + * `event_to_state_groups`: Maps events to state groups. + * `state_groups_state`: Maps state group to state events. + """ - @defer.inlineCallbacks def get_state_groups(self, event_ids): - groups = set() - for event_id in event_ids: - group = yield self._simple_select_one_onecol( - table="event_to_state_groups", - keyvalues={"event_id": event_id}, - retcol="state_group", - allow_none=True, - ) - if group: - groups.add(group) + """ Get the state groups for the given list of event_ids - res = {} - for group in groups: - state_ids = yield self._simple_select_onecol( - table="state_groups_state", - keyvalues={"state_group": group}, - retcol="event_id", - ) - state = [] - for state_id in state_ids: - s = yield self.get_event( - state_id, + The return value is a dict mapping group names to lists of events. + """ + + def f(txn): + groups = set() + for event_id in event_ids: + group = self._simple_select_one_onecol_txn( + txn, + table="event_to_state_groups", + keyvalues={"event_id": event_id}, + retcol="state_group", allow_none=True, ) - if s: - state.append(s) + if group: + groups.add(group) - res[group] = state + res = {} + for group in groups: + state_ids = self._simple_select_onecol_txn( + txn, + table="state_groups_state", + keyvalues={"state_group": group}, + retcol="event_id", + ) + state = [] + for state_id in state_ids: + s = self._get_events_txn( + txn, + [state_id], + ) + if s: + state.extend(s) - defer.returnValue(res) + res[group] = state + + return res + + return self.runInteraction( + "get_state_groups", + f, + ) def store_state_groups(self, event): return self.runInteraction(