# Copyright 2014-2016 OpenMarket Ltd # Copyright 2018 New Vector # Copyright 2019 Matrix.org Federation C.I.C # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import gc import hashlib import hmac import logging import secrets import time from typing import ( Any, Awaitable, Callable, ClassVar, Dict, Generic, Iterable, List, NoReturn, Optional, Tuple, Type, TypeVar, Union, ) from unittest.mock import Mock, patch import canonicaljson import signedjson.key import unpaddedbase64 from typing_extensions import Concatenate, ParamSpec, Protocol from twisted.internet.defer import Deferred, ensureDeferred from twisted.python.failure import Failure from twisted.python.threadpool import ThreadPool from twisted.test.proto_helpers import MemoryReactor from twisted.trial import unittest from twisted.web.resource import Resource from twisted.web.server import Request from synapse import events from synapse.api.constants import EventTypes from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.config.homeserver import HomeServerConfig from synapse.config.server import DEFAULT_ROOM_VERSION from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.federation.transport.server import TransportLayerServer from synapse.http.server import JsonResource from synapse.http.site import SynapseRequest, SynapseSite from synapse.logging.context import ( SENTINEL_CONTEXT, LoggingContext, current_context, set_current_context, ) from synapse.rest import RegisterServletsFunc from synapse.server import HomeServer from synapse.storage.keys import FetchKeyResult from synapse.types import JsonDict, Requester, UserID, create_requester from synapse.util import Clock from synapse.util.httpresourcetree import create_resource_tree from tests.server import ( CustomHeaderType, FakeChannel, ThreadedMemoryReactorClock, get_clock, make_request, setup_test_homeserver, ) from tests.test_utils import event_injection, setup_awaitable_errors from tests.test_utils.logging_setup import setup_logging from tests.utils import default_config, setupdb setupdb() setup_logging() TV = TypeVar("TV") _ExcType = TypeVar("_ExcType", bound=BaseException, covariant=True) P = ParamSpec("P") R = TypeVar("R") S = TypeVar("S") class _TypedFailure(Generic[_ExcType], Protocol): """Extension to twisted.Failure, where the 'value' has a certain type.""" @property def value(self) -> _ExcType: ... def around(target: TV) -> Callable[[Callable[Concatenate[S, P], R]], None]: """A CLOS-style 'around' modifier, which wraps the original method of the given instance with another piece of code. @around(self) def method_name(orig, *args, **kwargs): return orig(*args, **kwargs) """ def _around(code: Callable[Concatenate[S, P], R]) -> None: name = code.__name__ orig = getattr(target, name) def new(*args: P.args, **kwargs: P.kwargs) -> R: return code(orig, *args, **kwargs) setattr(target, name, new) return _around class TestCase(unittest.TestCase): """A subclass of twisted.trial's TestCase which looks for 'loglevel' attributes on both itself and its individual test methods, to override the root logger's logging level while that test (case|method) runs.""" def __init__(self, methodName: str): super().__init__(methodName) method = getattr(self, methodName) level = getattr(method, "loglevel", getattr(self, "loglevel", None)) @around(self) def setUp(orig: Callable[[], R]) -> R: # if we're not starting in the sentinel logcontext, then to be honest # all future bets are off. if current_context(): self.fail( "Test starting with non-sentinel logging context %s" % (current_context(),) ) old_level = logging.getLogger().level if level is not None and old_level != level: @around(self) def tearDown(orig: Callable[[], R]) -> R: ret = orig() logging.getLogger().setLevel(old_level) return ret logging.getLogger().setLevel(level) # Trial messes with the warnings configuration, thus this has to be # done in the context of an individual TestCase. self.addCleanup(setup_awaitable_errors()) return orig() @around(self) def tearDown(orig: Callable[[], R]) -> R: ret = orig() # force a GC to workaround problems with deferreds leaking logcontexts when # they are GCed (see the logcontext docs) gc.collect() set_current_context(SENTINEL_CONTEXT) return ret def assertObjectHasAttributes(self, attrs: Dict[str, object], obj: object) -> None: """Asserts that the given object has each of the attributes given, and that the value of each matches according to assertEqual.""" for key in attrs.keys(): if not hasattr(obj, key): raise AssertionError("Expected obj to have a '.%s'" % key) try: self.assertEqual(attrs[key], getattr(obj, key)) except AssertionError as e: raise (type(e))(f"Assert error for '.{key}':") from e def assert_dict(self, required: dict, actual: dict) -> None: """Does a partial assert of a dict. Args: required: The keys and value which MUST be in 'actual'. actual: The test result. Extra keys will not be checked. """ for key in required: self.assertEqual( required[key], actual[key], msg="%s mismatch. %s" % (key, actual) ) def DEBUG(target: TV) -> TV: """A decorator to set the .loglevel attribute to logging.DEBUG. Can apply to either a TestCase or an individual test method.""" target.loglevel = logging.DEBUG # type: ignore[attr-defined] return target def INFO(target: TV) -> TV: """A decorator to set the .loglevel attribute to logging.INFO. Can apply to either a TestCase or an individual test method.""" target.loglevel = logging.INFO # type: ignore[attr-defined] return target def logcontext_clean(target: TV) -> TV: """A decorator which marks the TestCase or method as 'logcontext_clean' ... ie, any logcontext errors should cause a test failure """ def logcontext_error(msg: str) -> NoReturn: raise AssertionError("logcontext error: %s" % (msg)) patcher = patch("synapse.logging.context.logcontext_error", new=logcontext_error) return patcher(target) # type: ignore[call-overload] class HomeserverTestCase(TestCase): """ A base TestCase that reduces boilerplate for HomeServer-using test cases. Defines a setUp method which creates a mock reactor, and instantiates a homeserver running on that reactor. There are various hooks for modifying the way that the homeserver is instantiated: * override make_homeserver, for example by making it pass different parameters into setup_test_homeserver. * override default_config, to return a modified configuration dictionary for use by setup_test_homeserver. * On a per-test basis, you can use the @override_config decorator to give a dictionary containing additional configuration settings to be added to the basic config dict. Attributes: servlets: List of servlet registration function. user_id (str): The user ID to assume if auth is hijacked. hijack_auth: Whether to hijack auth to return the user specified in user_id. """ hijack_auth: ClassVar[bool] = True needs_threadpool: ClassVar[bool] = False servlets: ClassVar[List[RegisterServletsFunc]] = [] def __init__(self, methodName: str): super().__init__(methodName) # see if we have any additional config for this test method = getattr(self, methodName) self._extra_config = getattr(method, "_extra_config", None) def setUp(self) -> None: """ Set up the TestCase by calling the homeserver constructor, optionally hijacking the authentication system to return a fixed user, and then calling the prepare function. """ self.reactor, self.clock = get_clock() self._hs_args = {"clock": self.clock, "reactor": self.reactor} self.hs = self.make_homeserver(self.reactor, self.clock) # Honour the `use_frozen_dicts` config option. We have to do this # manually because this is taken care of in the app `start` code, which # we don't run. Plus we want to reset it on tearDown. events.USE_FROZEN_DICTS = self.hs.config.server.use_frozen_dicts if self.hs is None: raise Exception("No homeserver returned from make_homeserver.") if not isinstance(self.hs, HomeServer): raise Exception("A homeserver wasn't returned, but %r" % (self.hs,)) # create the root resource, and a site to wrap it. self.resource = self.create_test_resource() self.site = SynapseSite( logger_name="synapse.access.http.fake", site_tag=self.hs.config.server.server_name, config=self.hs.config.server.listeners[0], resource=self.resource, server_version_string="1", max_request_body_size=4096, reactor=self.reactor, ) from tests.rest.client.utils import RestHelper self.helper = RestHelper(self.hs, self.site, getattr(self, "user_id", None)) if hasattr(self, "user_id"): if self.hijack_auth: assert self.helper.auth_user_id is not None token = "some_fake_token" # We need a valid token ID to satisfy foreign key constraints. token_id = self.get_success( self.hs.get_datastores().main.add_access_token_to_user( self.helper.auth_user_id, token, None, None, ) ) # This has to be a function and not just a Mock, because # `self.helper.auth_user_id` is temporarily reassigned in some tests async def get_requester(*args: Any, **kwargs: Any) -> Requester: assert self.helper.auth_user_id is not None return create_requester( user_id=UserID.from_string(self.helper.auth_user_id), access_token_id=token_id, ) # Type ignore: mypy doesn't like us assigning to methods. self.hs.get_auth().get_user_by_req = get_requester # type: ignore[assignment] self.hs.get_auth().get_user_by_access_token = get_requester # type: ignore[assignment] self.hs.get_auth().get_access_token_from_request = Mock(return_value=token) # type: ignore[assignment] if self.needs_threadpool: self.reactor.threadpool = ThreadPool() # type: ignore[assignment] self.addCleanup(self.reactor.threadpool.stop) self.reactor.threadpool.start() if hasattr(self, "prepare"): self.prepare(self.reactor, self.clock, self.hs) def tearDown(self) -> None: # Reset to not use frozen dicts. events.USE_FROZEN_DICTS = False def wait_on_thread(self, deferred: Deferred, timeout: int = 10) -> None: """ Wait until a Deferred is done, where it's waiting on a real thread. """ start_time = time.time() while not deferred.called: if start_time + timeout < time.time(): raise ValueError("Timed out waiting for threadpool") self.reactor.advance(0.01) time.sleep(0.01) def wait_for_background_updates(self) -> None: """Block until all background database updates have completed.""" store = self.hs.get_datastores().main while not self.get_success( store.db_pool.updates.has_completed_background_updates() ): self.get_success( store.db_pool.updates.do_next_background_update(False), by=0.1 ) def make_homeserver(self, reactor: ThreadedMemoryReactorClock, clock: Clock): """ Make and return a homeserver. Args: reactor: A Twisted Reactor, or something that pretends to be one. clock: The Clock, associated with the reactor. Returns: A homeserver suitable for testing. Function to be overridden in subclasses. """ hs = self.setup_test_homeserver() return hs def create_test_resource(self) -> Resource: """ Create a the root resource for the test server. The default calls `self.create_resource_dict` and builds the resultant dict into a tree. """ root_resource = Resource() create_resource_tree(self.create_resource_dict(), root_resource) return root_resource def create_resource_dict(self) -> Dict[str, Resource]: """Create a resource tree for the test server A resource tree is a mapping from path to twisted.web.resource. The default implementation creates a JsonResource and calls each function in `servlets` to register servlets against it. """ servlet_resource = JsonResource(self.hs) for servlet in self.servlets: servlet(self.hs, servlet_resource) return { "/_matrix/client": servlet_resource, "/_synapse/admin": servlet_resource, } def default_config(self) -> JsonDict: """ Get a default HomeServer config dict. """ config = default_config("test") # apply any additional config which was specified via the override_config # decorator. if self._extra_config is not None: config.update(self._extra_config) return config def prepare( self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer ) -> None: """ Prepare for the test. This involves things like mocking out parts of the homeserver, or building test data common across the whole test suite. Args: reactor: A Twisted Reactor, or something that pretends to be one. clock: The Clock, associated with the reactor. homeserver: The HomeServer to test against. Function to optionally be overridden in subclasses. """ def make_request( self, method: Union[bytes, str], path: Union[bytes, str], content: Union[bytes, str, JsonDict] = b"", access_token: Optional[str] = None, request: Type[Request] = SynapseRequest, shorthand: bool = True, federation_auth_origin: Optional[bytes] = None, content_is_form: bool = False, await_result: bool = True, custom_headers: Optional[Iterable[CustomHeaderType]] = None, client_ip: str = "127.0.0.1", ) -> FakeChannel: """ Create a SynapseRequest at the path using the method and containing the given content. Args: method: The HTTP request method ("verb"). path: The HTTP path, suitably URL encoded (e.g. escaped UTF-8 & spaces and such). content (bytes or dict): The body of the request. JSON-encoded, if a dict. shorthand: Whether to try and be helpful and prefix the given URL with the usual REST API path, if it doesn't contain it. federation_auth_origin: if set to not-None, we will add a fake Authorization header pretenting to be the given server name. content_is_form: Whether the content is URL encoded form data. Adds the 'Content-Type': 'application/x-www-form-urlencoded' header. await_result: whether to wait for the request to complete rendering. If true (the default), will pump the test reactor until the the renderer tells the channel the request is finished. custom_headers: (name, value) pairs to add as request headers client_ip: The IP to use as the requesting IP. Useful for testing ratelimiting. Returns: The FakeChannel object which stores the result of the request. """ return make_request( self.reactor, self.site, method, path, content, access_token, request, shorthand, federation_auth_origin, content_is_form, await_result, custom_headers, client_ip, ) def setup_test_homeserver(self, *args: Any, **kwargs: Any) -> HomeServer: """ Set up the test homeserver, meant to be called by the overridable make_homeserver. It automatically passes through the test class's clock & reactor. Args: See tests.utils.setup_test_homeserver. Returns: synapse.server.HomeServer """ kwargs = dict(kwargs) kwargs.update(self._hs_args) if "config" not in kwargs: config = self.default_config() else: config = kwargs["config"] # Parse the config from a config dict into a HomeServerConfig config_obj = HomeServerConfig() config_obj.parse_config_dict(config, "", "") kwargs["config"] = config_obj async def run_bg_updates() -> None: with LoggingContext("run_bg_updates"): self.get_success(stor.db_pool.updates.run_background_updates(False)) hs = setup_test_homeserver(self.addCleanup, *args, **kwargs) stor = hs.get_datastores().main # Run the database background updates, when running against "master". if hs.__class__.__name__ == "TestHomeServer": self.get_success(run_bg_updates()) return hs def pump(self, by: float = 0.0) -> None: """ Pump the reactor enough that Deferreds will fire. """ self.reactor.pump([by] * 100) def get_success(self, d: Awaitable[TV], by: float = 0.0) -> TV: deferred: Deferred[TV] = ensureDeferred(d) # type: ignore[arg-type] self.pump(by=by) return self.successResultOf(deferred) def get_failure( self, d: Awaitable[Any], exc: Type[_ExcType] ) -> _TypedFailure[_ExcType]: """ Run a Deferred and get a Failure from it. The failure must be of the type `exc`. """ deferred: Deferred[Any] = ensureDeferred(d) # type: ignore[arg-type] self.pump() return self.failureResultOf(deferred, exc) def get_success_or_raise(self, d: Awaitable[TV], by: float = 0.0) -> TV: """Drive deferred to completion and return result or raise exception on failure. """ deferred: Deferred[TV] = ensureDeferred(d) # type: ignore[arg-type] results: list = [] deferred.addBoth(results.append) self.pump(by=by) if not results: self.fail( "Success result expected on {!r}, found no result instead".format( deferred ) ) result = results[0] if isinstance(result, Failure): result.raiseException() return result def register_user( self, username: str, password: str, admin: Optional[bool] = False, displayname: Optional[str] = None, ) -> str: """ Register a user. Requires the Admin API be registered. Args: username: The user part of the new user. password: The password of the new user. admin: Whether the user should be created as an admin or not. displayname: The displayname of the new user. Returns: The MXID of the new user. """ self.hs.config.registration.registration_shared_secret = "shared" # Create the user channel = self.make_request("GET", "/_synapse/admin/v1/register") self.assertEqual(channel.code, 200, msg=channel.result) nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) nonce_str = b"\x00".join([username.encode("utf8"), password.encode("utf8")]) if admin: nonce_str += b"\x00admin" else: nonce_str += b"\x00notadmin" want_mac.update(nonce.encode("ascii") + b"\x00" + nonce_str) want_mac_digest = want_mac.hexdigest() body = { "nonce": nonce, "username": username, "displayname": displayname, "password": password, "admin": admin, "mac": want_mac_digest, "inhibit_login": True, } channel = self.make_request("POST", "/_synapse/admin/v1/register", body) self.assertEqual(channel.code, 200, channel.json_body) user_id = channel.json_body["user_id"] return user_id def register_appservice_user( self, username: str, appservice_token: str, ) -> Tuple[str, str]: """Register an appservice user as an application service. Requires the client-facing registration API be registered. Args: username: the user to be registered by an application service. Should NOT be a full username, i.e. just "localpart" as opposed to "@localpart:hostname" appservice_token: the acccess token for that application service. Raises: if the request to '/register' does not return 200 OK. Returns: The MXID of the new user, the device ID of the new user's first device. """ channel = self.make_request( "POST", "/_matrix/client/r0/register", { "username": username, "type": "m.login.application_service", }, access_token=appservice_token, ) self.assertEqual(channel.code, 200, channel.json_body) return channel.json_body["user_id"], channel.json_body["device_id"] def login( self, username: str, password: str, device_id: Optional[str] = None, additional_request_fields: Optional[Dict[str, str]] = None, custom_headers: Optional[Iterable[CustomHeaderType]] = None, ) -> str: """ Log in a user, and get an access token. Requires the Login API be registered. Args: username: The localpart to assign to the new user. password: The password to assign to the new user. device_id: An optional device ID to assign to the new device created during login. additional_request_fields: A dictionary containing any additional /login request fields and their values. custom_headers: Custom HTTP headers and values to add to the /login request. Returns: The newly registered user's Matrix ID. """ body = {"type": "m.login.password", "user": username, "password": password} if device_id: body["device_id"] = device_id if additional_request_fields: body.update(additional_request_fields) channel = self.make_request( "POST", "/_matrix/client/r0/login", body, custom_headers=custom_headers, ) self.assertEqual(channel.code, 200, channel.result) access_token = channel.json_body["access_token"] return access_token def create_and_send_event( self, room_id: str, user: UserID, soft_failed: bool = False, prev_event_ids: Optional[List[str]] = None, ) -> str: """ Create and send an event. Args: soft_failed: Whether to create a soft failed event or not prev_event_ids: Explicitly set the prev events, or if None just use the default Returns: The new event's ID. """ event_creator = self.hs.get_event_creation_handler() requester = create_requester(user) event, context = self.get_success( event_creator.create_event( requester, { "type": EventTypes.Message, "room_id": room_id, "sender": user.to_string(), "content": {"body": secrets.token_hex(), "msgtype": "m.text"}, }, prev_event_ids=prev_event_ids, ) ) if soft_failed: event.internal_metadata.soft_failed = True self.get_success( event_creator.handle_new_client_event( requester, events_and_context=[(event, context)] ) ) return event.event_id def inject_room_member(self, room: str, user: str, membership: str) -> None: """ Inject a membership event into a room. Deprecated: use event_injection.inject_room_member directly Args: room: Room ID to inject the event into. user: MXID of the user to inject the membership for. membership: The membership type. """ self.get_success( event_injection.inject_member_event(self.hs, room, user, membership) ) class FederatingHomeserverTestCase(HomeserverTestCase): """ A federating homeserver, set up to validate incoming federation requests """ OTHER_SERVER_NAME = "other.example.com" OTHER_SERVER_SIGNATURE_KEY = signedjson.key.generate_signing_key("test") def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: super().prepare(reactor, clock, hs) # poke the other server's signing key into the key store, so that we don't # make requests for it verify_key = signedjson.key.get_verify_key(self.OTHER_SERVER_SIGNATURE_KEY) verify_key_id = "%s:%s" % (verify_key.alg, verify_key.version) self.get_success( hs.get_datastores().main.store_server_verify_keys( from_server=self.OTHER_SERVER_NAME, ts_added_ms=clock.time_msec(), verify_keys=[ ( self.OTHER_SERVER_NAME, verify_key_id, FetchKeyResult( verify_key=verify_key, valid_until_ts=clock.time_msec() + 10000, ), ) ], ) ) def create_resource_dict(self) -> Dict[str, Resource]: d = super().create_resource_dict() d["/_matrix/federation"] = TransportLayerServer(self.hs) return d def make_signed_federation_request( self, method: str, path: str, content: Optional[JsonDict] = None, await_result: bool = True, custom_headers: Optional[Iterable[CustomHeaderType]] = None, client_ip: str = "127.0.0.1", ) -> FakeChannel: """Make an inbound signed federation request to this server The request is signed as if it came from "other.example.com", which our HS already has the keys for. """ if custom_headers is None: custom_headers = [] else: custom_headers = list(custom_headers) custom_headers.append( ( "Authorization", _auth_header_for_request( origin=self.OTHER_SERVER_NAME, destination=self.hs.hostname, signing_key=self.OTHER_SERVER_SIGNATURE_KEY, method=method, path=path, content=content, ), ) ) return make_request( self.reactor, self.site, method=method, path=path, content=content if content is not None else "", shorthand=False, await_result=await_result, custom_headers=custom_headers, client_ip=client_ip, ) def add_hashes_and_signatures_from_other_server( self, event_dict: JsonDict, room_version: RoomVersion = KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION], ) -> JsonDict: """Adds hashes and signatures to the given event dict Returns: The modified event dict, for convenience """ add_hashes_and_signatures( room_version, event_dict, signature_name=self.OTHER_SERVER_NAME, signing_key=self.OTHER_SERVER_SIGNATURE_KEY, ) return event_dict def _auth_header_for_request( origin: str, destination: str, signing_key: signedjson.key.SigningKey, method: str, path: str, content: Optional[JsonDict], ) -> str: """Build a suitable Authorization header for an outgoing federation request""" request_description: JsonDict = { "method": method, "uri": path, "destination": destination, "origin": origin, } if content is not None: request_description["content"] = content signature_base64 = unpaddedbase64.encode_base64( signing_key.sign( canonicaljson.encode_canonical_json(request_description) ).signature ) return ( f"X-Matrix origin={origin}," f"key={signing_key.alg}:{signing_key.version}," f"sig={signature_base64}" ) def override_config(extra_config: JsonDict) -> Callable[[TV], TV]: """A decorator which can be applied to test functions to give additional HS config For use For example: class MyTestCase(HomeserverTestCase): @override_config({"enable_registration": False, ...}) def test_foo(self): ... Args: extra_config: Additional config settings to be merged into the default config dict before instantiating the test homeserver. """ def decorator(func: TV) -> TV: # This attribute is being defined. func._extra_config = extra_config # type: ignore[attr-defined] return func return decorator def skip_unless(condition: bool, reason: str) -> Callable[[TV], TV]: """A test decorator which will skip the decorated test unless a condition is set For example: class MyTestCase(TestCase): @skip_unless(HAS_FOO, "Cannot test without foo") def test_foo(self): ... Args: condition: If true, the test will be skipped reason: the reason to give for skipping the test """ def decorator(f: TV) -> TV: if not condition: f.skip = reason # type: ignore return f return decorator