0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-21 18:23:53 +01:00
synapse/tests/unittest.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

794 lines
26 KiB
Python
Raw Normal View History

2016-01-07 05:26:29 +01:00
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector
2019-11-27 22:54:07 +01:00
# Copyright 2019 Matrix.org Federation C.I.C
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
2018-10-01 16:11:58 +02:00
import hashlib
import hmac
2019-12-05 18:58:25 +01:00
import inspect
import logging
import secrets
import time
from typing import Callable, Dict, Iterable, Optional, Tuple, Type, TypeVar, Union
2021-04-09 19:44:38 +02:00
from unittest.mock import Mock, patch
from canonicaljson import json
2019-12-05 18:58:25 +01:00
from twisted.internet.defer import Deferred, ensureDeferred, succeed
from twisted.python.failure import Failure
from twisted.python.threadpool import ThreadPool
from twisted.trial import unittest
from twisted.web.resource import Resource
from synapse import events
2019-11-27 22:54:07 +01:00
from synapse.api.constants import EventTypes, Membership
from synapse.config.homeserver import HomeServerConfig
2019-11-27 22:54:07 +01:00
from synapse.config.ratelimiting import FederationRateLimitConfig
from synapse.federation.transport import server as federation_server
from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest, SynapseSite
from synapse.logging.context import (
SENTINEL_CONTEXT,
LoggingContext,
current_context,
set_current_context,
)
from synapse.server import HomeServer
from synapse.types import UserID, create_requester
from synapse.util.httpresourcetree import create_resource_tree
2019-11-27 22:54:07 +01:00
from synapse.util.ratelimitutils import FederationRateLimiter
2020-11-15 23:49:21 +01:00
from tests.server import FakeChannel, get_clock, make_request, setup_test_homeserver
from tests.test_utils import event_injection, setup_awaitable_errors
from tests.test_utils.logging_setup import setup_logging
from tests.utils import default_config, setupdb
setupdb()
setup_logging()
def around(target):
"""A CLOS-style 'around' modifier, which wraps the original method of the
given instance with another piece of code.
@around(self)
def method_name(orig, *args, **kwargs):
return orig(*args, **kwargs)
"""
2018-08-10 15:54:09 +02:00
def _around(code):
name = code.__name__
orig = getattr(target, name)
2016-02-19 16:34:38 +01:00
def new(*args, **kwargs):
return code(orig, *args, **kwargs)
2016-02-19 16:34:38 +01:00
setattr(target, name, new)
2016-02-19 16:34:38 +01:00
return _around
T = TypeVar("T")
class TestCase(unittest.TestCase):
2014-09-12 19:45:48 +02:00
"""A subclass of twisted.trial's TestCase which looks for 'loglevel'
attributes on both itself and its individual test methods, to override the
root logger's logging level while that test (case|method) runs."""
def __init__(self, methodName, *args, **kwargs):
super().__init__(methodName, *args, **kwargs)
method = getattr(self, methodName)
level = getattr(method, "loglevel", getattr(self, "loglevel", None))
@around(self)
def setUp(orig):
# if we're not starting in the sentinel logcontext, then to be honest
# all future bets are off.
if current_context():
self.fail(
2019-05-10 07:12:11 +02:00
"Test starting with non-sentinel logging context %s"
% (current_context(),)
)
old_level = logging.getLogger().level
if level is not None and old_level != level:
2018-08-10 15:54:09 +02:00
@around(self)
def tearDown(orig):
ret = orig()
logging.getLogger().setLevel(old_level)
return ret
logging.getLogger().setLevel(level)
# Trial messes with the warnings configuration, thus this has to be
# done in the context of an individual TestCase.
self.addCleanup(setup_awaitable_errors())
return orig()
@around(self)
def tearDown(orig):
ret = orig()
# force a GC to workaround problems with deferreds leaking logcontexts when
# they are GCed (see the logcontext docs)
gc.collect()
set_current_context(SENTINEL_CONTEXT)
return ret
def assertObjectHasAttributes(self, attrs, obj):
"""Asserts that the given object has each of the attributes given, and
that the value of each matches according to assertEquals."""
2021-04-20 12:50:49 +02:00
for key in attrs.keys():
if not hasattr(obj, key):
raise AssertionError("Expected obj to have a '.%s'" % key)
try:
self.assertEquals(attrs[key], getattr(obj, key))
except AssertionError as e:
2021-07-13 12:43:15 +02:00
raise (type(e))(f"Assert error for '.{key}':") from e
def assert_dict(self, required, actual):
"""Does a partial assert of a dict.
Args:
required (dict): The keys and value which MUST be in 'actual'.
actual (dict): The test result. Extra keys will not be checked.
"""
for key in required:
2018-08-10 15:54:09 +02:00
self.assertEquals(
required[key], actual[key], msg="%s mismatch. %s" % (key, actual)
)
def DEBUG(target):
2014-09-12 19:45:48 +02:00
"""A decorator to set the .loglevel attribute to logging.DEBUG.
Can apply to either a TestCase or an individual test method."""
target.loglevel = logging.DEBUG
return target
def INFO(target):
"""A decorator to set the .loglevel attribute to logging.INFO.
Can apply to either a TestCase or an individual test method."""
target.loglevel = logging.INFO
return target
def logcontext_clean(target):
"""A decorator which marks the TestCase or method as 'logcontext_clean'
... ie, any logcontext errors should cause a test failure
"""
def logcontext_error(msg):
raise AssertionError("logcontext error: %s" % (msg))
patcher = patch("synapse.logging.context.logcontext_error", new=logcontext_error)
return patcher(target)
class HomeserverTestCase(TestCase):
"""
A base TestCase that reduces boilerplate for HomeServer-using test cases.
Defines a setUp method which creates a mock reactor, and instantiates a homeserver
running on that reactor.
There are various hooks for modifying the way that the homeserver is instantiated:
* override make_homeserver, for example by making it pass different parameters into
setup_test_homeserver.
* override default_config, to return a modified configuration dictionary for use
by setup_test_homeserver.
* On a per-test basis, you can use the @override_config decorator to give a
dictionary containing additional configuration settings to be added to the basic
config dict.
Attributes:
servlets (list[function]): List of servlet registration function.
user_id (str): The user ID to assume if auth is hijacked.
hijack_auth (bool): Whether to hijack auth to return the user specified
in user_id.
"""
2018-08-30 16:19:58 +02:00
servlets = []
hijack_auth = True
needs_threadpool = False
def __init__(self, methodName, *args, **kwargs):
super().__init__(methodName, *args, **kwargs)
# see if we have any additional config for this test
method = getattr(self, methodName)
self._extra_config = getattr(method, "_extra_config", None)
def setUp(self):
"""
Set up the TestCase by calling the homeserver constructor, optionally
hijacking the authentication system to return a fixed user, and then
calling the prepare function.
"""
self.reactor, self.clock = get_clock()
self._hs_args = {"clock": self.clock, "reactor": self.reactor}
self.hs = self.make_homeserver(self.reactor, self.clock)
# Honour the `use_frozen_dicts` config option. We have to do this
# manually because this is taken care of in the app `start` code, which
# we don't run. Plus we want to reset it on tearDown.
events.USE_FROZEN_DICTS = self.hs.config.use_frozen_dicts
if self.hs is None:
raise Exception("No homeserver returned from make_homeserver.")
if not isinstance(self.hs, HomeServer):
raise Exception("A homeserver wasn't returned, but %r" % (self.hs,))
# create the root resource, and a site to wrap it.
self.resource = self.create_test_resource()
self.site = SynapseSite(
logger_name="synapse.access.http.fake",
site_tag=self.hs.config.server.server_name,
config=self.hs.config.server.listeners[0],
resource=self.resource,
server_version_string="1",
max_request_body_size=1234,
reactor=self.reactor,
)
from tests.rest.client.v1.utils import RestHelper
2020-11-13 23:34:08 +01:00
self.helper = RestHelper(self.hs, self.site, getattr(self, "user_id", None))
if hasattr(self, "user_id"):
if self.hijack_auth:
# We need a valid token ID to satisfy foreign key constraints.
token_id = self.get_success(
self.hs.get_datastore().add_access_token_to_user(
self.helper.auth_user_id,
"some_fake_token",
None,
None,
)
)
async def get_user_by_access_token(token=None, allow_guest=False):
return {
"user": UserID.from_string(self.helper.auth_user_id),
"token_id": token_id,
"is_guest": False,
}
async def get_user_by_req(request, allow_guest=False, rights="access"):
return create_requester(
UserID.from_string(self.helper.auth_user_id),
token_id,
False,
False,
None,
)
self.hs.get_auth().get_user_by_req = get_user_by_req
self.hs.get_auth().get_user_by_access_token = get_user_by_access_token
self.hs.get_auth().get_access_token_from_request = Mock(
return_value="1234"
)
if self.needs_threadpool:
self.reactor.threadpool = ThreadPool()
self.addCleanup(self.reactor.threadpool.stop)
self.reactor.threadpool.start()
if hasattr(self, "prepare"):
self.prepare(self.reactor, self.clock, self.hs)
def tearDown(self):
# Reset to not use frozen dicts.
events.USE_FROZEN_DICTS = False
def wait_on_thread(self, deferred, timeout=10):
"""
Wait until a Deferred is done, where it's waiting on a real thread.
"""
start_time = time.time()
while not deferred.called:
if start_time + timeout < time.time():
raise ValueError("Timed out waiting for threadpool")
self.reactor.advance(0.01)
time.sleep(0.01)
def make_homeserver(self, reactor, clock):
"""
Make and return a homeserver.
Args:
reactor: A Twisted Reactor, or something that pretends to be one.
clock (synapse.util.Clock): The Clock, associated with the reactor.
Returns:
A homeserver (synapse.server.HomeServer) suitable for testing.
Function to be overridden in subclasses.
"""
hs = self.setup_test_homeserver()
return hs
def create_test_resource(self) -> Resource:
"""
Create a the root resource for the test server.
The default calls `self.create_resource_dict` and builds the resultant dict
into a tree.
"""
root_resource = Resource()
create_resource_tree(self.create_resource_dict(), root_resource)
return root_resource
def create_resource_dict(self) -> Dict[str, Resource]:
"""Create a resource tree for the test server
A resource tree is a mapping from path to twisted.web.resource.
The default implementation creates a JsonResource and calls each function in
`servlets` to register servlets against it.
"""
servlet_resource = JsonResource(self.hs)
for servlet in self.servlets:
servlet(self.hs, servlet_resource)
return {
"/_matrix/client": servlet_resource,
"/_synapse/admin": servlet_resource,
}
def default_config(self):
2018-10-01 16:11:58 +02:00
"""
Get a default HomeServer config dict.
2018-10-01 16:11:58 +02:00
"""
config = default_config("test")
# apply any additional config which was specified via the override_config
# decorator.
if self._extra_config is not None:
config.update(self._extra_config)
return config
2018-10-01 16:11:58 +02:00
def prepare(self, reactor, clock, homeserver):
"""
Prepare for the test. This involves things like mocking out parts of
the homeserver, or building test data common across the whole test
suite.
Args:
reactor: A Twisted Reactor, or something that pretends to be one.
clock (synapse.util.Clock): The Clock, associated with the reactor.
homeserver (synapse.server.HomeServer): The HomeServer to test
against.
Function to optionally be overridden in subclasses.
"""
def make_request(
self,
method: Union[bytes, str],
path: Union[bytes, str],
content: Union[bytes, dict] = b"",
access_token: Optional[str] = None,
request: Type[T] = SynapseRequest,
shorthand: bool = True,
federation_auth_origin: str = None,
content_is_form: bool = False,
await_result: bool = True,
custom_headers: Optional[
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
] = None,
client_ip: str = "127.0.0.1",
) -> FakeChannel:
"""
Create a SynapseRequest at the path using the method and containing the
given content.
Args:
method (bytes/unicode): The HTTP request method ("verb").
path (bytes/unicode): The HTTP path, suitably URL encoded (e.g.
escaped UTF-8 & spaces and such).
content (bytes or dict): The body of the request. JSON-encoded, if
a dict.
shorthand: Whether to try and be helpful and prefix the given URL
with the usual REST API path, if it doesn't contain it.
federation_auth_origin (bytes|None): if set to not-None, we will add a fake
Authorization header pretenting to be the given server name.
content_is_form: Whether the content is URL encoded form data. Adds the
'Content-Type': 'application/x-www-form-urlencoded' header.
await_result: whether to wait for the request to complete rendering. If
true (the default), will pump the test reactor until the the renderer
tells the channel the request is finished.
custom_headers: (name, value) pairs to add as request headers
client_ip: The IP to use as the requesting IP. Useful for testing
ratelimiting.
Returns:
The FakeChannel object which stores the result of the request.
"""
return make_request(
2019-05-10 07:12:11 +02:00
self.reactor,
2020-11-13 23:39:09 +01:00
self.site,
2019-05-10 07:12:11 +02:00
method,
path,
content,
access_token,
request,
shorthand,
federation_auth_origin,
content_is_form,
await_result,
custom_headers,
client_ip,
)
def setup_test_homeserver(self, *args, **kwargs):
"""
Set up the test homeserver, meant to be called by the overridable
make_homeserver. It automatically passes through the test class's
clock & reactor.
Args:
See tests.utils.setup_test_homeserver.
Returns:
synapse.server.HomeServer
"""
kwargs = dict(kwargs)
kwargs.update(self._hs_args)
if "config" not in kwargs:
config = self.default_config()
else:
config = kwargs["config"]
# Parse the config from a config dict into a HomeServerConfig
config_obj = HomeServerConfig()
config_obj.parse_config_dict(config, "", "")
kwargs["config"] = config_obj
async def run_bg_updates():
with LoggingContext("run_bg_updates"):
while not await stor.db_pool.updates.has_completed_background_updates():
await stor.db_pool.updates.do_next_background_update(1)
hs = setup_test_homeserver(self.addCleanup, *args, **kwargs)
stor = hs.get_datastore()
# Run the database background updates, when running against "master".
if hs.__class__.__name__ == "TestHomeServer":
self.get_success(run_bg_updates())
return hs
2018-08-30 16:19:58 +02:00
2018-09-03 18:21:48 +02:00
def pump(self, by=0.0):
2018-08-30 16:19:58 +02:00
"""
Pump the reactor enough that Deferreds will fire.
"""
2018-09-03 18:21:48 +02:00
self.reactor.pump([by] * 100)
2018-08-30 16:19:58 +02:00
def get_success(self, d, by=0.0):
2019-12-05 18:58:25 +01:00
if inspect.isawaitable(d):
d = ensureDeferred(d)
2018-09-03 18:21:48 +02:00
if not isinstance(d, Deferred):
return d
self.pump(by=by)
2018-08-30 16:19:58 +02:00
return self.successResultOf(d)
2018-10-01 16:11:58 +02:00
def get_failure(self, d, exc):
"""
Run a Deferred and get a Failure from it. The failure must be of the type `exc`.
"""
2019-12-05 18:58:25 +01:00
if inspect.isawaitable(d):
d = ensureDeferred(d)
if not isinstance(d, Deferred):
return d
self.pump()
return self.failureResultOf(d, exc)
def get_success_or_raise(self, d, by=0.0):
"""Drive deferred to completion and return result or raise exception
on failure.
"""
if inspect.isawaitable(d):
deferred = ensureDeferred(d)
if not isinstance(deferred, Deferred):
return d
results: list = []
deferred.addBoth(results.append)
self.pump(by=by)
if not results:
self.fail(
"Success result expected on {!r}, found no result instead".format(
deferred
)
)
result = results[0]
if isinstance(result, Failure):
result.raiseException()
return result
def register_user(
self,
username: str,
password: str,
admin: Optional[bool] = False,
displayname: Optional[str] = None,
) -> str:
2018-10-01 16:11:58 +02:00
"""
Register a user. Requires the Admin API be registered.
Args:
username: The user part of the new user.
password: The password of the new user.
admin: Whether the user should be created as an admin or not.
displayname: The displayname of the new user.
2018-10-01 16:11:58 +02:00
Returns:
The MXID of the new user.
2018-10-01 16:11:58 +02:00
"""
self.hs.config.registration_shared_secret = "shared"
# Create the user
channel = self.make_request("GET", "/_synapse/admin/v1/register")
self.assertEqual(channel.code, 200, msg=channel.result)
2018-10-01 16:11:58 +02:00
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
nonce_str = b"\x00".join([username.encode("utf8"), password.encode("utf8")])
if admin:
nonce_str += b"\x00admin"
else:
nonce_str += b"\x00notadmin"
2018-10-01 16:11:58 +02:00
want_mac.update(nonce.encode("ascii") + b"\x00" + nonce_str)
want_mac = want_mac.hexdigest()
body = json.dumps(
{
"nonce": nonce,
"username": username,
"displayname": displayname,
2018-10-01 16:11:58 +02:00
"password": password,
"admin": admin,
"mac": want_mac,
2020-03-27 12:39:43 +01:00
"inhibit_login": True,
2018-10-01 16:11:58 +02:00
}
)
channel = self.make_request(
"POST", "/_synapse/admin/v1/register", body.encode("utf8")
2018-10-01 16:11:58 +02:00
)
self.assertEqual(channel.code, 200, channel.json_body)
2018-10-01 16:11:58 +02:00
user_id = channel.json_body["user_id"]
return user_id
def login(
self,
username,
password,
device_id=None,
custom_headers: Optional[
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
] = None,
):
2018-10-01 16:11:58 +02:00
"""
Log in a user, and get an access token. Requires the Login API be
registered.
"""
body = {"type": "m.login.password", "user": username, "password": password}
if device_id:
body["device_id"] = device_id
channel = self.make_request(
"POST",
"/_matrix/client/r0/login",
json.dumps(body).encode("utf8"),
custom_headers=custom_headers,
2018-10-01 16:11:58 +02:00
)
self.assertEqual(channel.code, 200, channel.result)
2018-10-01 16:11:58 +02:00
access_token = channel.json_body["access_token"]
2018-10-01 16:11:58 +02:00
return access_token
def create_and_send_event(
self, room_id, user, soft_failed=False, prev_event_ids=None
):
"""
Create and send an event.
Args:
soft_failed (bool): Whether to create a soft failed event or not
prev_event_ids (list[str]|None): Explicitly set the prev events,
or if None just use the default
Returns:
str: The new event's ID.
"""
event_creator = self.hs.get_event_creation_handler()
requester = create_requester(user)
event, context = self.get_success(
event_creator.create_event(
requester,
{
"type": EventTypes.Message,
"room_id": room_id,
"sender": user.to_string(),
"content": {"body": secrets.token_hex(), "msgtype": "m.text"},
},
prev_event_ids=prev_event_ids,
)
)
if soft_failed:
event.internal_metadata.soft_failed = True
self.get_success(
event_creator.handle_new_client_event(requester, event, context)
)
return event.event_id
def add_extremity(self, room_id, event_id):
"""
Add the given event as an extremity to the room.
"""
self.get_success(
self.hs.get_datastore().db_pool.simple_insert(
table="event_forward_extremities",
values={"room_id": room_id, "event_id": event_id},
desc="test_add_extremity",
)
)
self.hs.get_datastore().get_latest_event_ids_in_room.invalidate((room_id,))
def attempt_wrong_password_login(self, username, password):
"""Attempts to login as the user with the given password, asserting
that the attempt *fails*.
"""
body = {"type": "m.login.password", "user": username, "password": password}
channel = self.make_request(
"POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8")
)
self.assertEqual(channel.code, 403, channel.result)
2019-11-27 22:54:07 +01:00
def inject_room_member(self, room: str, user: str, membership: Membership) -> None:
"""
Inject a membership event into a room.
Deprecated: use event_injection.inject_room_member directly
2019-11-27 22:54:07 +01:00
Args:
room: Room ID to inject the event into.
user: MXID of the user to inject the membership for.
membership: The membership type.
"""
self.get_success(
event_injection.inject_member_event(self.hs, room, user, membership)
)
2019-11-27 22:54:07 +01:00
class FederatingHomeserverTestCase(HomeserverTestCase):
"""
A federating homeserver that authenticates incoming requests as `other.example.com`.
"""
def create_resource_dict(self) -> Dict[str, Resource]:
d = super().create_resource_dict()
d["/_matrix/federation"] = TestTransportLayerServer(self.hs)
return d
class TestTransportLayerServer(JsonResource):
"""A test implementation of TransportLayerServer
authenticates incoming requests as `other.example.com`.
"""
def __init__(self, hs):
super().__init__(hs)
2020-09-04 12:54:56 +02:00
class Authenticator:
2019-11-27 22:54:07 +01:00
def authenticate_request(self, request, content):
return succeed("other.example.com")
authenticator = Authenticator()
2019-11-27 22:54:07 +01:00
ratelimiter = FederationRateLimiter(
hs.get_clock(),
2019-11-27 22:54:07 +01:00
FederationRateLimitConfig(
window_size=1,
sleep_limit=1,
sleep_msec=1,
reject_limit=1000,
concurrent_requests=1000,
),
)
federation_server.register_servlets(hs, self, authenticator, ratelimiter)
2019-11-27 22:54:07 +01:00
def override_config(extra_config):
"""A decorator which can be applied to test functions to give additional HS config
For use
For example:
class MyTestCase(HomeserverTestCase):
@override_config({"enable_registration": False, ...})
def test_foo(self):
...
Args:
extra_config(dict): Additional config settings to be merged into the default
config dict before instantiating the test homeserver.
"""
def decorator(func):
func._extra_config = extra_config
return func
return decorator
TV = TypeVar("TV")
def skip_unless(condition: bool, reason: str) -> Callable[[TV], TV]:
"""A test decorator which will skip the decorated test unless a condition is set
For example:
class MyTestCase(TestCase):
@skip_unless(HAS_FOO, "Cannot test without foo")
def test_foo(self):
...
Args:
condition: If true, the test will be skipped
reason: the reason to give for skipping the test
"""
def decorator(f: TV) -> TV:
if not condition:
f.skip = reason # type: ignore
return f
return decorator