0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-06-02 18:59:04 +02:00

Split PushRulesStore

This commit is contained in:
Erik Johnston 2018-02-21 10:39:27 +00:00
parent a2b25de68d
commit cbaad969f9
3 changed files with 61 additions and 45 deletions

View file

@ -15,29 +15,15 @@
from .events import SlavedEventStore
from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore
from synapse.storage.push_rule import PushRuleStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.storage.push_rule import PushRulesWorkerStore
class SlavedPushRuleStore(SlavedEventStore):
class SlavedPushRuleStore(PushRulesWorkerStore, SlavedEventStore):
def __init__(self, db_conn, hs):
super(SlavedPushRuleStore, self).__init__(db_conn, hs)
self._push_rules_stream_id_gen = SlavedIdTracker(
db_conn, "push_rules_stream", "stream_id",
)
self.push_rules_stream_cache = StreamChangeCache(
"PushRulesStreamChangeCache",
self._push_rules_stream_id_gen.get_current_token(),
)
get_push_rules_for_user = PushRuleStore.__dict__["get_push_rules_for_user"]
get_push_rules_enabled_for_user = (
PushRuleStore.__dict__["get_push_rules_enabled_for_user"]
)
have_push_rules_changed_for_user = (
DataStore.have_push_rules_changed_for_user.__func__
)
super(SlavedPushRuleStore, self).__init__(db_conn, hs)
def get_push_rules_stream_token(self):
return (
@ -45,6 +31,9 @@ class SlavedPushRuleStore(SlavedEventStore):
self._stream_id_gen.get_current_token(),
)
def get_max_push_rules_stream_id(self):
return self._push_rules_stream_id_gen.get_current_token()
def stream_positions(self):
result = super(SlavedPushRuleStore, self).stream_positions()
result["push_rules"] = self._push_rules_stream_id_gen.get_current_token()

View file

@ -177,18 +177,6 @@ class DataStore(RoomMemberStore, RoomStore,
prefilled_cache=presence_cache_prefill
)
push_rules_prefill, push_rules_id = self._get_cache_dict(
db_conn, "push_rules_stream",
entity_column="user_id",
stream_column="stream_id",
max_value=self._push_rules_stream_id_gen.get_current_token()[0],
)
self.push_rules_stream_cache = StreamChangeCache(
"PushRulesStreamChangeCache", push_rules_id,
prefilled_cache=push_rules_prefill,
)
max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
device_inbox_prefill, min_device_inbox_id = self._get_cache_dict(
db_conn, "device_inbox",

View file

@ -15,10 +15,12 @@
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.push.baserules import list_with_base_rules
from synapse.api.constants import EventTypes
from twisted.internet import defer
import abc
import logging
import simplejson as json
@ -48,7 +50,39 @@ def _load_rules(rawrules, enabled_map):
return rules
class PushRuleStore(SQLBaseStore):
class PushRulesWorkerStore(SQLBaseStore):
"""This is an abstract base class where subclasses must implement
`get_max_push_rules_stream_id` which can be called in the initializer.
"""
# This ABCMeta metaclass ensures that we cannot be instantiated without
# the abstract methods being implemented.
__metaclass__ = abc.ABCMeta
def __init__(self, db_conn, hs):
super(PushRulesWorkerStore, self).__init__(db_conn, hs)
push_rules_prefill, push_rules_id = self._get_cache_dict(
db_conn, "push_rules_stream",
entity_column="user_id",
stream_column="stream_id",
max_value=self.get_max_push_rules_stream_id(),
)
self.push_rules_stream_cache = StreamChangeCache(
"PushRulesStreamChangeCache", push_rules_id,
prefilled_cache=push_rules_prefill,
)
@abc.abstractmethod
def get_max_push_rules_stream_id(self):
"""Get the position of the push rules stream.
Returns:
int
"""
raise NotImplementedError()
@cachedInlineCallbacks(max_entries=5000)
def get_push_rules_for_user(self, user_id):
rows = yield self._simple_select_list(
@ -89,6 +123,24 @@ class PushRuleStore(SQLBaseStore):
r['rule_id']: False if r['enabled'] == 0 else True for r in results
})
def have_push_rules_changed_for_user(self, user_id, last_id):
if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
return defer.succeed(False)
else:
def have_push_rules_changed_txn(txn):
sql = (
"SELECT COUNT(stream_id) FROM push_rules_stream"
" WHERE user_id = ? AND ? < stream_id"
)
txn.execute(sql, (user_id, last_id))
count, = txn.fetchone()
return bool(count)
return self.runInteraction(
"have_push_rules_changed", have_push_rules_changed_txn
)
class PushRuleStore(PushRulesWorkerStore):
@cachedList(cached_method_name="get_push_rules_for_user",
list_name="user_ids", num_args=1, inlineCallbacks=True)
def bulk_get_push_rules(self, user_ids):
@ -526,21 +578,8 @@ class PushRuleStore(SQLBaseStore):
room stream ordering it corresponds to."""
return self._push_rules_stream_id_gen.get_current_token()
def have_push_rules_changed_for_user(self, user_id, last_id):
if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
return defer.succeed(False)
else:
def have_push_rules_changed_txn(txn):
sql = (
"SELECT COUNT(stream_id) FROM push_rules_stream"
" WHERE user_id = ? AND ? < stream_id"
)
txn.execute(sql, (user_id, last_id))
count, = txn.fetchone()
return bool(count)
return self.runInteraction(
"have_push_rules_changed", have_push_rules_changed_txn
)
def get_max_push_rules_stream_id(self):
return self.get_push_rules_stream_token()[0]
class RuleNotFoundException(Exception):