fix up various test cases

A few test cases were relying on being able to mount non-client servlets on the
test resource. it's better to give them their own Resources.
This commit is contained in:
Richard van der Hoff 2020-12-02 15:26:25 +00:00
parent 693516e756
commit 7ea85302f3
5 changed files with 38 additions and 17 deletions

View file

@ -1462,7 +1462,7 @@ def register_servlets(hs, resource, authenticator, ratelimiter, servlet_groups=N
Args:
hs (synapse.server.HomeServer): homeserver
resource (TransportLayerServer): resource class to register to
resource (JsonResource): resource class to register to
authenticator (Authenticator): authenticator to use
ratelimiter (util.ratelimitutils.FederationRateLimiter): ratelimiter to use
servlet_groups (list[str], optional): List of servlet groups to register.

View file

@ -15,18 +15,20 @@
import json
from typing import Dict
from mock import ANY, Mock, call
from twisted.internet import defer
from twisted.web.resource import Resource
from synapse.api.errors import AuthError
from synapse.federation.transport.server import TransportLayerServer
from synapse.types import UserID, create_requester
from tests import unittest
from tests.test_utils import make_awaitable
from tests.unittest import override_config
from tests.utils import register_federation_servlets
# Some local users to test with
U_APPLE = UserID.from_string("@apple:test")
@ -53,8 +55,6 @@ def _make_edu_transaction_json(edu_type, content):
class TypingNotificationsTestCase(unittest.HomeserverTestCase):
servlets = [register_federation_servlets]
def make_homeserver(self, reactor, clock):
# we mock out the keyring so as to skip the authentication check on the
# federation API call.
@ -77,6 +77,11 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
return hs
def create_resource_dict(self) -> Dict[str, Resource]:
d = super().create_resource_dict()
d["/_matrix/federation"] = TransportLayerServer(self.hs)
return d
def prepare(self, reactor, clock, hs):
mock_notifier = hs.get_notifier()
self.on_new_event = mock_notifier.on_new_event

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Callable, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple
import attr
@ -21,6 +21,7 @@ from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
from twisted.internet.protocol import Protocol
from twisted.internet.task import LoopingCall
from twisted.web.http import HTTPChannel
from twisted.web.resource import Resource
from synapse.app.generic_worker import (
GenericWorkerReplicationHandler,
@ -28,7 +29,7 @@ from synapse.app.generic_worker import (
)
from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest, SynapseSite
from synapse.replication.http import ReplicationRestResource, streams
from synapse.replication.http import ReplicationRestResource
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
@ -54,10 +55,6 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
if not hiredis:
skip = "Requires hiredis"
servlets = [
streams.register_servlets,
]
def prepare(self, reactor, clock, hs):
# build a replication server
server_factory = ReplicationStreamProtocolFactory(hs)
@ -88,6 +85,11 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self._client_transport = None
self._server_transport = None
def create_resource_dict(self) -> Dict[str, Resource]:
d = super().create_resource_dict()
d["/_synapse/replication"] = ReplicationRestResource(self.hs)
return d
def _get_worker_hs_config(self) -> dict:
config = self.default_config()
config["worker_app"] = "synapse.app.generic_worker"

View file

@ -216,8 +216,9 @@ def make_request(
and not path.startswith(b"/_matrix")
and not path.startswith(b"/_synapse")
):
if path.startswith(b"/"):
path = path[1:]
path = b"/_matrix/client/r0/" + path
path = path.replace(b"//", b"/")
if not path.startswith(b"/"):
path = b"/" + path

View file

@ -705,13 +705,29 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
A federating homeserver that authenticates incoming requests as `other.example.com`.
"""
def prepare(self, reactor, clock, homeserver):
def create_resource_dict(self) -> Dict[str, Resource]:
d = super().create_resource_dict()
d["/_matrix/federation"] = TestTransportLayerServer(self.hs)
return d
class TestTransportLayerServer(JsonResource):
"""A test implementation of TransportLayerServer
authenticates incoming requests as `other.example.com`.
"""
def __init__(self, hs):
super().__init__(hs)
class Authenticator:
def authenticate_request(self, request, content):
return succeed("other.example.com")
authenticator = Authenticator()
ratelimiter = FederationRateLimiter(
clock,
hs.get_clock(),
FederationRateLimitConfig(
window_size=1,
sleep_limit=1,
@ -720,11 +736,8 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
concurrent_requests=1000,
),
)
federation_server.register_servlets(
homeserver, self.resource, Authenticator(), ratelimiter
)
return super().prepare(reactor, clock, homeserver)
federation_server.register_servlets(hs, self, authenticator, ratelimiter)
def override_config(extra_config):