0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-15 14:32:30 +01:00

Faster cache for get_joined_hosts

This commit is contained in:
Erik Johnston 2017-05-25 17:08:41 +01:00
parent 2b03751c3c
commit dfbda5e025
5 changed files with 117 additions and 29 deletions

View file

@ -187,6 +187,8 @@ class TransactionQueue(object):
prev_id for prev_id, _ in event.prev_events
],
)
destinations = set(destinations)
logger.info("destinations: %r", destinations)
if send_on_behalf_of is not None:
# If we are sending the event on behalf of another server

View file

@ -108,6 +108,8 @@ class SlavedEventStore(BaseSlavedStore):
get_current_state_ids = (
StateStore.__dict__["get_current_state_ids"]
)
get_state_group_delta = DataStore.get_state_group_delta.__func__
_get_joined_hosts_cache = RoomMemberStore.__dict__["_get_joined_hosts_cache"]
has_room_changed_since = DataStore.has_room_changed_since.__func__
get_unread_push_actions_for_user_in_range_for_http = (

View file

@ -170,9 +170,7 @@ class StateHandler(object):
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_user_in_room")
entry = yield self.resolve_state_groups(room_id, latest_event_ids)
joined_users = yield self.store.get_joined_users_from_state(
room_id, entry.state_id, entry.state
)
joined_users = yield self.store.get_joined_users_from_state(room_id, entry)
defer.returnValue(joined_users)
@defer.inlineCallbacks
@ -181,9 +179,9 @@ class StateHandler(object):
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_hosts_in_room")
entry = yield self.resolve_state_groups(room_id, latest_event_ids)
joined_hosts = yield self.store.get_joined_hosts(
room_id, entry.state_id, entry.state
)
logger.info("State: %r", entry.state_group)
joined_hosts = yield self.store.get_joined_hosts(room_id, entry)
logger.info("returning: %r", joined_hosts)
defer.returnValue(joined_hosts)
@defer.inlineCallbacks
@ -307,11 +305,13 @@ class StateHandler(object):
if len(group_names) == 1:
name, state_list = state_groups_ids.items().pop()
prev_group, delta_ids = yield self.store.get_state_group_delta(name)
defer.returnValue(_StateCacheEntry(
state=state_list,
state_group=name,
prev_group=None,
delta_ids=None,
prev_group=prev_group,
delta_ids=delta_ids,
))
with (yield self.resolve_linearizer.queue(group_names)):

View file

@ -18,6 +18,7 @@ from twisted.internet import defer
from collections import namedtuple
from ._base import SQLBaseStore
from synapse.util.async import Linearizer
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.util.stringutils import to_ascii
@ -392,7 +393,8 @@ class RoomMemberStore(SQLBaseStore):
context=context,
)
def get_joined_users_from_state(self, room_id, state_group, state_ids):
def get_joined_users_from_state(self, room_id, state_entry):
state_group = state_entry.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group
@ -401,7 +403,7 @@ class RoomMemberStore(SQLBaseStore):
state_group = object()
return self._get_joined_users_from_context(
room_id, state_group, state_ids,
room_id, state_group, state_entry.state, context=state_entry,
)
@cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True,
@ -534,7 +536,8 @@ class RoomMemberStore(SQLBaseStore):
defer.returnValue(False)
def get_joined_hosts(self, room_id, state_group, state_ids):
def get_joined_hosts(self, room_id, state_entry):
state_group = state_entry.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group
@ -543,33 +546,21 @@ class RoomMemberStore(SQLBaseStore):
state_group = object()
return self._get_joined_hosts(
room_id, state_group, state_ids
room_id, state_group, state_entry.state, state_entry=state_entry,
)
@cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True)
def _get_joined_hosts(self, room_id, state_group, current_state_ids):
# @defer.inlineCallbacks
def _get_joined_hosts(self, room_id, state_group, current_state_ids, state_entry):
# We don't use `state_group`, its there so that we can cache based
# on it. However, its important that its never None, since two current_state's
# with a state_group of None are likely to be different.
# See bulk_get_push_rules_for_room for how we work around this.
assert state_group is not None
joined_hosts = set()
for etype, state_key in current_state_ids:
if etype == EventTypes.Member:
try:
host = get_domain_from_id(state_key)
except:
logger.warn("state_key not user_id: %s", state_key)
continue
if host in joined_hosts:
continue
event_id = current_state_ids[(etype, state_key)]
event = yield self.get_event(event_id, allow_none=True)
if event and event.content["membership"] == Membership.JOIN:
joined_hosts.add(intern_string(host))
cache = self._get_joined_hosts_cache(room_id)
joined_hosts = yield cache.get_destinations(state_entry)
logger.info("returning: %r", joined_hosts)
defer.returnValue(joined_hosts)
@ -647,3 +638,63 @@ class RoomMemberStore(SQLBaseStore):
yield self._end_background_update(_MEMBERSHIP_PROFILE_UPDATE_NAME)
defer.returnValue(result)
@cached(max_entries=10000, iterable=True)
def _get_joined_hosts_cache(self, room_id):
return _JoinedHostsCache(self, room_id)
class _JoinedHostsCache(object):
def __init__(self, store, room_id):
self.store = store
self.room_id = room_id
self.hosts_to_joined_users = {}
self.state_group = object()
self.linearizer = Linearizer("_JoinedHostsCache")
self._len = 0
@defer.inlineCallbacks
def get_destinations(self, state_entry):
if state_entry.state_group == self.state_group:
defer.returnValue(frozenset(self.hosts_to_joined_users))
with (yield self.linearizer.queue(())):
if state_entry.state_group == self.state_group:
pass
elif state_entry.prev_group == self.state_group:
for (typ, state_key), event_id in state_entry.delta_ids.iteritems():
if typ != EventTypes.Member:
continue
host = intern_string(get_domain_from_id(state_key))
user_id = state_key
known_joins = self.hosts_to_joined_users.setdefault(host, set())
event = yield self.store.get_event(event_id)
if event.membership == Membership.JOIN:
known_joins.add(user_id)
else:
known_joins.discard(user_id)
if not known_joins:
self.hosts_to_joined_users.pop(host, None)
else:
joined_users = yield self.store.get_joined_users_from_state(
self.room_id, state_entry,
)
self.hosts_to_joined_users = {}
for user_id in joined_users:
host = intern_string(get_domain_from_id(user_id))
self.hosts_to_joined_users.setdefault(host, set()).add(user_id)
self.state_group = state_entry.state_group
self._len = sum(len(v) for v in self.hosts_to_joined_users.itervalues())
defer.returnValue(frozenset(self.hosts_to_joined_users))
def __len__(self):
return self._len

View file

@ -98,6 +98,39 @@ class StateStore(SQLBaseStore):
_get_current_state_ids_txn,
)
def get_state_group_delta(self, state_group):
def _get_state_group_delta_txn(txn):
prev_group = self._simple_select_one_onecol_txn(
txn,
table="state_group_edges",
keyvalues={
"state_group": state_group,
},
retcol="prev_state_group",
allow_none=True,
)
if not prev_group:
return None, None
delta_ids = self._simple_select_list_txn(
txn,
table="state_groups_state",
keyvalues={
"state_group": state_group,
},
retcols=("type", "state_key", "event_id",)
)
return prev_group, {
(row["type"], row["state_key"]): row["event_id"]
for row in delta_ids
}
return self.runInteraction(
"get_state_group_delta",
_get_state_group_delta_txn,
)
@defer.inlineCallbacks
def get_state_groups_ids(self, room_id, event_ids):
if not event_ids: