diff --git a/changelog.d/7853.misc b/changelog.d/7853.misc new file mode 100644 index 000000000..b4f614084 --- /dev/null +++ b/changelog.d/7853.misc @@ -0,0 +1 @@ +Add support for handling registration requests across multiple client reader workers. diff --git a/synapse/http/client.py b/synapse/http/client.py index 505872ee9..b80681135 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -31,6 +31,7 @@ from twisted.internet.interfaces import ( IReactorPluggableNameResolver, IResolutionReceiver, ) +from twisted.internet.task import Cooperator from twisted.python.failure import Failure from twisted.web._newclient import ResponseDone from twisted.web.client import Agent, HTTPConnectionPool, readBody @@ -69,6 +70,21 @@ def check_against_blacklist(ip_address, ip_whitelist, ip_blacklist): return False +_EPSILON = 0.00000001 + + +def _make_scheduler(reactor): + """Makes a schedular suitable for a Cooperator using the given reactor. + + (This is effectively just a copy from `twisted.internet.task`) + """ + + def _scheduler(x): + return reactor.callLater(_EPSILON, x) + + return _scheduler + + class IPBlacklistingResolver(object): """ A proxy for reactor.nameResolver which only produces non-blacklisted IP @@ -212,6 +228,10 @@ class SimpleHttpClient(object): if hs.config.user_agent_suffix: self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix) + # We use this for our body producers to ensure that they use the correct + # reactor. + self._cooperator = Cooperator(scheduler=_make_scheduler(hs.get_reactor())) + self.user_agent = self.user_agent.encode("ascii") if self._ip_blacklist: @@ -292,7 +312,9 @@ class SimpleHttpClient(object): try: body_producer = None if data is not None: - body_producer = QuieterFileBodyProducer(BytesIO(data)) + body_producer = QuieterFileBodyProducer( + BytesIO(data), cooperator=self._cooperator, + ) request_deferred = treq.request( method, diff --git a/synapse/server.pyi b/synapse/server.pyi index 58cd099e6..cd50c721b 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -20,6 +20,7 @@ import synapse.handlers.room import synapse.handlers.room_member import synapse.handlers.set_password import synapse.http.client +import synapse.http.matrixfederationclient import synapse.notifier import synapse.push.pusherpool import synapse.replication.tcp.client @@ -143,3 +144,7 @@ class HomeServer(object): pass def get_replication_streams(self) -> Dict[str, Stream]: pass + def get_http_client( + self, + ) -> synapse.http.matrixfederationclient.MatrixFederationHttpClient: + pass diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 9d4f0bbe4..06575ba0a 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import Any, List, Optional, Tuple +from typing import Any, Callable, List, Optional, Tuple import attr @@ -26,8 +26,9 @@ from synapse.app.generic_worker import ( GenericWorkerReplicationHandler, GenericWorkerServer, ) +from synapse.http.server import JsonResource from synapse.http.site import SynapseRequest -from synapse.replication.http import streams +from synapse.replication.http import ReplicationRestResource, streams from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory @@ -35,7 +36,7 @@ from synapse.server import HomeServer from synapse.util import Clock from tests import unittest -from tests.server import FakeTransport +from tests.server import FakeTransport, render logger = logging.getLogger(__name__) @@ -180,6 +181,159 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): self.assertEqual(request.method, b"GET") +class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): + """Base class for tests running multiple workers. + + Automatically handle HTTP replication requests from workers to master, + unlike `BaseStreamTestCase`. + """ + + servlets = [] # type: List[Callable[[HomeServer, JsonResource], None]] + + def setUp(self): + super().setUp() + + # build a replication server + self.server_factory = ReplicationStreamProtocolFactory(self.hs) + self.streamer = self.hs.get_replication_streamer() + + store = self.hs.get_datastore() + self.database = store.db + + self.reactor.lookups["testserv"] = "1.2.3.4" + + self._worker_hs_to_resource = {} + + # When we see a connection attempt to the master replication listener we + # automatically set up the connection. This is so that tests don't + # manually have to go and explicitly set it up each time (plus sometimes + # it is impossible to write the handling explicitly in the tests). + self.reactor.add_tcp_client_callback( + "1.2.3.4", 8765, self._handle_http_replication_attempt + ) + + def create_test_json_resource(self): + """Overrides `HomeserverTestCase.create_test_json_resource`. + """ + # We override this so that it automatically registers all the HTTP + # replication servlets, without having to explicitly do that in all + # subclassses. + + resource = ReplicationRestResource(self.hs) + + for servlet in self.servlets: + servlet(self.hs, resource) + + return resource + + def make_worker_hs( + self, worker_app: str, extra_config: dict = {}, **kwargs + ) -> HomeServer: + """Make a new worker HS instance, correctly connecting replcation + stream to the master HS. + + Args: + worker_app: Type of worker, e.g. `synapse.app.federation_sender`. + extra_config: Any extra config to use for this instances. + **kwargs: Options that get passed to `self.setup_test_homeserver`, + useful to e.g. pass some mocks for things like `http_client` + + Returns: + The new worker HomeServer instance. + """ + + config = self._get_worker_hs_config() + config["worker_app"] = worker_app + config.update(extra_config) + + worker_hs = self.setup_test_homeserver( + homeserverToUse=GenericWorkerServer, + config=config, + reactor=self.reactor, + **kwargs + ) + + store = worker_hs.get_datastore() + store.db._db_pool = self.database._db_pool + + repl_handler = ReplicationCommandHandler(worker_hs) + client = ClientReplicationStreamProtocol( + worker_hs, "client", "test", self.clock, repl_handler, + ) + server = self.server_factory.buildProtocol(None) + + client_transport = FakeTransport(server, self.reactor) + client.makeConnection(client_transport) + + server_transport = FakeTransport(client, self.reactor) + server.makeConnection(server_transport) + + # Set up a resource for the worker + resource = ReplicationRestResource(self.hs) + + for servlet in self.servlets: + servlet(worker_hs, resource) + + self._worker_hs_to_resource[worker_hs] = resource + + return worker_hs + + def _get_worker_hs_config(self) -> dict: + config = self.default_config() + config["worker_replication_host"] = "testserv" + config["worker_replication_http_port"] = "8765" + return config + + def render_on_worker(self, worker_hs: HomeServer, request: SynapseRequest): + render(request, self._worker_hs_to_resource[worker_hs], self.reactor) + + def replicate(self): + """Tell the master side of replication that something has happened, and then + wait for the replication to occur. + """ + self.streamer.on_notifier_poke() + self.pump() + + def _handle_http_replication_attempt(self): + """Handles a connection attempt to the master replication HTTP + listener. + """ + + # We should have at least one outbound connection attempt, where the + # last is one to the HTTP repication IP/port. + clients = self.reactor.tcpClients + self.assertGreaterEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients.pop() + self.assertEqual(host, "1.2.3.4") + self.assertEqual(port, 8765) + + # Set up client side protocol + client_protocol = client_factory.buildProtocol(None) + + request_factory = OneShotRequestFactory() + + # Set up the server side protocol + channel = _PushHTTPChannel(self.reactor) + channel.requestFactory = request_factory + channel.site = self.site + + # Connect client to server and vice versa. + client_to_server_transport = FakeTransport( + channel, self.reactor, client_protocol + ) + client_protocol.makeConnection(client_to_server_transport) + + server_to_client_transport = FakeTransport( + client_protocol, self.reactor, channel + ) + channel.makeConnection(server_to_client_transport) + + # Note: at this point we've wired everything up, but we need to return + # before the data starts flowing over the connections as this is called + # inside `connecTCP` before the connection has been passed back to the + # code that requested the TCP connection. + + class TestReplicationDataHandler(GenericWorkerReplicationHandler): """Drop-in for ReplicationDataHandler which just collects RDATA rows""" @@ -241,6 +395,14 @@ class _PushHTTPChannel(HTTPChannel): # We need to manually stop the _PullToPushProducer. self._pull_to_push_producer.stop() + def checkPersistence(self, request, version): + """Check whether the connection can be re-used + """ + # We hijack this to always say no for ease of wiring stuff up in + # `handle_http_replication_attempt`. + request.responseHeaders.setRawHeaders(b"connection", [b"close"]) + return False + class _PullToPushProducer: """A push producer that wraps a pull producer. diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py index b7d753e0a..86c03fd89 100644 --- a/tests/replication/test_client_reader_shard.py +++ b/tests/replication/test_client_reader_shard.py @@ -15,63 +15,26 @@ import logging from synapse.api.constants import LoginType -from synapse.app.generic_worker import GenericWorkerServer -from synapse.http.server import JsonResource from synapse.http.site import SynapseRequest -from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory from synapse.rest.client.v2_alpha import register -from tests import unittest +from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.rest.client.v2_alpha.test_auth import DummyRecaptchaChecker -from tests.server import FakeChannel, render +from tests.server import FakeChannel logger = logging.getLogger(__name__) -class ClientReaderTestCase(unittest.HomeserverTestCase): +class ClientReaderTestCase(BaseMultiWorkerStreamTestCase): """Base class for tests of the replication streams""" - servlets = [ - register.register_servlets, - ] + servlets = [register.register_servlets] def prepare(self, reactor, clock, hs): - # build a replication server - self.server_factory = ReplicationStreamProtocolFactory(hs) - self.streamer = hs.get_replication_streamer() - - store = hs.get_datastore() - self.database = store.db - self.recaptcha_checker = DummyRecaptchaChecker(hs) auth_handler = hs.get_auth_handler() auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker - self.reactor.lookups["testserv"] = "1.2.3.4" - - def make_worker_hs(self, extra_config={}): - config = self._get_worker_hs_config() - config.update(extra_config) - - worker_hs = self.setup_test_homeserver( - homeserverToUse=GenericWorkerServer, config=config, reactor=self.reactor, - ) - - store = worker_hs.get_datastore() - store.db._db_pool = self.database._db_pool - - # Register the expected servlets, essentially this is HomeserverTestCase.create_test_json_resource. - resource = JsonResource(self.hs) - - for servlet in self.servlets: - servlet(worker_hs, resource) - - # Essentially HomeserverTestCase.render. - def _render(request): - render(request, self.resource, self.reactor) - - return worker_hs, _render - def _get_worker_hs_config(self) -> dict: config = self.default_config() config["worker_app"] = "synapse.app.client_reader" @@ -82,14 +45,14 @@ class ClientReaderTestCase(unittest.HomeserverTestCase): def test_register_single_worker(self): """Test that registration works when using a single client reader worker. """ - _, worker_render = self.make_worker_hs() + worker_hs = self.make_worker_hs("synapse.app.client_reader") request_1, channel_1 = self.make_request( "POST", "register", {"username": "user", "type": "m.login.password", "password": "bar"}, ) # type: SynapseRequest, FakeChannel - worker_render(request_1) + self.render_on_worker(worker_hs, request_1) self.assertEqual(request_1.code, 401) # Grab the session @@ -99,7 +62,7 @@ class ClientReaderTestCase(unittest.HomeserverTestCase): request_2, channel_2 = self.make_request( "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}} ) # type: SynapseRequest, FakeChannel - worker_render(request_2) + self.render_on_worker(worker_hs, request_2) self.assertEqual(request_2.code, 200) # We're given a registered user. @@ -108,15 +71,15 @@ class ClientReaderTestCase(unittest.HomeserverTestCase): def test_register_multi_worker(self): """Test that registration works when using multiple client reader workers. """ - _, worker_render_1 = self.make_worker_hs() - _, worker_render_2 = self.make_worker_hs() + worker_hs_1 = self.make_worker_hs("synapse.app.client_reader") + worker_hs_2 = self.make_worker_hs("synapse.app.client_reader") request_1, channel_1 = self.make_request( "POST", "register", {"username": "user", "type": "m.login.password", "password": "bar"}, ) # type: SynapseRequest, FakeChannel - worker_render_1(request_1) + self.render_on_worker(worker_hs_1, request_1) self.assertEqual(request_1.code, 401) # Grab the session @@ -126,7 +89,7 @@ class ClientReaderTestCase(unittest.HomeserverTestCase): request_2, channel_2 = self.make_request( "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}} ) # type: SynapseRequest, FakeChannel - worker_render_2(request_2) + self.render_on_worker(worker_hs_2, request_2) self.assertEqual(request_2.code, 200) # We're given a registered user. diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py index 519a2dc51..8d4dbf232 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py @@ -19,132 +19,40 @@ from mock import Mock from twisted.internet import defer from synapse.api.constants import EventTypes, Membership -from synapse.app.generic_worker import GenericWorkerServer from synapse.events.builder import EventBuilderFactory -from synapse.replication.http import streams -from synapse.replication.tcp.handler import ReplicationCommandHandler -from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol -from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory from synapse.rest.admin import register_servlets_for_client_rest_resource from synapse.rest.client.v1 import login, room from synapse.types import UserID -from tests import unittest -from tests.server import FakeTransport +from tests.replication._base import BaseMultiWorkerStreamTestCase logger = logging.getLogger(__name__) -class BaseStreamTestCase(unittest.HomeserverTestCase): - """Base class for tests of the replication streams""" - - servlets = [ - streams.register_servlets, - ] - - def prepare(self, reactor, clock, hs): - # build a replication server - self.server_factory = ReplicationStreamProtocolFactory(hs) - self.streamer = hs.get_replication_streamer() - - store = hs.get_datastore() - self.database = store.db - - self.reactor.lookups["testserv"] = "1.2.3.4" - - def default_config(self): - conf = super().default_config() - conf["send_federation"] = False - return conf - - def make_worker_hs(self, extra_config={}): - config = self._get_worker_hs_config() - config.update(extra_config) - - mock_federation_client = Mock(spec=["put_json"]) - mock_federation_client.put_json.side_effect = lambda *_, **__: defer.succeed({}) - - worker_hs = self.setup_test_homeserver( - http_client=mock_federation_client, - homeserverToUse=GenericWorkerServer, - config=config, - reactor=self.reactor, - ) - - store = worker_hs.get_datastore() - store.db._db_pool = self.database._db_pool - - repl_handler = ReplicationCommandHandler(worker_hs) - client = ClientReplicationStreamProtocol( - worker_hs, "client", "test", self.clock, repl_handler, - ) - server = self.server_factory.buildProtocol(None) - - client_transport = FakeTransport(server, self.reactor) - client.makeConnection(client_transport) - - server_transport = FakeTransport(client, self.reactor) - server.makeConnection(server_transport) - - return worker_hs - - def _get_worker_hs_config(self) -> dict: - config = self.default_config() - config["worker_app"] = "synapse.app.federation_sender" - config["worker_replication_host"] = "testserv" - config["worker_replication_http_port"] = "8765" - return config - - def replicate(self): - """Tell the master side of replication that something has happened, and then - wait for the replication to occur. - """ - self.streamer.on_notifier_poke() - self.pump() - - def create_room_with_remote_server(self, user, token, remote_server="other_server"): - room = self.helper.create_room_as(user, tok=token) - store = self.hs.get_datastore() - federation = self.hs.get_handlers().federation_handler - - prev_event_ids = self.get_success(store.get_latest_event_ids_in_room(room)) - room_version = self.get_success(store.get_room_version(room)) - - factory = EventBuilderFactory(self.hs) - factory.hostname = remote_server - - user_id = UserID("user", remote_server).to_string() - - event_dict = { - "type": EventTypes.Member, - "state_key": user_id, - "content": {"membership": Membership.JOIN}, - "sender": user_id, - "room_id": room, - } - - builder = factory.for_room_version(room_version, event_dict) - join_event = self.get_success(builder.build(prev_event_ids)) - - self.get_success(federation.on_send_join_request(remote_server, join_event)) - self.replicate() - - return room - - -class FederationSenderTestCase(BaseStreamTestCase): +class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): servlets = [ login.register_servlets, register_servlets_for_client_rest_resource, room.register_servlets, ] + def default_config(self): + conf = super().default_config() + conf["send_federation"] = False + return conf + def test_send_event_single_sender(self): """Test that using a single federation sender worker correctly sends a new event. """ - worker_hs = self.make_worker_hs({"send_federation": True}) - mock_client = worker_hs.get_http_client() + mock_client = Mock(spec=["put_json"]) + mock_client.put_json.side_effect = lambda *_, **__: defer.succeed({}) + + self.make_worker_hs( + "synapse.app.federation_sender", + {"send_federation": True}, + http_client=mock_client, + ) user = self.register_user("user", "pass") token = self.login("user", "pass") @@ -165,23 +73,29 @@ class FederationSenderTestCase(BaseStreamTestCase): """Test that using two federation sender workers correctly sends new events. """ - worker1 = self.make_worker_hs( + mock_client1 = Mock(spec=["put_json"]) + mock_client1.put_json.side_effect = lambda *_, **__: defer.succeed({}) + self.make_worker_hs( + "synapse.app.federation_sender", { "send_federation": True, "worker_name": "sender1", "federation_sender_instances": ["sender1", "sender2"], - } + }, + http_client=mock_client1, ) - mock_client1 = worker1.get_http_client() - worker2 = self.make_worker_hs( + mock_client2 = Mock(spec=["put_json"]) + mock_client2.put_json.side_effect = lambda *_, **__: defer.succeed({}) + self.make_worker_hs( + "synapse.app.federation_sender", { "send_federation": True, "worker_name": "sender2", "federation_sender_instances": ["sender1", "sender2"], - } + }, + http_client=mock_client2, ) - mock_client2 = worker2.get_http_client() user = self.register_user("user2", "pass") token = self.login("user2", "pass") @@ -191,8 +105,8 @@ class FederationSenderTestCase(BaseStreamTestCase): for i in range(20): server_name = "other_server_%d" % (i,) room = self.create_room_with_remote_server(user, token, server_name) - mock_client1.reset_mock() - mock_client2.reset_mock() + mock_client1.reset_mock() # type: ignore[attr-defined] + mock_client2.reset_mock() # type: ignore[attr-defined] self.create_and_send_event(room, UserID.from_string(user)) self.replicate() @@ -222,23 +136,29 @@ class FederationSenderTestCase(BaseStreamTestCase): """Test that using two federation sender workers correctly sends new typing EDUs. """ - worker1 = self.make_worker_hs( + mock_client1 = Mock(spec=["put_json"]) + mock_client1.put_json.side_effect = lambda *_, **__: defer.succeed({}) + self.make_worker_hs( + "synapse.app.federation_sender", { "send_federation": True, "worker_name": "sender1", "federation_sender_instances": ["sender1", "sender2"], - } + }, + http_client=mock_client1, ) - mock_client1 = worker1.get_http_client() - worker2 = self.make_worker_hs( + mock_client2 = Mock(spec=["put_json"]) + mock_client2.put_json.side_effect = lambda *_, **__: defer.succeed({}) + self.make_worker_hs( + "synapse.app.federation_sender", { "send_federation": True, "worker_name": "sender2", "federation_sender_instances": ["sender1", "sender2"], - } + }, + http_client=mock_client2, ) - mock_client2 = worker2.get_http_client() user = self.register_user("user3", "pass") token = self.login("user3", "pass") @@ -250,8 +170,8 @@ class FederationSenderTestCase(BaseStreamTestCase): for i in range(20): server_name = "other_server_%d" % (i,) room = self.create_room_with_remote_server(user, token, server_name) - mock_client1.reset_mock() - mock_client2.reset_mock() + mock_client1.reset_mock() # type: ignore[attr-defined] + mock_client2.reset_mock() # type: ignore[attr-defined] self.get_success( typing_handler.started_typing( @@ -284,3 +204,32 @@ class FederationSenderTestCase(BaseStreamTestCase): self.assertTrue(sent_on_1) self.assertTrue(sent_on_2) + + def create_room_with_remote_server(self, user, token, remote_server="other_server"): + room = self.helper.create_room_as(user, tok=token) + store = self.hs.get_datastore() + federation = self.hs.get_handlers().federation_handler + + prev_event_ids = self.get_success(store.get_latest_event_ids_in_room(room)) + room_version = self.get_success(store.get_room_version(room)) + + factory = EventBuilderFactory(self.hs) + factory.hostname = remote_server + + user_id = UserID("user", remote_server).to_string() + + event_dict = { + "type": EventTypes.Member, + "state_key": user_id, + "content": {"membership": Membership.JOIN}, + "sender": user_id, + "room_id": room, + } + + builder = factory.for_room_version(room_version, event_dict) + join_event = self.get_success(builder.build(prev_event_ids)) + + self.get_success(federation.on_send_join_request(remote_server, join_event)) + self.replicate() + + return room diff --git a/tests/server.py b/tests/server.py index a5e57c52f..b6e0b14e7 100644 --- a/tests/server.py +++ b/tests/server.py @@ -237,6 +237,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): def __init__(self): self.threadpool = ThreadPool(self) + self._tcp_callbacks = {} self._udp = [] lookups = self.lookups = {} @@ -268,6 +269,29 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): def getThreadPool(self): return self.threadpool + def add_tcp_client_callback(self, host, port, callback): + """Add a callback that will be invoked when we receive a connection + attempt to the given IP/port using `connectTCP`. + + Note that the callback gets run before we return the connection to the + client, which means callbacks cannot block while waiting for writes. + """ + self._tcp_callbacks[(host, port)] = callback + + def connectTCP(self, host, port, factory, timeout=30, bindAddress=None): + """Fake L{IReactorTCP.connectTCP}. + """ + + conn = super().connectTCP( + host, port, factory, timeout=timeout, bindAddress=None + ) + + callback = self._tcp_callbacks.get((host, port)) + if callback: + callback() + + return conn + class ThreadPool: """ @@ -486,7 +510,7 @@ class FakeTransport(object): try: self.other.dataReceived(to_write) except Exception as e: - logger.warning("Exception writing to protocol: %s", e) + logger.exception("Exception writing to protocol: %s", e) return self.buffer = self.buffer[len(to_write) :]