forked from MirrorHub/synapse
Use stream.current_token()
and remove stream_positions()
(#7172)
We move the processing of typing and federation replication traffic into their handlers so that `Stream.current_token()` points to a valid token. This allows us to remove `get_streams_to_replicate()` and `stream_positions()`.
This commit is contained in:
parent
6b22921b19
commit
3085cde577
19 changed files with 30 additions and 154 deletions
changelog.d
synapse
app
replication
tests/replication/tcp/streams
1
changelog.d/7172.misc
Normal file
1
changelog.d/7172.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Use `stream.current_token()` and remove `stream_positions()`.
|
|
@ -413,12 +413,6 @@ class GenericWorkerTyping(object):
|
||||||
# map room IDs to sets of users currently typing
|
# map room IDs to sets of users currently typing
|
||||||
self._room_typing = {}
|
self._room_typing = {}
|
||||||
|
|
||||||
def stream_positions(self):
|
|
||||||
# We must update this typing token from the response of the previous
|
|
||||||
# sync. In particular, the stream id may "reset" back to zero/a low
|
|
||||||
# value which we *must* use for the next replication request.
|
|
||||||
return {"typing": self._latest_room_serial}
|
|
||||||
|
|
||||||
def process_replication_rows(self, token, rows):
|
def process_replication_rows(self, token, rows):
|
||||||
if self._latest_room_serial > token:
|
if self._latest_room_serial > token:
|
||||||
# The master has gone backwards. To prevent inconsistent data, just
|
# The master has gone backwards. To prevent inconsistent data, just
|
||||||
|
@ -658,13 +652,6 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler):
|
||||||
)
|
)
|
||||||
await self.process_and_notify(stream_name, token, rows)
|
await self.process_and_notify(stream_name, token, rows)
|
||||||
|
|
||||||
def get_streams_to_replicate(self):
|
|
||||||
args = super(GenericWorkerReplicationHandler, self).get_streams_to_replicate()
|
|
||||||
args.update(self.typing_handler.stream_positions())
|
|
||||||
if self.send_handler:
|
|
||||||
args.update(self.send_handler.stream_positions())
|
|
||||||
return args
|
|
||||||
|
|
||||||
async def process_and_notify(self, stream_name, token, rows):
|
async def process_and_notify(self, stream_name, token, rows):
|
||||||
try:
|
try:
|
||||||
if self.send_handler:
|
if self.send_handler:
|
||||||
|
@ -799,9 +786,6 @@ class FederationSenderHandler(object):
|
||||||
def wake_destination(self, server: str):
|
def wake_destination(self, server: str):
|
||||||
self.federation_sender.wake_destination(server)
|
self.federation_sender.wake_destination(server)
|
||||||
|
|
||||||
def stream_positions(self):
|
|
||||||
return {"federation": self.federation_position}
|
|
||||||
|
|
||||||
async def process_replication_rows(self, stream_name, token, rows):
|
async def process_replication_rows(self, stream_name, token, rows):
|
||||||
# The federation stream contains things that we want to send out, e.g.
|
# The federation stream contains things that we want to send out, e.g.
|
||||||
# presence, typing, etc.
|
# presence, typing, etc.
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
|
@ -49,19 +49,6 @@ class BaseSlavedStore(CacheInvalidationWorkerStore):
|
||||||
|
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
|
|
||||||
def stream_positions(self) -> Dict[str, int]:
|
|
||||||
"""
|
|
||||||
Get the current positions of all the streams this store wants to subscribe to
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
map from stream name to the most recent update we have for
|
|
||||||
that stream (ie, the point we want to start replicating from)
|
|
||||||
"""
|
|
||||||
pos = {}
|
|
||||||
if self._cache_id_gen:
|
|
||||||
pos["caches"] = self._cache_id_gen.get_current_token()
|
|
||||||
return pos
|
|
||||||
|
|
||||||
def get_cache_stream_token(self):
|
def get_cache_stream_token(self):
|
||||||
if self._cache_id_gen:
|
if self._cache_id_gen:
|
||||||
return self._cache_id_gen.get_current_token()
|
return self._cache_id_gen.get_current_token()
|
||||||
|
|
|
@ -32,14 +32,6 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
|
||||||
def get_max_account_data_stream_id(self):
|
def get_max_account_data_stream_id(self):
|
||||||
return self._account_data_id_gen.get_current_token()
|
return self._account_data_id_gen.get_current_token()
|
||||||
|
|
||||||
def stream_positions(self):
|
|
||||||
result = super(SlavedAccountDataStore, self).stream_positions()
|
|
||||||
position = self._account_data_id_gen.get_current_token()
|
|
||||||
result["user_account_data"] = position
|
|
||||||
result["room_account_data"] = position
|
|
||||||
result["tag_account_data"] = position
|
|
||||||
return result
|
|
||||||
|
|
||||||
def process_replication_rows(self, stream_name, token, rows):
|
def process_replication_rows(self, stream_name, token, rows):
|
||||||
if stream_name == "tag_account_data":
|
if stream_name == "tag_account_data":
|
||||||
self._account_data_id_gen.advance(token)
|
self._account_data_id_gen.advance(token)
|
||||||
|
|
|
@ -43,11 +43,6 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
|
||||||
expiry_ms=30 * 60 * 1000,
|
expiry_ms=30 * 60 * 1000,
|
||||||
)
|
)
|
||||||
|
|
||||||
def stream_positions(self):
|
|
||||||
result = super(SlavedDeviceInboxStore, self).stream_positions()
|
|
||||||
result["to_device"] = self._device_inbox_id_gen.get_current_token()
|
|
||||||
return result
|
|
||||||
|
|
||||||
def process_replication_rows(self, stream_name, token, rows):
|
def process_replication_rows(self, stream_name, token, rows):
|
||||||
if stream_name == "to_device":
|
if stream_name == "to_device":
|
||||||
self._device_inbox_id_gen.advance(token)
|
self._device_inbox_id_gen.advance(token)
|
||||||
|
|
|
@ -48,16 +48,6 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
|
||||||
"DeviceListFederationStreamChangeCache", device_list_max
|
"DeviceListFederationStreamChangeCache", device_list_max
|
||||||
)
|
)
|
||||||
|
|
||||||
def stream_positions(self):
|
|
||||||
result = super(SlavedDeviceStore, self).stream_positions()
|
|
||||||
# The user signature stream uses the same stream ID generator as the
|
|
||||||
# device list stream, so set them both to the device list ID
|
|
||||||
# generator's current token.
|
|
||||||
current_token = self._device_list_id_gen.get_current_token()
|
|
||||||
result[DeviceListsStream.NAME] = current_token
|
|
||||||
result[UserSignatureStream.NAME] = current_token
|
|
||||||
return result
|
|
||||||
|
|
||||||
def process_replication_rows(self, stream_name, token, rows):
|
def process_replication_rows(self, stream_name, token, rows):
|
||||||
if stream_name == DeviceListsStream.NAME:
|
if stream_name == DeviceListsStream.NAME:
|
||||||
self._device_list_id_gen.advance(token)
|
self._device_list_id_gen.advance(token)
|
||||||
|
|
|
@ -93,12 +93,6 @@ class SlavedEventStore(
|
||||||
def get_room_min_stream_ordering(self):
|
def get_room_min_stream_ordering(self):
|
||||||
return self._backfill_id_gen.get_current_token()
|
return self._backfill_id_gen.get_current_token()
|
||||||
|
|
||||||
def stream_positions(self):
|
|
||||||
result = super(SlavedEventStore, self).stream_positions()
|
|
||||||
result["events"] = self._stream_id_gen.get_current_token()
|
|
||||||
result["backfill"] = -self._backfill_id_gen.get_current_token()
|
|
||||||
return result
|
|
||||||
|
|
||||||
def process_replication_rows(self, stream_name, token, rows):
|
def process_replication_rows(self, stream_name, token, rows):
|
||||||
if stream_name == "events":
|
if stream_name == "events":
|
||||||
self._stream_id_gen.advance(token)
|
self._stream_id_gen.advance(token)
|
||||||
|
|
|
@ -37,11 +37,6 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
|
||||||
def get_group_stream_token(self):
|
def get_group_stream_token(self):
|
||||||
return self._group_updates_id_gen.get_current_token()
|
return self._group_updates_id_gen.get_current_token()
|
||||||
|
|
||||||
def stream_positions(self):
|
|
||||||
result = super(SlavedGroupServerStore, self).stream_positions()
|
|
||||||
result["groups"] = self._group_updates_id_gen.get_current_token()
|
|
||||||
return result
|
|
||||||
|
|
||||||
def process_replication_rows(self, stream_name, token, rows):
|
def process_replication_rows(self, stream_name, token, rows):
|
||||||
if stream_name == "groups":
|
if stream_name == "groups":
|
||||||
self._group_updates_id_gen.advance(token)
|
self._group_updates_id_gen.advance(token)
|
||||||
|
|
|
@ -41,15 +41,6 @@ class SlavedPresenceStore(BaseSlavedStore):
|
||||||
def get_current_presence_token(self):
|
def get_current_presence_token(self):
|
||||||
return self._presence_id_gen.get_current_token()
|
return self._presence_id_gen.get_current_token()
|
||||||
|
|
||||||
def stream_positions(self):
|
|
||||||
result = super(SlavedPresenceStore, self).stream_positions()
|
|
||||||
|
|
||||||
if self.hs.config.use_presence:
|
|
||||||
position = self._presence_id_gen.get_current_token()
|
|
||||||
result["presence"] = position
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def process_replication_rows(self, stream_name, token, rows):
|
def process_replication_rows(self, stream_name, token, rows):
|
||||||
if stream_name == "presence":
|
if stream_name == "presence":
|
||||||
self._presence_id_gen.advance(token)
|
self._presence_id_gen.advance(token)
|
||||||
|
|
|
@ -37,11 +37,6 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
|
||||||
def get_max_push_rules_stream_id(self):
|
def get_max_push_rules_stream_id(self):
|
||||||
return self._push_rules_stream_id_gen.get_current_token()
|
return self._push_rules_stream_id_gen.get_current_token()
|
||||||
|
|
||||||
def stream_positions(self):
|
|
||||||
result = super(SlavedPushRuleStore, self).stream_positions()
|
|
||||||
result["push_rules"] = self._push_rules_stream_id_gen.get_current_token()
|
|
||||||
return result
|
|
||||||
|
|
||||||
def process_replication_rows(self, stream_name, token, rows):
|
def process_replication_rows(self, stream_name, token, rows):
|
||||||
if stream_name == "push_rules":
|
if stream_name == "push_rules":
|
||||||
self._push_rules_stream_id_gen.advance(token)
|
self._push_rules_stream_id_gen.advance(token)
|
||||||
|
|
|
@ -28,11 +28,6 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
|
||||||
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
|
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
|
||||||
)
|
)
|
||||||
|
|
||||||
def stream_positions(self):
|
|
||||||
result = super(SlavedPusherStore, self).stream_positions()
|
|
||||||
result["pushers"] = self._pushers_id_gen.get_current_token()
|
|
||||||
return result
|
|
||||||
|
|
||||||
def get_pushers_stream_token(self):
|
def get_pushers_stream_token(self):
|
||||||
return self._pushers_id_gen.get_current_token()
|
return self._pushers_id_gen.get_current_token()
|
||||||
|
|
||||||
|
|
|
@ -42,11 +42,6 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
|
||||||
def get_max_receipt_stream_id(self):
|
def get_max_receipt_stream_id(self):
|
||||||
return self._receipts_id_gen.get_current_token()
|
return self._receipts_id_gen.get_current_token()
|
||||||
|
|
||||||
def stream_positions(self):
|
|
||||||
result = super(SlavedReceiptsStore, self).stream_positions()
|
|
||||||
result["receipts"] = self._receipts_id_gen.get_current_token()
|
|
||||||
return result
|
|
||||||
|
|
||||||
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
|
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
|
||||||
self.get_receipts_for_user.invalidate((user_id, receipt_type))
|
self.get_receipts_for_user.invalidate((user_id, receipt_type))
|
||||||
self._get_linearized_receipts_for_room.invalidate_many((room_id,))
|
self._get_linearized_receipts_for_room.invalidate_many((room_id,))
|
||||||
|
|
|
@ -30,11 +30,6 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore):
|
||||||
def get_current_public_room_stream_id(self):
|
def get_current_public_room_stream_id(self):
|
||||||
return self._public_room_id_gen.get_current_token()
|
return self._public_room_id_gen.get_current_token()
|
||||||
|
|
||||||
def stream_positions(self):
|
|
||||||
result = super(RoomStore, self).stream_positions()
|
|
||||||
result["public_rooms"] = self._public_room_id_gen.get_current_token()
|
|
||||||
return result
|
|
||||||
|
|
||||||
def process_replication_rows(self, stream_name, token, rows):
|
def process_replication_rows(self, stream_name, token, rows):
|
||||||
if stream_name == "public_rooms":
|
if stream_name == "public_rooms":
|
||||||
self._public_room_id_gen.advance(token)
|
self._public_room_id_gen.advance(token)
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Dict
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from twisted.internet.protocol import ReconnectingClientFactory
|
from twisted.internet.protocol import ReconnectingClientFactory
|
||||||
|
|
||||||
|
@ -100,23 +100,6 @@ class ReplicationDataHandler:
|
||||||
"""
|
"""
|
||||||
self.store.process_replication_rows(stream_name, token, rows)
|
self.store.process_replication_rows(stream_name, token, rows)
|
||||||
|
|
||||||
def get_streams_to_replicate(self) -> Dict[str, int]:
|
|
||||||
"""Called when a new connection has been established and we need to
|
|
||||||
subscribe to streams.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
map from stream name to the most recent update we have for
|
|
||||||
that stream (ie, the point we want to start replicating from)
|
|
||||||
"""
|
|
||||||
args = self.store.stream_positions()
|
|
||||||
user_account_data = args.pop("user_account_data", None)
|
|
||||||
room_account_data = args.pop("room_account_data", None)
|
|
||||||
if user_account_data:
|
|
||||||
args["account_data"] = user_account_data
|
|
||||||
elif room_account_data:
|
|
||||||
args["account_data"] = room_account_data
|
|
||||||
return args
|
|
||||||
|
|
||||||
async def on_position(self, stream_name: str, token: int):
|
async def on_position(self, stream_name: str, token: int):
|
||||||
self.store.process_replication_rows(stream_name, token, [])
|
self.store.process_replication_rows(stream_name, token, [])
|
||||||
|
|
||||||
|
|
|
@ -314,15 +314,7 @@ class ReplicationCommandHandler:
|
||||||
self._pending_batches.pop(cmd.stream_name, [])
|
self._pending_batches.pop(cmd.stream_name, [])
|
||||||
|
|
||||||
# Find where we previously streamed up to.
|
# Find where we previously streamed up to.
|
||||||
current_token = self._replication_data_handler.get_streams_to_replicate().get(
|
current_token = stream.current_token()
|
||||||
cmd.stream_name
|
|
||||||
)
|
|
||||||
if current_token is None:
|
|
||||||
logger.warning(
|
|
||||||
"Got POSITION for stream we're not subscribed to: %s",
|
|
||||||
cmd.stream_name,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# If the position token matches our current token then we're up to
|
# If the position token matches our current token then we're up to
|
||||||
# date and there's nothing to do. Otherwise, fetch all updates
|
# date and there's nothing to do. Otherwise, fetch all updates
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, List, Optional, Tuple
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
|
@ -22,13 +22,15 @@ from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
|
||||||
from twisted.internet.task import LoopingCall
|
from twisted.internet.task import LoopingCall
|
||||||
from twisted.web.http import HTTPChannel
|
from twisted.web.http import HTTPChannel
|
||||||
|
|
||||||
from synapse.app.generic_worker import GenericWorkerServer
|
from synapse.app.generic_worker import (
|
||||||
|
GenericWorkerReplicationHandler,
|
||||||
|
GenericWorkerServer,
|
||||||
|
)
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
|
||||||
from synapse.replication.tcp.client import ReplicationDataHandler
|
|
||||||
from synapse.replication.tcp.handler import ReplicationCommandHandler
|
from synapse.replication.tcp.handler import ReplicationCommandHandler
|
||||||
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
|
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
|
||||||
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
|
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
|
||||||
|
from synapse.server import HomeServer
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
@ -77,7 +79,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
||||||
self._server_transport = None
|
self._server_transport = None
|
||||||
|
|
||||||
def _build_replication_data_handler(self):
|
def _build_replication_data_handler(self):
|
||||||
return TestReplicationDataHandler(self.worker_hs.get_datastore())
|
return TestReplicationDataHandler(self.worker_hs)
|
||||||
|
|
||||||
def reconnect(self):
|
def reconnect(self):
|
||||||
if self._client_transport:
|
if self._client_transport:
|
||||||
|
@ -172,32 +174,20 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(request.method, b"GET")
|
self.assertEqual(request.method, b"GET")
|
||||||
|
|
||||||
|
|
||||||
class TestReplicationDataHandler(ReplicationDataHandler):
|
class TestReplicationDataHandler(GenericWorkerReplicationHandler):
|
||||||
"""Drop-in for ReplicationDataHandler which just collects RDATA rows"""
|
"""Drop-in for ReplicationDataHandler which just collects RDATA rows"""
|
||||||
|
|
||||||
def __init__(self, store: BaseSlavedStore):
|
def __init__(self, hs: HomeServer):
|
||||||
super().__init__(store)
|
super().__init__(hs)
|
||||||
|
|
||||||
# streams to subscribe to: map from stream id to position
|
|
||||||
self.stream_positions = {} # type: Dict[str, int]
|
|
||||||
|
|
||||||
# list of received (stream_name, token, row) tuples
|
# list of received (stream_name, token, row) tuples
|
||||||
self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]]
|
self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]]
|
||||||
|
|
||||||
def get_streams_to_replicate(self):
|
|
||||||
return self.stream_positions
|
|
||||||
|
|
||||||
async def on_rdata(self, stream_name, token, rows):
|
async def on_rdata(self, stream_name, token, rows):
|
||||||
await super().on_rdata(stream_name, token, rows)
|
await super().on_rdata(stream_name, token, rows)
|
||||||
for r in rows:
|
for r in rows:
|
||||||
self.received_rdata_rows.append((stream_name, token, r))
|
self.received_rdata_rows.append((stream_name, token, r))
|
||||||
|
|
||||||
if (
|
|
||||||
stream_name in self.stream_positions
|
|
||||||
and token > self.stream_positions[stream_name]
|
|
||||||
):
|
|
||||||
self.stream_positions[stream_name] = token
|
|
||||||
|
|
||||||
|
|
||||||
@attr.s()
|
@attr.s()
|
||||||
class OneShotRequestFactory:
|
class OneShotRequestFactory:
|
||||||
|
|
|
@ -43,7 +43,6 @@ class EventsStreamTestCase(BaseStreamTestCase):
|
||||||
self.user_tok = self.login("u1", "pass")
|
self.user_tok = self.login("u1", "pass")
|
||||||
|
|
||||||
self.reconnect()
|
self.reconnect()
|
||||||
self.test_handler.stream_positions["events"] = 0
|
|
||||||
|
|
||||||
self.room_id = self.helper.create_room_as(tok=self.user_tok)
|
self.room_id = self.helper.create_room_as(tok=self.user_tok)
|
||||||
self.test_handler.received_rdata_rows.clear()
|
self.test_handler.received_rdata_rows.clear()
|
||||||
|
@ -80,8 +79,12 @@ class EventsStreamTestCase(BaseStreamTestCase):
|
||||||
self.reconnect()
|
self.reconnect()
|
||||||
self.replicate()
|
self.replicate()
|
||||||
|
|
||||||
# we should have received all the expected rows in the right order
|
# we should have received all the expected rows in the right order (as
|
||||||
received_rows = self.test_handler.received_rdata_rows
|
# well as various cache invalidation updates which we ignore)
|
||||||
|
received_rows = [
|
||||||
|
row for row in self.test_handler.received_rdata_rows if row[0] == "events"
|
||||||
|
]
|
||||||
|
|
||||||
for event in events:
|
for event in events:
|
||||||
stream_name, token, row = received_rows.pop(0)
|
stream_name, token, row = received_rows.pop(0)
|
||||||
self.assertEqual("events", stream_name)
|
self.assertEqual("events", stream_name)
|
||||||
|
@ -184,7 +187,8 @@ class EventsStreamTestCase(BaseStreamTestCase):
|
||||||
self.reconnect()
|
self.reconnect()
|
||||||
self.replicate()
|
self.replicate()
|
||||||
|
|
||||||
# now we should have received all the expected rows in the right order.
|
# we should have received all the expected rows in the right order (as
|
||||||
|
# well as various cache invalidation updates which we ignore)
|
||||||
#
|
#
|
||||||
# we expect:
|
# we expect:
|
||||||
#
|
#
|
||||||
|
@ -193,7 +197,9 @@ class EventsStreamTestCase(BaseStreamTestCase):
|
||||||
# of the states that got reverted.
|
# of the states that got reverted.
|
||||||
# - two rows for state2
|
# - two rows for state2
|
||||||
|
|
||||||
received_rows = self.test_handler.received_rdata_rows
|
received_rows = [
|
||||||
|
row for row in self.test_handler.received_rdata_rows if row[0] == "events"
|
||||||
|
]
|
||||||
|
|
||||||
# first check the first two rows, which should be state1
|
# first check the first two rows, which should be state1
|
||||||
|
|
||||||
|
@ -334,9 +340,11 @@ class EventsStreamTestCase(BaseStreamTestCase):
|
||||||
self.reconnect()
|
self.reconnect()
|
||||||
self.replicate()
|
self.replicate()
|
||||||
|
|
||||||
# we should have received all the expected rows in the right order
|
# we should have received all the expected rows in the right order (as
|
||||||
|
# well as various cache invalidation updates which we ignore)
|
||||||
received_rows = self.test_handler.received_rdata_rows
|
received_rows = [
|
||||||
|
row for row in self.test_handler.received_rdata_rows if row[0] == "events"
|
||||||
|
]
|
||||||
self.assertGreaterEqual(len(received_rows), len(events))
|
self.assertGreaterEqual(len(received_rows), len(events))
|
||||||
for i in range(NUM_USERS):
|
for i in range(NUM_USERS):
|
||||||
# for each user, we expect the PL event row, followed by state rows for
|
# for each user, we expect the PL event row, followed by state rows for
|
||||||
|
|
|
@ -31,9 +31,6 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
|
||||||
def test_receipt(self):
|
def test_receipt(self):
|
||||||
self.reconnect()
|
self.reconnect()
|
||||||
|
|
||||||
# make the client subscribe to the receipts stream
|
|
||||||
self.test_handler.stream_positions.update({"receipts": 0})
|
|
||||||
|
|
||||||
# tell the master to send a new receipt
|
# tell the master to send a new receipt
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.hs.get_datastore().insert_receipt(
|
self.hs.get_datastore().insert_receipt(
|
||||||
|
|
|
@ -38,9 +38,6 @@ class TypingStreamTestCase(BaseStreamTestCase):
|
||||||
|
|
||||||
self.reconnect()
|
self.reconnect()
|
||||||
|
|
||||||
# make the client subscribe to the typing stream
|
|
||||||
self.test_handler.stream_positions.update({"typing": 0})
|
|
||||||
|
|
||||||
typing._push_update(member=RoomMember(room_id, USER_ID), typing=True)
|
typing._push_update(member=RoomMember(room_id, USER_ID), typing=True)
|
||||||
|
|
||||||
self.reactor.advance(0)
|
self.reactor.advance(0)
|
||||||
|
|
Loading…
Reference in a new issue