Use stream.current_token() and remove stream_positions() ()

We move the processing of typing and federation replication traffic into their handlers so that `Stream.current_token()` points to a valid token. This allows us to remove `get_streams_to_replicate()` and `stream_positions()`.
This commit is contained in:
Erik Johnston 2020-05-01 15:21:35 +01:00 committed by GitHub
parent 6b22921b19
commit 3085cde577
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 30 additions and 154 deletions

1
changelog.d/7172.misc Normal file
View file

@ -0,0 +1 @@
Use `stream.current_token()` and remove `stream_positions()`.

View file

@ -413,12 +413,6 @@ class GenericWorkerTyping(object):
# map room IDs to sets of users currently typing # map room IDs to sets of users currently typing
self._room_typing = {} self._room_typing = {}
def stream_positions(self):
# We must update this typing token from the response of the previous
# sync. In particular, the stream id may "reset" back to zero/a low
# value which we *must* use for the next replication request.
return {"typing": self._latest_room_serial}
def process_replication_rows(self, token, rows): def process_replication_rows(self, token, rows):
if self._latest_room_serial > token: if self._latest_room_serial > token:
# The master has gone backwards. To prevent inconsistent data, just # The master has gone backwards. To prevent inconsistent data, just
@ -658,13 +652,6 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler):
) )
await self.process_and_notify(stream_name, token, rows) await self.process_and_notify(stream_name, token, rows)
def get_streams_to_replicate(self):
args = super(GenericWorkerReplicationHandler, self).get_streams_to_replicate()
args.update(self.typing_handler.stream_positions())
if self.send_handler:
args.update(self.send_handler.stream_positions())
return args
async def process_and_notify(self, stream_name, token, rows): async def process_and_notify(self, stream_name, token, rows):
try: try:
if self.send_handler: if self.send_handler:
@ -799,9 +786,6 @@ class FederationSenderHandler(object):
def wake_destination(self, server: str): def wake_destination(self, server: str):
self.federation_sender.wake_destination(server) self.federation_sender.wake_destination(server)
def stream_positions(self):
return {"federation": self.federation_position}
async def process_replication_rows(self, stream_name, token, rows): async def process_replication_rows(self, stream_name, token, rows):
# The federation stream contains things that we want to send out, e.g. # The federation stream contains things that we want to send out, e.g.
# presence, typing, etc. # presence, typing, etc.

View file

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Dict, Optional from typing import Optional
import six import six
@ -49,19 +49,6 @@ class BaseSlavedStore(CacheInvalidationWorkerStore):
self.hs = hs self.hs = hs
def stream_positions(self) -> Dict[str, int]:
"""
Get the current positions of all the streams this store wants to subscribe to
Returns:
map from stream name to the most recent update we have for
that stream (ie, the point we want to start replicating from)
"""
pos = {}
if self._cache_id_gen:
pos["caches"] = self._cache_id_gen.get_current_token()
return pos
def get_cache_stream_token(self): def get_cache_stream_token(self):
if self._cache_id_gen: if self._cache_id_gen:
return self._cache_id_gen.get_current_token() return self._cache_id_gen.get_current_token()

View file

@ -32,14 +32,6 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
def get_max_account_data_stream_id(self): def get_max_account_data_stream_id(self):
return self._account_data_id_gen.get_current_token() return self._account_data_id_gen.get_current_token()
def stream_positions(self):
result = super(SlavedAccountDataStore, self).stream_positions()
position = self._account_data_id_gen.get_current_token()
result["user_account_data"] = position
result["room_account_data"] = position
result["tag_account_data"] = position
return result
def process_replication_rows(self, stream_name, token, rows): def process_replication_rows(self, stream_name, token, rows):
if stream_name == "tag_account_data": if stream_name == "tag_account_data":
self._account_data_id_gen.advance(token) self._account_data_id_gen.advance(token)

View file

@ -43,11 +43,6 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
expiry_ms=30 * 60 * 1000, expiry_ms=30 * 60 * 1000,
) )
def stream_positions(self):
result = super(SlavedDeviceInboxStore, self).stream_positions()
result["to_device"] = self._device_inbox_id_gen.get_current_token()
return result
def process_replication_rows(self, stream_name, token, rows): def process_replication_rows(self, stream_name, token, rows):
if stream_name == "to_device": if stream_name == "to_device":
self._device_inbox_id_gen.advance(token) self._device_inbox_id_gen.advance(token)

View file

@ -48,16 +48,6 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
"DeviceListFederationStreamChangeCache", device_list_max "DeviceListFederationStreamChangeCache", device_list_max
) )
def stream_positions(self):
result = super(SlavedDeviceStore, self).stream_positions()
# The user signature stream uses the same stream ID generator as the
# device list stream, so set them both to the device list ID
# generator's current token.
current_token = self._device_list_id_gen.get_current_token()
result[DeviceListsStream.NAME] = current_token
result[UserSignatureStream.NAME] = current_token
return result
def process_replication_rows(self, stream_name, token, rows): def process_replication_rows(self, stream_name, token, rows):
if stream_name == DeviceListsStream.NAME: if stream_name == DeviceListsStream.NAME:
self._device_list_id_gen.advance(token) self._device_list_id_gen.advance(token)

View file

@ -93,12 +93,6 @@ class SlavedEventStore(
def get_room_min_stream_ordering(self): def get_room_min_stream_ordering(self):
return self._backfill_id_gen.get_current_token() return self._backfill_id_gen.get_current_token()
def stream_positions(self):
result = super(SlavedEventStore, self).stream_positions()
result["events"] = self._stream_id_gen.get_current_token()
result["backfill"] = -self._backfill_id_gen.get_current_token()
return result
def process_replication_rows(self, stream_name, token, rows): def process_replication_rows(self, stream_name, token, rows):
if stream_name == "events": if stream_name == "events":
self._stream_id_gen.advance(token) self._stream_id_gen.advance(token)

View file

@ -37,11 +37,6 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
def get_group_stream_token(self): def get_group_stream_token(self):
return self._group_updates_id_gen.get_current_token() return self._group_updates_id_gen.get_current_token()
def stream_positions(self):
result = super(SlavedGroupServerStore, self).stream_positions()
result["groups"] = self._group_updates_id_gen.get_current_token()
return result
def process_replication_rows(self, stream_name, token, rows): def process_replication_rows(self, stream_name, token, rows):
if stream_name == "groups": if stream_name == "groups":
self._group_updates_id_gen.advance(token) self._group_updates_id_gen.advance(token)

View file

@ -41,15 +41,6 @@ class SlavedPresenceStore(BaseSlavedStore):
def get_current_presence_token(self): def get_current_presence_token(self):
return self._presence_id_gen.get_current_token() return self._presence_id_gen.get_current_token()
def stream_positions(self):
result = super(SlavedPresenceStore, self).stream_positions()
if self.hs.config.use_presence:
position = self._presence_id_gen.get_current_token()
result["presence"] = position
return result
def process_replication_rows(self, stream_name, token, rows): def process_replication_rows(self, stream_name, token, rows):
if stream_name == "presence": if stream_name == "presence":
self._presence_id_gen.advance(token) self._presence_id_gen.advance(token)

View file

@ -37,11 +37,6 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
def get_max_push_rules_stream_id(self): def get_max_push_rules_stream_id(self):
return self._push_rules_stream_id_gen.get_current_token() return self._push_rules_stream_id_gen.get_current_token()
def stream_positions(self):
result = super(SlavedPushRuleStore, self).stream_positions()
result["push_rules"] = self._push_rules_stream_id_gen.get_current_token()
return result
def process_replication_rows(self, stream_name, token, rows): def process_replication_rows(self, stream_name, token, rows):
if stream_name == "push_rules": if stream_name == "push_rules":
self._push_rules_stream_id_gen.advance(token) self._push_rules_stream_id_gen.advance(token)

View file

@ -28,11 +28,6 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
) )
def stream_positions(self):
result = super(SlavedPusherStore, self).stream_positions()
result["pushers"] = self._pushers_id_gen.get_current_token()
return result
def get_pushers_stream_token(self): def get_pushers_stream_token(self):
return self._pushers_id_gen.get_current_token() return self._pushers_id_gen.get_current_token()

View file

@ -42,11 +42,6 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
def get_max_receipt_stream_id(self): def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_current_token() return self._receipts_id_gen.get_current_token()
def stream_positions(self):
result = super(SlavedReceiptsStore, self).stream_positions()
result["receipts"] = self._receipts_id_gen.get_current_token()
return result
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id): def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
self.get_receipts_for_user.invalidate((user_id, receipt_type)) self.get_receipts_for_user.invalidate((user_id, receipt_type))
self._get_linearized_receipts_for_room.invalidate_many((room_id,)) self._get_linearized_receipts_for_room.invalidate_many((room_id,))

View file

@ -30,11 +30,6 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore):
def get_current_public_room_stream_id(self): def get_current_public_room_stream_id(self):
return self._public_room_id_gen.get_current_token() return self._public_room_id_gen.get_current_token()
def stream_positions(self):
result = super(RoomStore, self).stream_positions()
result["public_rooms"] = self._public_room_id_gen.get_current_token()
return result
def process_replication_rows(self, stream_name, token, rows): def process_replication_rows(self, stream_name, token, rows):
if stream_name == "public_rooms": if stream_name == "public_rooms":
self._public_room_id_gen.advance(token) self._public_room_id_gen.advance(token)

View file

@ -16,7 +16,7 @@
""" """
import logging import logging
from typing import TYPE_CHECKING, Dict from typing import TYPE_CHECKING
from twisted.internet.protocol import ReconnectingClientFactory from twisted.internet.protocol import ReconnectingClientFactory
@ -100,23 +100,6 @@ class ReplicationDataHandler:
""" """
self.store.process_replication_rows(stream_name, token, rows) self.store.process_replication_rows(stream_name, token, rows)
def get_streams_to_replicate(self) -> Dict[str, int]:
"""Called when a new connection has been established and we need to
subscribe to streams.
Returns:
map from stream name to the most recent update we have for
that stream (ie, the point we want to start replicating from)
"""
args = self.store.stream_positions()
user_account_data = args.pop("user_account_data", None)
room_account_data = args.pop("room_account_data", None)
if user_account_data:
args["account_data"] = user_account_data
elif room_account_data:
args["account_data"] = room_account_data
return args
async def on_position(self, stream_name: str, token: int): async def on_position(self, stream_name: str, token: int):
self.store.process_replication_rows(stream_name, token, []) self.store.process_replication_rows(stream_name, token, [])

View file

@ -314,15 +314,7 @@ class ReplicationCommandHandler:
self._pending_batches.pop(cmd.stream_name, []) self._pending_batches.pop(cmd.stream_name, [])
# Find where we previously streamed up to. # Find where we previously streamed up to.
current_token = self._replication_data_handler.get_streams_to_replicate().get( current_token = stream.current_token()
cmd.stream_name
)
if current_token is None:
logger.warning(
"Got POSITION for stream we're not subscribed to: %s",
cmd.stream_name,
)
return
# If the position token matches our current token then we're up to # If the position token matches our current token then we're up to
# date and there's nothing to do. Otherwise, fetch all updates # date and there's nothing to do. Otherwise, fetch all updates

View file

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Any, Dict, List, Optional, Tuple from typing import Any, List, Optional, Tuple
import attr import attr
@ -22,13 +22,15 @@ from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
from twisted.internet.task import LoopingCall from twisted.internet.task import LoopingCall
from twisted.web.http import HTTPChannel from twisted.web.http import HTTPChannel
from synapse.app.generic_worker import GenericWorkerServer from synapse.app.generic_worker import (
GenericWorkerReplicationHandler,
GenericWorkerServer,
)
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.tcp.client import ReplicationDataHandler
from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.server import HomeServer
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
@ -77,7 +79,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self._server_transport = None self._server_transport = None
def _build_replication_data_handler(self): def _build_replication_data_handler(self):
return TestReplicationDataHandler(self.worker_hs.get_datastore()) return TestReplicationDataHandler(self.worker_hs)
def reconnect(self): def reconnect(self):
if self._client_transport: if self._client_transport:
@ -172,32 +174,20 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self.assertEqual(request.method, b"GET") self.assertEqual(request.method, b"GET")
class TestReplicationDataHandler(ReplicationDataHandler): class TestReplicationDataHandler(GenericWorkerReplicationHandler):
"""Drop-in for ReplicationDataHandler which just collects RDATA rows""" """Drop-in for ReplicationDataHandler which just collects RDATA rows"""
def __init__(self, store: BaseSlavedStore): def __init__(self, hs: HomeServer):
super().__init__(store) super().__init__(hs)
# streams to subscribe to: map from stream id to position
self.stream_positions = {} # type: Dict[str, int]
# list of received (stream_name, token, row) tuples # list of received (stream_name, token, row) tuples
self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]] self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]]
def get_streams_to_replicate(self):
return self.stream_positions
async def on_rdata(self, stream_name, token, rows): async def on_rdata(self, stream_name, token, rows):
await super().on_rdata(stream_name, token, rows) await super().on_rdata(stream_name, token, rows)
for r in rows: for r in rows:
self.received_rdata_rows.append((stream_name, token, r)) self.received_rdata_rows.append((stream_name, token, r))
if (
stream_name in self.stream_positions
and token > self.stream_positions[stream_name]
):
self.stream_positions[stream_name] = token
@attr.s() @attr.s()
class OneShotRequestFactory: class OneShotRequestFactory:

View file

@ -43,7 +43,6 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.user_tok = self.login("u1", "pass") self.user_tok = self.login("u1", "pass")
self.reconnect() self.reconnect()
self.test_handler.stream_positions["events"] = 0
self.room_id = self.helper.create_room_as(tok=self.user_tok) self.room_id = self.helper.create_room_as(tok=self.user_tok)
self.test_handler.received_rdata_rows.clear() self.test_handler.received_rdata_rows.clear()
@ -80,8 +79,12 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.reconnect() self.reconnect()
self.replicate() self.replicate()
# we should have received all the expected rows in the right order # we should have received all the expected rows in the right order (as
received_rows = self.test_handler.received_rdata_rows # well as various cache invalidation updates which we ignore)
received_rows = [
row for row in self.test_handler.received_rdata_rows if row[0] == "events"
]
for event in events: for event in events:
stream_name, token, row = received_rows.pop(0) stream_name, token, row = received_rows.pop(0)
self.assertEqual("events", stream_name) self.assertEqual("events", stream_name)
@ -184,7 +187,8 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.reconnect() self.reconnect()
self.replicate() self.replicate()
# now we should have received all the expected rows in the right order. # we should have received all the expected rows in the right order (as
# well as various cache invalidation updates which we ignore)
# #
# we expect: # we expect:
# #
@ -193,7 +197,9 @@ class EventsStreamTestCase(BaseStreamTestCase):
# of the states that got reverted. # of the states that got reverted.
# - two rows for state2 # - two rows for state2
received_rows = self.test_handler.received_rdata_rows received_rows = [
row for row in self.test_handler.received_rdata_rows if row[0] == "events"
]
# first check the first two rows, which should be state1 # first check the first two rows, which should be state1
@ -334,9 +340,11 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.reconnect() self.reconnect()
self.replicate() self.replicate()
# we should have received all the expected rows in the right order # we should have received all the expected rows in the right order (as
# well as various cache invalidation updates which we ignore)
received_rows = self.test_handler.received_rdata_rows received_rows = [
row for row in self.test_handler.received_rdata_rows if row[0] == "events"
]
self.assertGreaterEqual(len(received_rows), len(events)) self.assertGreaterEqual(len(received_rows), len(events))
for i in range(NUM_USERS): for i in range(NUM_USERS):
# for each user, we expect the PL event row, followed by state rows for # for each user, we expect the PL event row, followed by state rows for

View file

@ -31,9 +31,6 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
def test_receipt(self): def test_receipt(self):
self.reconnect() self.reconnect()
# make the client subscribe to the receipts stream
self.test_handler.stream_positions.update({"receipts": 0})
# tell the master to send a new receipt # tell the master to send a new receipt
self.get_success( self.get_success(
self.hs.get_datastore().insert_receipt( self.hs.get_datastore().insert_receipt(

View file

@ -38,9 +38,6 @@ class TypingStreamTestCase(BaseStreamTestCase):
self.reconnect() self.reconnect()
# make the client subscribe to the typing stream
self.test_handler.stream_positions.update({"typing": 0})
typing._push_update(member=RoomMember(room_id, USER_ID), typing=True) typing._push_update(member=RoomMember(room_id, USER_ID), typing=True)
self.reactor.advance(0) self.reactor.advance(0)