Linearize state resolution to help caches

This commit is contained in:
Erik Johnston 2016-09-01 14:55:03 +01:00
parent 265d847ffd
commit 051a9ea921

View file

@ -23,6 +23,7 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.api.auth import AuthEventTypes from synapse.api.auth import AuthEventTypes
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.util.async import Linearizer
from collections import namedtuple from collections import namedtuple
@ -84,6 +85,7 @@ class StateHandler(object):
# dict of set of event_ids -> _StateCacheEntry. # dict of set of event_ids -> _StateCacheEntry.
self._state_cache = None self._state_cache = None
self.resolve_linearizer = Linearizer()
def start_caching(self): def start_caching(self):
logger.debug("start_caching") logger.debug("start_caching")
@ -283,69 +285,70 @@ class StateHandler(object):
state_group=name, state_group=name,
)) ))
if self._state_cache is not None: with (yield self.resolve_linearizer.queue(group_names)):
cache = self._state_cache.get(group_names, None) if self._state_cache is not None:
if cache: cache = self._state_cache.get(group_names, None)
defer.returnValue(cache) if cache:
defer.returnValue(cache)
logger.info( logger.info(
"Resolving state for %s with %d groups", room_id, len(state_groups_ids) "Resolving state for %s with %d groups", room_id, len(state_groups_ids)
)
state = {}
for st in state_groups_ids.values():
for key, e_id in st.items():
state.setdefault(key, set()).add(e_id)
conflicted_state = {
k: list(v)
for k, v in state.items()
if len(v) > 1
}
if conflicted_state:
logger.info("Resolving conflicted state for %r", room_id)
state_map = yield self.store.get_events(
[e_id for st in state_groups_ids.values() for e_id in st.values()],
get_prev_content=False
) )
state_sets = [
[state_map[e_id] for key, e_id in st.items() if e_id in state_map] state = {}
for st in state_groups_ids.values() for st in state_groups_ids.values():
] for key, e_id in st.items():
new_state, _ = self._resolve_events( state.setdefault(key, set()).add(e_id)
state_sets, event_type, state_key
) conflicted_state = {
new_state = { k: list(v)
key: e.event_id for key, e in new_state.items() for k, v in state.items()
} if len(v) > 1
else:
new_state = {
key: e_ids.pop() for key, e_ids in state.items()
} }
state_group = None if conflicted_state:
new_state_event_ids = frozenset(new_state.values()) logger.info("Resolving conflicted state for %r", room_id)
for sg, events in state_groups_ids.items(): state_map = yield self.store.get_events(
if new_state_event_ids == frozenset(e_id for e_id in events): [e_id for st in state_groups_ids.values() for e_id in st.values()],
state_group = sg get_prev_content=False
break )
if state_group is None: state_sets = [
# Worker instances don't have access to this method, but we want [state_map[e_id] for key, e_id in st.items() if e_id in state_map]
# to set the state_group on the main instance to increase cache for st in state_groups_ids.values()
# hits. ]
if hasattr(self.store, "get_next_state_group"): new_state, _ = self._resolve_events(
state_group = self.store.get_next_state_group() state_sets, event_type, state_key
)
new_state = {
key: e.event_id for key, e in new_state.items()
}
else:
new_state = {
key: e_ids.pop() for key, e_ids in state.items()
}
cache = _StateCacheEntry( state_group = None
state=new_state, new_state_event_ids = frozenset(new_state.values())
state_group=state_group, for sg, events in state_groups_ids.items():
) if new_state_event_ids == frozenset(e_id for e_id in events):
state_group = sg
break
if state_group is None:
# Worker instances don't have access to this method, but we want
# to set the state_group on the main instance to increase cache
# hits.
if hasattr(self.store, "get_next_state_group"):
state_group = self.store.get_next_state_group()
if self._state_cache is not None: cache = _StateCacheEntry(
self._state_cache[group_names] = cache state=new_state,
state_group=state_group,
)
defer.returnValue(cache) if self._state_cache is not None:
self._state_cache[group_names] = cache
defer.returnValue(cache)
def resolve_events(self, state_sets, event): def resolve_events(self, state_sets, event):
logger.info( logger.info(