diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index e895b1c450..ec32008d5a 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -15,9 +15,25 @@ class EventContext(object): + __slots__ = [ + "current_state_ids", + "prev_state_ids", + "state_group", + "rejected", + "push_actions", + "prev_group", + "delta_ids", + "prev_state_events", + ] + def __init__(self): self.current_state_ids = None self.prev_state_ids = None self.state_group = None self.rejected = False self.push_actions = [] + + self.prev_group = None + self.delta_ids = None + + self.prev_state_events = None diff --git a/synapse/state.py b/synapse/state.py index b31bbcdbd2..cd428e83cd 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -54,12 +54,15 @@ def _gen_state_id(): class _StateCacheEntry(object): - __slots__ = ["state", "state_group", "state_id"] + __slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"] - def __init__(self, state, state_group): + def __init__(self, state, state_group, prev_group=None, delta_ids=None): self.state = state self.state_group = state_group + self.prev_group = prev_group + self.delta_ids = delta_ids + # The `state_id` is a unique ID we generate that can be used as ID for # this collection of state. Usually this would be the same as the # state group, but on worker instances we can't generate a new state @@ -243,11 +246,20 @@ class StateHandler(object): if key in context.prev_state_ids: replaces = context.prev_state_ids[key] event.unsigned["replaces_state"] = replaces + context.current_state_ids = dict(context.prev_state_ids) context.current_state_ids[key] = event.event_id + + context.prev_group = entry.prev_group + context.delta_ids = entry.delta_ids + if context.delta_ids is not None: + context.delta_ids[key] = event.event_id else: context.current_state_ids = context.prev_state_ids + context.prev_group = entry.prev_group + context.delta_ids = entry.delta_ids + context.prev_state_events = [] defer.returnValue(context) @@ -281,6 +293,8 @@ class StateHandler(object): defer.returnValue(_StateCacheEntry( state=state_list, state_group=name, + prev_group=name, + delta_ids={}, )) if self._state_cache is not None: @@ -330,6 +344,7 @@ class StateHandler(object): if new_state_event_ids == frozenset(e_id for e_id in events): state_group = sg break + if state_group is None: # Worker instances don't have access to this method, but we want # to set the state_group on the main instance to increase cache @@ -337,9 +352,24 @@ class StateHandler(object): if hasattr(self.store, "get_next_state_group"): state_group = self.store.get_next_state_group() + prev_group = None + delta_ids = None + for old_group, old_ids in state_groups_ids.items(): + if not set(new_state.iterkeys()) - set(old_ids.iterkeys()): + n_delta_ids = { + k: v + for k, v in new_state.items() + if old_ids.get(k) != v + } + if not delta_ids or len(n_delta_ids) < len(delta_ids): + prev_group = old_group + delta_ids = n_delta_ids + cache = _StateCacheEntry( state=new_state, state_group=state_group, + prev_group=prev_group, + delta_ids=delta_ids, ) if self._state_cache is not None: diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index b94ce7bea1..b1fbc4ffa5 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 34 +SCHEMA_VERSION = 35 dir_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/synapse/storage/schema/delta/35/state.sql b/synapse/storage/schema/delta/35/state.sql new file mode 100644 index 0000000000..c4c244c169 --- /dev/null +++ b/synapse/storage/schema/delta/35/state.sql @@ -0,0 +1,21 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE state_group_edges( + state_group BIGINT NOT NULL, + prev_state_group BIGINT NOT NULL +); + +CREATE INDEX state_group_edges_idx ON state_group_edges(state_group); diff --git a/synapse/storage/state.py b/synapse/storage/state.py index ec551b0b4f..73cebc7383 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -16,6 +16,7 @@ from ._base import SQLBaseStore from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches import intern_string +from synapse.storage.engines import PostgresEngine from twisted.internet import defer @@ -118,20 +119,45 @@ class StateStore(SQLBaseStore): }, ) - self._simple_insert_many_txn( - txn, - table="state_groups_state", - values=[ - { + if context.prev_group: + self._simple_insert_txn( + txn, + table="state_group_edges", + values={ "state_group": context.state_group, - "room_id": event.room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } - for key, state_id in state_event_ids.items() - ], - ) + "prev_state_group": context.prev_group, + }, + ) + + self._simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": context.state_group, + "room_id": event.room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in context.delta_ids.items() + ], + ) + else: + self._simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": context.state_group, + "room_id": event.room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in state_event_ids.items() + ], + ) self._simple_insert_many_txn( txn, @@ -214,26 +240,70 @@ class StateStore(SQLBaseStore): else: where_clause = "" - sql = ( - "SELECT state_group, event_id, type, state_key" - " FROM state_groups_state WHERE" - " state_group IN (%s) %s" % ( - ",".join("?" for _ in groups), - where_clause, - ) - ) - - args = list(groups) - if types is not None: - args.extend([i for typ in types for i in typ]) - - txn.execute(sql, args) - rows = self.cursor_to_dict(txn) - results = {group: {} for group in groups} - for row in rows: - key = (row["type"], row["state_key"]) - results[row["state_group"]][key] = row["event_id"] + if isinstance(self.database_engine, PostgresEngine): + sql = (""" + WITH RECURSIVE state(state_group) AS ( + VALUES(?::bigint) + UNION ALL + SELECT prev_state_group FROM state_group_edges e, state s + WHERE s.state_group = e.state_group + ) + SELECT type, state_key, event_id FROM state_groups_state + WHERE ROW(type, state_key, state_group) IN ( + SELECT type, state_key, max(state_group) FROM state + INNER JOIN state_groups_state USING (state_group) + GROUP BY type, state_key + ) + %s; + """) % (where_clause,) + + for group in groups: + args = [group] + if types is not None: + args.extend([i for typ in types for i in typ]) + + txn.execute(sql, args) + rows = self.cursor_to_dict(txn) + for row in rows: + key = (row["type"], row["state_key"]) + results[group][key] = row["event_id"] + else: + for group in groups: + group_tree = [group] + next_group = group + + while next_group: + next_group = self._simple_select_one_onecol_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": next_group}, + retcol="prev_state_group", + allow_none=True, + ) + if next_group: + group_tree.append(next_group) + + sql = (""" + SELECT type, state_key, event_id FROM state_groups_state + INNER JOIN ( + SELECT type, state_key, max(state_group) as state_group + FROM state_groups_state + WHERE state_group IN (%s) %s + GROUP BY type, state_key + ) USING (type, state_key, state_group); + """) % (",".join("?" for _ in group_tree), where_clause,) + + args = list(group_tree) + if types is not None: + args.extend([i for typ in types for i in typ]) + + txn.execute(sql, args) + rows = self.cursor_to_dict(txn) + for row in rows: + key = (row["type"], row["state_key"]) + results[group][key] = row["event_id"] + return results results = {} @@ -504,32 +574,5 @@ class StateStore(SQLBaseStore): defer.returnValue(results) - def get_all_new_state_groups(self, last_id, current_id, limit): - def get_all_new_state_groups_txn(txn): - sql = ( - "SELECT id, room_id, event_id FROM state_groups" - " WHERE ? < id AND id <= ? ORDER BY id LIMIT ?" - ) - txn.execute(sql, (last_id, current_id, limit)) - groups = txn.fetchall() - - if not groups: - return ([], []) - - lower_bound = groups[0][0] - upper_bound = groups[-1][0] - sql = ( - "SELECT state_group, type, state_key, event_id" - " FROM state_groups_state" - " WHERE ? <= state_group AND state_group <= ?" - ) - - txn.execute(sql, (lower_bound, upper_bound)) - state_group_state = txn.fetchall() - return (groups, state_group_state) - return self.runInteraction( - "get_all_new_state_groups", get_all_new_state_groups_txn - ) - def get_next_state_group(self): return self._state_groups_id_gen.get_next()