Split AccountDataStore and TagStore

This commit is contained in:
Erik Johnston 2018-02-16 12:08:42 +00:00
parent a2b25de68d
commit ca9b9d9703
4 changed files with 69 additions and 67 deletions

View file

@ -15,48 +15,18 @@
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore
from synapse.storage.account_data import AccountDataStore
from synapse.storage.tags import TagsStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.storage.account_data import AccountDataWorkerStore
from synapse.storage.tags import TagsWorkerStore
class SlavedAccountDataStore(BaseSlavedStore):
class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedAccountDataStore, self).__init__(db_conn, hs)
self._account_data_id_gen = SlavedIdTracker(
db_conn, "account_data_max_stream_id", "stream_id",
)
self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache",
self._account_data_id_gen.get_current_token(),
)
get_account_data_for_user = (
AccountDataStore.__dict__["get_account_data_for_user"]
)
get_global_account_data_by_type_for_users = (
AccountDataStore.__dict__["get_global_account_data_by_type_for_users"]
)
get_global_account_data_by_type_for_user = (
AccountDataStore.__dict__["get_global_account_data_by_type_for_user"]
)
get_tags_for_user = TagsStore.__dict__["get_tags_for_user"]
get_tags_for_room = (
DataStore.get_tags_for_room.__func__
)
get_account_data_for_room = (
DataStore.get_account_data_for_room.__func__
)
get_updated_tags = DataStore.get_updated_tags.__func__
get_updated_account_data_for_user = (
DataStore.get_updated_account_data_for_user.__func__
)
super(SlavedAccountDataStore, self).__init__(db_conn, hs)
def get_max_account_data_stream_id(self):
return self._account_data_id_gen.get_current_token()

View file

@ -104,9 +104,6 @@ class DataStore(RoomMemberStore, RoomStore,
db_conn, "events", "stream_ordering", step=-1,
extra_tables=[("ex_outlier_stream", "event_stream_ordering")]
)
self._account_data_id_gen = StreamIdGenerator(
db_conn, "account_data_max_stream_id", "stream_id"
)
self._presence_id_gen = StreamIdGenerator(
db_conn, "presence_stream", "stream_id"
)
@ -159,11 +156,6 @@ class DataStore(RoomMemberStore, RoomStore,
"MembershipStreamChangeCache", events_max,
)
account_max = self._account_data_id_gen.get_current_token()
self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache", account_max,
)
self._presence_on_startup = self._get_active_presence(db_conn)
presence_cache_prefill, min_presence_val = self._get_cache_dict(

View file

@ -13,18 +13,46 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import SQLBaseStore
from twisted.internet import defer
from ._base import SQLBaseStore
from .util.id_generators import StreamIdGenerator
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
import abc
import ujson as json
import logging
logger = logging.getLogger(__name__)
class AccountDataStore(SQLBaseStore):
class AccountDataWorkerStore(SQLBaseStore):
"""This is an abstract base class where subclasses must implement
`get_max_account_data_stream_id` which can be called in the initializer.
"""
# This ABCMeta metaclass ensures that we cannot be instantiated without
# the abstract methods being implemented.
__metaclass__ = abc.ABCMeta
def __init__(self, db_conn, hs):
account_max = self.get_max_account_data_stream_id()
self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache", account_max,
)
super(AccountDataWorkerStore, self).__init__(db_conn, hs)
@abc.abstractmethod
def get_max_account_data_stream_id(self):
"""Get the current max stream ID for account data stream
Returns:
int
"""
raise NotImplementedError()
@cached()
def get_account_data_for_user(self, user_id):
@ -209,6 +237,36 @@ class AccountDataStore(SQLBaseStore):
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn
)
@cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000)
def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context):
ignored_account_data = yield self.get_global_account_data_by_type_for_user(
"m.ignored_user_list", ignorer_user_id,
on_invalidate=cache_context.invalidate,
)
if not ignored_account_data:
defer.returnValue(False)
defer.returnValue(
ignored_user_id in ignored_account_data.get("ignored_users", {})
)
class AccountDataStore(AccountDataWorkerStore):
def __init__(self, db_conn, hs):
self._account_data_id_gen = StreamIdGenerator(
db_conn, "account_data_max_stream_id", "stream_id"
)
super(AccountDataStore, self).__init__(db_conn, hs)
def get_max_account_data_stream_id(self):
"""Get the current max stream id for the private user data stream
Returns:
A deferred int.
"""
return self._account_data_id_gen.get_current_token()
@defer.inlineCallbacks
def add_account_data_to_room(self, user_id, room_id, account_data_type, content):
"""Add some account_data to a room for a user.
@ -321,16 +379,3 @@ class AccountDataStore(SQLBaseStore):
"update_account_data_max_stream_id",
_update,
)
@cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000)
def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context):
ignored_account_data = yield self.get_global_account_data_by_type_for_user(
"m.ignored_user_list", ignorer_user_id,
on_invalidate=cache_context.invalidate,
)
if not ignored_account_data:
defer.returnValue(False)
defer.returnValue(
ignored_user_id in ignored_account_data.get("ignored_users", {})
)

View file

@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import SQLBaseStore
from .account_data import AccountDataWorkerStore
from synapse.util.caches.descriptors import cached
from twisted.internet import defer
@ -23,15 +24,7 @@ import logging
logger = logging.getLogger(__name__)
class TagsStore(SQLBaseStore):
def get_max_account_data_stream_id(self):
"""Get the current max stream id for the private user data stream
Returns:
A deferred int.
"""
return self._account_data_id_gen.get_current_token()
class TagsWorkerStore(AccountDataWorkerStore):
@cached()
def get_tags_for_user(self, user_id):
"""Get all the tags for a user.
@ -170,6 +163,8 @@ class TagsStore(SQLBaseStore):
row["tag"]: json.loads(row["content"]) for row in rows
})
class TagsStore(TagsWorkerStore):
@defer.inlineCallbacks
def add_tag_to_room(self, user_id, room_id, tag, content):
"""Add a tag to a room for a user.