0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2025-01-19 05:21:55 +01:00

Add a slaved receipts store

This commit is contained in:
Mark Haines 2016-04-19 17:11:44 +01:00
parent e8884e5e9c
commit 5bbd424ee0
4 changed files with 104 additions and 3 deletions

View file

@ -0,0 +1,61 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore
from synapse.storage.receipts import ReceiptsStore
# So, um, we want to borrow a load of functions intended for reading from
# a DataStore, but we don't want to take functions that either write to the
# DataStore or are cached and don't have cache invalidation logic.
#
# Rather than write duplicate versions of those functions, or lift them to
# a common base class, we going to grab the underlying __func__ object from
# the method descriptor on the DataStore and chuck them into our class.
class SlavedReceiptsStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedReceiptsStore, self).__init__(db_conn, hs)
self._receipts_id_gen = SlavedIdTracker(
db_conn, "receipts_linearized", "stream_id"
)
get_receipts_for_user = ReceiptsStore.__dict__["get_receipts_for_user"]
get_max_receipt_stream_id = DataStore.get_max_receipt_stream_id.__func__
get_all_updated_receipts = DataStore.get_all_updated_receipts.__func__
def stream_positions(self):
result = super(SlavedReceiptsStore, self).stream_positions()
result["receipts"] = self._receipts_id_gen.get_current_token()
return result
def process_replication(self, result):
stream = result.get("receipts")
if stream:
self._receipts_id_gen.advance(stream["position"])
for row in stream["rows"]:
room_id, receipt_type, user_id = row[1:4]
self.invalidate_caches_for_receipt(room_id, receipt_type, user_id)
return super(SlavedReceiptsStore, self).process_replication(result)
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
self.get_receipts_for_user.invalidate((user_id, receipt_type))

View file

@ -15,8 +15,6 @@
from twisted.internet import defer from twisted.internet import defer
from tests import unittest from tests import unittest
from synapse.replication.slave.storage.events import SlavedEventStore
from mock import Mock, NonCallableMock from mock import Mock, NonCallableMock
from tests.utils import setup_test_homeserver from tests.utils import setup_test_homeserver
from synapse.replication.resource import ReplicationResource from synapse.replication.resource import ReplicationResource
@ -38,7 +36,7 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
self.replication = ReplicationResource(self.hs) self.replication = ReplicationResource(self.hs)
self.master_store = self.hs.get_datastore() self.master_store = self.hs.get_datastore()
self.slaved_store = SlavedEventStore(self.hs.get_db_conn(), self.hs) self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs)
self.event_id = 0 self.event_id = 0
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -16,6 +16,7 @@ from ._base import BaseSlavedStoreTestCase
from synapse.events import FrozenEvent, _EventInternalMetadata from synapse.events import FrozenEvent, _EventInternalMetadata
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.storage.roommember import RoomsForUser from synapse.storage.roommember import RoomsForUser
from twisted.internet import defer from twisted.internet import defer
@ -43,6 +44,8 @@ def patch__eq__(cls):
class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
STORE_TYPE = SlavedEventStore
def setUp(self): def setUp(self):
# Patch up the equality operator for events so that we can check # Patch up the equality operator for events so that we can check
# whether lists of events match using assertEquals # whether lists of events match using assertEquals

View file

@ -0,0 +1,39 @@
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStoreTestCase
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from twisted.internet import defer
USER_ID = "@feeling:blue"
ROOM_ID = "!room:blue"
EVENT_ID = "$event:blue"
class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
STORE_TYPE = SlavedReceiptsStore
@defer.inlineCallbacks
def test_receipt(self):
yield self.check("get_receipts_for_user", [USER_ID, "m.read"], {})
yield self.master_store.insert_receipt(
ROOM_ID, "m.read", USER_ID, [EVENT_ID], {}
)
yield self.replicate()
yield self.check("get_receipts_for_user", [USER_ID, "m.read"], {
ROOM_ID: EVENT_ID
})