diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index b53e7a20e..434718ddf 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -1462,7 +1462,7 @@ def register_servlets(hs, resource, authenticator, ratelimiter, servlet_groups=N Args: hs (synapse.server.HomeServer): homeserver - resource (TransportLayerServer): resource class to register to + resource (JsonResource): resource class to register to authenticator (Authenticator): authenticator to use ratelimiter (util.ratelimitutils.FederationRateLimiter): ratelimiter to use servlet_groups (list[str], optional): List of servlet groups to register. diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index abbdf2d52..9a085ccaf 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -15,18 +15,20 @@ import json +from typing import Dict from mock import ANY, Mock, call from twisted.internet import defer +from twisted.web.resource import Resource from synapse.api.errors import AuthError +from synapse.federation.transport.server import TransportLayerServer from synapse.types import UserID, create_requester from tests import unittest from tests.test_utils import make_awaitable from tests.unittest import override_config -from tests.utils import register_federation_servlets # Some local users to test with U_APPLE = UserID.from_string("@apple:test") @@ -53,8 +55,6 @@ def _make_edu_transaction_json(edu_type, content): class TypingNotificationsTestCase(unittest.HomeserverTestCase): - servlets = [register_federation_servlets] - def make_homeserver(self, reactor, clock): # we mock out the keyring so as to skip the authentication check on the # federation API call. @@ -77,6 +77,11 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): return hs + def create_resource_dict(self) -> Dict[str, Resource]: + d = super().create_resource_dict() + d["/_matrix/federation"] = TransportLayerServer(self.hs) + return d + def prepare(self, reactor, clock, hs): mock_notifier = hs.get_notifier() self.on_new_event = mock_notifier.on_new_event diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 295c5d58a..79738ab46 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Any, Callable, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import attr @@ -21,6 +21,7 @@ from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime from twisted.internet.protocol import Protocol from twisted.internet.task import LoopingCall from twisted.web.http import HTTPChannel +from twisted.web.resource import Resource from synapse.app.generic_worker import ( GenericWorkerReplicationHandler, @@ -28,7 +29,7 @@ from synapse.app.generic_worker import ( ) from synapse.http.server import JsonResource from synapse.http.site import SynapseRequest, SynapseSite -from synapse.replication.http import ReplicationRestResource, streams +from synapse.replication.http import ReplicationRestResource from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory @@ -54,10 +55,6 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): if not hiredis: skip = "Requires hiredis" - servlets = [ - streams.register_servlets, - ] - def prepare(self, reactor, clock, hs): # build a replication server server_factory = ReplicationStreamProtocolFactory(hs) @@ -88,6 +85,11 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): self._client_transport = None self._server_transport = None + def create_resource_dict(self) -> Dict[str, Resource]: + d = super().create_resource_dict() + d["/_synapse/replication"] = ReplicationRestResource(self.hs) + return d + def _get_worker_hs_config(self) -> dict: config = self.default_config() config["worker_app"] = "synapse.app.generic_worker" diff --git a/tests/server.py b/tests/server.py index a51ad0c14..eee970c43 100644 --- a/tests/server.py +++ b/tests/server.py @@ -216,8 +216,9 @@ def make_request( and not path.startswith(b"/_matrix") and not path.startswith(b"/_synapse") ): + if path.startswith(b"/"): + path = path[1:] path = b"/_matrix/client/r0/" + path - path = path.replace(b"//", b"/") if not path.startswith(b"/"): path = b"/" + path diff --git a/tests/unittest.py b/tests/unittest.py index 425b39b1d..102b0a1f3 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -705,13 +705,29 @@ class FederatingHomeserverTestCase(HomeserverTestCase): A federating homeserver that authenticates incoming requests as `other.example.com`. """ - def prepare(self, reactor, clock, homeserver): + def create_resource_dict(self) -> Dict[str, Resource]: + d = super().create_resource_dict() + d["/_matrix/federation"] = TestTransportLayerServer(self.hs) + return d + + +class TestTransportLayerServer(JsonResource): + """A test implementation of TransportLayerServer + + authenticates incoming requests as `other.example.com`. + """ + + def __init__(self, hs): + super().__init__(hs) + class Authenticator: def authenticate_request(self, request, content): return succeed("other.example.com") + authenticator = Authenticator() + ratelimiter = FederationRateLimiter( - clock, + hs.get_clock(), FederationRateLimitConfig( window_size=1, sleep_limit=1, @@ -720,11 +736,8 @@ class FederatingHomeserverTestCase(HomeserverTestCase): concurrent_requests=1000, ), ) - federation_server.register_servlets( - homeserver, self.resource, Authenticator(), ratelimiter - ) - return super().prepare(reactor, clock, homeserver) + federation_server.register_servlets(hs, self, authenticator, ratelimiter) def override_config(extra_config):