diff --git a/changelog.d/9396.misc b/changelog.d/9396.misc new file mode 100644 index 000000000..df1348ec4 --- /dev/null +++ b/changelog.d/9396.misc @@ -0,0 +1 @@ +Convert tests to use `HomeserverTestCase`. diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index ee5217b07..b1a8c58e1 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -17,8 +17,6 @@ from mock import Mock import pymacaroons -from twisted.internet import defer - from synapse.api.auth import Auth from synapse.api.constants import UserTypes from synapse.api.errors import ( @@ -33,19 +31,17 @@ from synapse.storage.databases.main.registration import TokenLookupResult from synapse.types import UserID from tests import unittest -from tests.utils import mock_getRawHeaders, setup_test_homeserver +from tests.test_utils import simple_async_mock +from tests.utils import mock_getRawHeaders -class AuthTestCase(unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - self.state_handler = Mock() +class AuthTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, hs): self.store = Mock() - self.hs = yield setup_test_homeserver(self.addCleanup) - self.hs.get_datastore = Mock(return_value=self.store) - self.hs.get_auth_handler().store = self.store - self.auth = Auth(self.hs) + hs.get_datastore = Mock(return_value=self.store) + hs.get_auth_handler().store = self.store + self.auth = Auth(hs) # AuthBlocking reads from the hs' config on initialization. We need to # modify its config instead of the hs' @@ -57,64 +53,59 @@ class AuthTestCase(unittest.TestCase): # this is overridden for the appservice tests self.store.get_app_service_by_token = Mock(return_value=None) - self.store.insert_client_ip = Mock(return_value=defer.succeed(None)) - self.store.is_support_user = Mock(return_value=defer.succeed(False)) + self.store.insert_client_ip = simple_async_mock(None) + self.store.is_support_user = simple_async_mock(False) - @defer.inlineCallbacks def test_get_user_by_req_user_valid_token(self): user_info = TokenLookupResult( user_id=self.test_user, token_id=5, device_id="device" ) - self.store.get_user_by_access_token = Mock( - return_value=defer.succeed(user_info) - ) + self.store.get_user_by_access_token = simple_async_mock(user_info) request = Mock(args={}) request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request)) + requester = self.get_success(self.auth.get_user_by_req(request)) self.assertEquals(requester.user.to_string(), self.test_user) def test_get_user_by_req_user_bad_token(self): - self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) + self.store.get_user_by_access_token = simple_async_mock(None) request = Mock(args={}) request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - d = defer.ensureDeferred(self.auth.get_user_by_req(request)) - f = self.failureResultOf(d, InvalidClientTokenError).value + f = self.get_failure( + self.auth.get_user_by_req(request), InvalidClientTokenError + ).value self.assertEqual(f.code, 401) self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN") def test_get_user_by_req_user_missing_token(self): user_info = TokenLookupResult(user_id=self.test_user, token_id=5) - self.store.get_user_by_access_token = Mock( - return_value=defer.succeed(user_info) - ) + self.store.get_user_by_access_token = simple_async_mock(user_info) request = Mock(args={}) request.requestHeaders.getRawHeaders = mock_getRawHeaders() - d = defer.ensureDeferred(self.auth.get_user_by_req(request)) - f = self.failureResultOf(d, MissingClientTokenError).value + f = self.get_failure( + self.auth.get_user_by_req(request), MissingClientTokenError + ).value self.assertEqual(f.code, 401) self.assertEqual(f.errcode, "M_MISSING_TOKEN") - @defer.inlineCallbacks def test_get_user_by_req_appservice_valid_token(self): app_service = Mock( token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None ) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) + self.store.get_user_by_access_token = simple_async_mock(None) request = Mock(args={}) request.getClientIP.return_value = "127.0.0.1" request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request)) + requester = self.get_success(self.auth.get_user_by_req(request)) self.assertEquals(requester.user.to_string(), self.test_user) - @defer.inlineCallbacks def test_get_user_by_req_appservice_valid_token_good_ip(self): from netaddr import IPSet @@ -125,13 +116,13 @@ class AuthTestCase(unittest.TestCase): ip_range_whitelist=IPSet(["192.168/16"]), ) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) + self.store.get_user_by_access_token = simple_async_mock(None) request = Mock(args={}) request.getClientIP.return_value = "192.168.10.10" request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request)) + requester = self.get_success(self.auth.get_user_by_req(request)) self.assertEquals(requester.user.to_string(), self.test_user) def test_get_user_by_req_appservice_valid_token_bad_ip(self): @@ -144,42 +135,44 @@ class AuthTestCase(unittest.TestCase): ip_range_whitelist=IPSet(["192.168/16"]), ) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) + self.store.get_user_by_access_token = simple_async_mock(None) request = Mock(args={}) request.getClientIP.return_value = "131.111.8.42" request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - d = defer.ensureDeferred(self.auth.get_user_by_req(request)) - f = self.failureResultOf(d, InvalidClientTokenError).value + f = self.get_failure( + self.auth.get_user_by_req(request), InvalidClientTokenError + ).value self.assertEqual(f.code, 401) self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN") def test_get_user_by_req_appservice_bad_token(self): self.store.get_app_service_by_token = Mock(return_value=None) - self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) + self.store.get_user_by_access_token = simple_async_mock(None) request = Mock(args={}) request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - d = defer.ensureDeferred(self.auth.get_user_by_req(request)) - f = self.failureResultOf(d, InvalidClientTokenError).value + f = self.get_failure( + self.auth.get_user_by_req(request), InvalidClientTokenError + ).value self.assertEqual(f.code, 401) self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN") def test_get_user_by_req_appservice_missing_token(self): app_service = Mock(token="foobar", url="a_url", sender=self.test_user) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) + self.store.get_user_by_access_token = simple_async_mock(None) request = Mock(args={}) request.requestHeaders.getRawHeaders = mock_getRawHeaders() - d = defer.ensureDeferred(self.auth.get_user_by_req(request)) - f = self.failureResultOf(d, MissingClientTokenError).value + f = self.get_failure( + self.auth.get_user_by_req(request), MissingClientTokenError + ).value self.assertEqual(f.code, 401) self.assertEqual(f.errcode, "M_MISSING_TOKEN") - @defer.inlineCallbacks def test_get_user_by_req_appservice_valid_token_valid_user_id(self): masquerading_user_id = b"@doppelganger:matrix.org" app_service = Mock( @@ -188,17 +181,15 @@ class AuthTestCase(unittest.TestCase): app_service.is_interested_in_user = Mock(return_value=True) self.store.get_app_service_by_token = Mock(return_value=app_service) # This just needs to return a truth-y value. - self.store.get_user_by_id = Mock( - return_value=defer.succeed({"is_guest": False}) - ) - self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) + self.store.get_user_by_id = simple_async_mock({"is_guest": False}) + self.store.get_user_by_access_token = simple_async_mock(None) request = Mock(args={}) request.getClientIP.return_value = "127.0.0.1" request.args[b"access_token"] = [self.test_token] request.args[b"user_id"] = [masquerading_user_id] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request)) + requester = self.get_success(self.auth.get_user_by_req(request)) self.assertEquals( requester.user.to_string(), masquerading_user_id.decode("utf8") ) @@ -210,22 +201,18 @@ class AuthTestCase(unittest.TestCase): ) app_service.is_interested_in_user = Mock(return_value=False) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) + self.store.get_user_by_access_token = simple_async_mock(None) request = Mock(args={}) request.getClientIP.return_value = "127.0.0.1" request.args[b"access_token"] = [self.test_token] request.args[b"user_id"] = [masquerading_user_id] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - d = defer.ensureDeferred(self.auth.get_user_by_req(request)) - self.failureResultOf(d, AuthError) + self.get_failure(self.auth.get_user_by_req(request), AuthError) - @defer.inlineCallbacks def test_get_user_from_macaroon(self): - self.store.get_user_by_access_token = Mock( - return_value=defer.succeed( - TokenLookupResult(user_id="@baldrick:matrix.org", device_id="device") - ) + self.store.get_user_by_access_token = simple_async_mock( + TokenLookupResult(user_id="@baldrick:matrix.org", device_id="device") ) user_id = "@baldrick:matrix.org" @@ -237,7 +224,7 @@ class AuthTestCase(unittest.TestCase): macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) - user_info = yield defer.ensureDeferred( + user_info = self.get_success( self.auth.get_user_by_access_token(macaroon.serialize()) ) self.assertEqual(user_id, user_info.user_id) @@ -246,10 +233,9 @@ class AuthTestCase(unittest.TestCase): # from the db. self.assertEqual(user_info.device_id, "device") - @defer.inlineCallbacks def test_get_guest_user_from_macaroon(self): - self.store.get_user_by_id = Mock(return_value=defer.succeed({"is_guest": True})) - self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) + self.store.get_user_by_id = simple_async_mock({"is_guest": True}) + self.store.get_user_by_access_token = simple_async_mock(None) user_id = "@baldrick:matrix.org" macaroon = pymacaroons.Macaroon( @@ -263,20 +249,17 @@ class AuthTestCase(unittest.TestCase): macaroon.add_first_party_caveat("guest = true") serialized = macaroon.serialize() - user_info = yield defer.ensureDeferred( - self.auth.get_user_by_access_token(serialized) - ) + user_info = self.get_success(self.auth.get_user_by_access_token(serialized)) self.assertEqual(user_id, user_info.user_id) self.assertTrue(user_info.is_guest) self.store.get_user_by_id.assert_called_with(user_id) - @defer.inlineCallbacks def test_cannot_use_regular_token_as_guest(self): USER_ID = "@percy:matrix.org" - self.store.add_access_token_to_user = Mock(return_value=defer.succeed(None)) - self.store.get_device = Mock(return_value=defer.succeed(None)) + self.store.add_access_token_to_user = simple_async_mock(None) + self.store.get_device = simple_async_mock(None) - token = yield defer.ensureDeferred( + token = self.get_success( self.hs.get_auth_handler().get_access_token_for_user_id( USER_ID, "DEVICE", valid_until_ms=None ) @@ -289,25 +272,21 @@ class AuthTestCase(unittest.TestCase): puppets_user_id=None, ) - def get_user(tok): + async def get_user(tok): if token != tok: - return defer.succeed(None) - return defer.succeed( - TokenLookupResult( - user_id=USER_ID, is_guest=False, token_id=1234, device_id="DEVICE", - ) + return None + return TokenLookupResult( + user_id=USER_ID, is_guest=False, token_id=1234, device_id="DEVICE", ) self.store.get_user_by_access_token = get_user - self.store.get_user_by_id = Mock( - return_value=defer.succeed({"is_guest": False}) - ) + self.store.get_user_by_id = simple_async_mock({"is_guest": False}) # check the token works request = Mock(args={}) request.args[b"access_token"] = [token.encode("ascii")] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - requester = yield defer.ensureDeferred( + requester = self.get_success( self.auth.get_user_by_req(request, allow_guest=True) ) self.assertEqual(UserID.from_string(USER_ID), requester.user) @@ -323,17 +302,16 @@ class AuthTestCase(unittest.TestCase): request.args[b"access_token"] = [guest_tok.encode("ascii")] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - with self.assertRaises(InvalidClientCredentialsError) as cm: - yield defer.ensureDeferred( - self.auth.get_user_by_req(request, allow_guest=True) - ) + cm = self.get_failure( + self.auth.get_user_by_req(request, allow_guest=True), + InvalidClientCredentialsError, + ) - self.assertEqual(401, cm.exception.code) - self.assertEqual("Guest access token used for regular user", cm.exception.msg) + self.assertEqual(401, cm.value.code) + self.assertEqual("Guest access token used for regular user", cm.value.msg) self.store.get_user_by_id.assert_called_with(USER_ID) - @defer.inlineCallbacks def test_blocking_mau(self): self.auth_blocking._limit_usage_by_mau = False self.auth_blocking._max_mau_value = 50 @@ -341,77 +319,61 @@ class AuthTestCase(unittest.TestCase): small_number_of_users = 1 # Ensure no error thrown - yield defer.ensureDeferred(self.auth.check_auth_blocking()) + self.get_success(self.auth.check_auth_blocking()) self.auth_blocking._limit_usage_by_mau = True - self.store.get_monthly_active_count = Mock( - return_value=defer.succeed(lots_of_users) - ) + self.store.get_monthly_active_count = simple_async_mock(lots_of_users) - with self.assertRaises(ResourceLimitError) as e: - yield defer.ensureDeferred(self.auth.check_auth_blocking()) - self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact) - self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) - self.assertEquals(e.exception.code, 403) + e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError) + self.assertEquals(e.value.admin_contact, self.hs.config.admin_contact) + self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + self.assertEquals(e.value.code, 403) # Ensure does not throw an error - self.store.get_monthly_active_count = Mock( - return_value=defer.succeed(small_number_of_users) - ) - yield defer.ensureDeferred(self.auth.check_auth_blocking()) + self.store.get_monthly_active_count = simple_async_mock(small_number_of_users) + self.get_success(self.auth.check_auth_blocking()) - @defer.inlineCallbacks def test_blocking_mau__depending_on_user_type(self): self.auth_blocking._max_mau_value = 50 self.auth_blocking._limit_usage_by_mau = True - self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100)) + self.store.get_monthly_active_count = simple_async_mock(100) # Support users allowed - yield defer.ensureDeferred( - self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT) - ) - self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100)) + self.get_success(self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT)) + self.store.get_monthly_active_count = simple_async_mock(100) # Bots not allowed - with self.assertRaises(ResourceLimitError): - yield defer.ensureDeferred( - self.auth.check_auth_blocking(user_type=UserTypes.BOT) - ) - self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100)) + self.get_failure( + self.auth.check_auth_blocking(user_type=UserTypes.BOT), ResourceLimitError + ) + self.store.get_monthly_active_count = simple_async_mock(100) # Real users not allowed - with self.assertRaises(ResourceLimitError): - yield defer.ensureDeferred(self.auth.check_auth_blocking()) + self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError) - @defer.inlineCallbacks def test_reserved_threepid(self): self.auth_blocking._limit_usage_by_mau = True self.auth_blocking._max_mau_value = 1 - self.store.get_monthly_active_count = lambda: defer.succeed(2) + self.store.get_monthly_active_count = simple_async_mock(2) threepid = {"medium": "email", "address": "reserved@server.com"} unknown_threepid = {"medium": "email", "address": "unreserved@server.com"} self.auth_blocking._mau_limits_reserved_threepids = [threepid] - with self.assertRaises(ResourceLimitError): - yield defer.ensureDeferred(self.auth.check_auth_blocking()) + self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError) - with self.assertRaises(ResourceLimitError): - yield defer.ensureDeferred( - self.auth.check_auth_blocking(threepid=unknown_threepid) - ) + self.get_failure( + self.auth.check_auth_blocking(threepid=unknown_threepid), ResourceLimitError + ) - yield defer.ensureDeferred(self.auth.check_auth_blocking(threepid=threepid)) + self.get_success(self.auth.check_auth_blocking(threepid=threepid)) - @defer.inlineCallbacks def test_hs_disabled(self): self.auth_blocking._hs_disabled = True self.auth_blocking._hs_disabled_message = "Reason for being disabled" - with self.assertRaises(ResourceLimitError) as e: - yield defer.ensureDeferred(self.auth.check_auth_blocking()) - self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact) - self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) - self.assertEquals(e.exception.code, 403) + e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError) + self.assertEquals(e.value.admin_contact, self.hs.config.admin_contact) + self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + self.assertEquals(e.value.code, 403) - @defer.inlineCallbacks def test_hs_disabled_no_server_notices_user(self): """Check that 'hs_disabled_message' works correctly when there is no server_notices user. @@ -422,16 +384,14 @@ class AuthTestCase(unittest.TestCase): self.auth_blocking._hs_disabled = True self.auth_blocking._hs_disabled_message = "Reason for being disabled" - with self.assertRaises(ResourceLimitError) as e: - yield defer.ensureDeferred(self.auth.check_auth_blocking()) - self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact) - self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) - self.assertEquals(e.exception.code, 403) + e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError) + self.assertEquals(e.value.admin_contact, self.hs.config.admin_contact) + self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + self.assertEquals(e.value.code, 403) - @defer.inlineCallbacks def test_server_notices_mxid_special_cased(self): self.auth_blocking._hs_disabled = True user = "@user:server" self.auth_blocking._server_notices_mxid = user self.auth_blocking._hs_disabled_message = "Reason for being disabled" - yield defer.ensureDeferred(self.auth.check_auth_blocking(user)) + self.get_success(self.auth.check_auth_blocking(user)) diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 279c94a03..ab7d29072 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -18,15 +18,12 @@ import jsonschema -from twisted.internet import defer - from synapse.api.constants import EventContentFields from synapse.api.errors import SynapseError from synapse.api.filtering import Filter from synapse.events import make_event_from_dict from tests import unittest -from tests.utils import setup_test_homeserver user_localpart = "test_user" @@ -39,9 +36,8 @@ def MockEvent(**kwargs): return make_event_from_dict(kwargs) -class FilteringTestCase(unittest.TestCase): - def setUp(self): - hs = setup_test_homeserver(self.addCleanup) +class FilteringTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, hs): self.filtering = hs.get_filtering() self.datastore = hs.get_datastore() @@ -351,10 +347,9 @@ class FilteringTestCase(unittest.TestCase): self.assertTrue(Filter(definition).check(event)) - @defer.inlineCallbacks def test_filter_presence_match(self): user_filter_json = {"presence": {"types": ["m.*"]}} - filter_id = yield defer.ensureDeferred( + filter_id = self.get_success( self.datastore.add_user_filter( user_localpart=user_localpart, user_filter=user_filter_json ) @@ -362,7 +357,7 @@ class FilteringTestCase(unittest.TestCase): event = MockEvent(sender="@foo:bar", type="m.profile") events = [event] - user_filter = yield defer.ensureDeferred( + user_filter = self.get_success( self.filtering.get_user_filter( user_localpart=user_localpart, filter_id=filter_id ) @@ -371,11 +366,10 @@ class FilteringTestCase(unittest.TestCase): results = user_filter.filter_presence(events=events) self.assertEquals(events, results) - @defer.inlineCallbacks def test_filter_presence_no_match(self): user_filter_json = {"presence": {"types": ["m.*"]}} - filter_id = yield defer.ensureDeferred( + filter_id = self.get_success( self.datastore.add_user_filter( user_localpart=user_localpart + "2", user_filter=user_filter_json ) @@ -387,7 +381,7 @@ class FilteringTestCase(unittest.TestCase): ) events = [event] - user_filter = yield defer.ensureDeferred( + user_filter = self.get_success( self.filtering.get_user_filter( user_localpart=user_localpart + "2", filter_id=filter_id ) @@ -396,10 +390,9 @@ class FilteringTestCase(unittest.TestCase): results = user_filter.filter_presence(events=events) self.assertEquals([], results) - @defer.inlineCallbacks def test_filter_room_state_match(self): user_filter_json = {"room": {"state": {"types": ["m.*"]}}} - filter_id = yield defer.ensureDeferred( + filter_id = self.get_success( self.datastore.add_user_filter( user_localpart=user_localpart, user_filter=user_filter_json ) @@ -407,7 +400,7 @@ class FilteringTestCase(unittest.TestCase): event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar") events = [event] - user_filter = yield defer.ensureDeferred( + user_filter = self.get_success( self.filtering.get_user_filter( user_localpart=user_localpart, filter_id=filter_id ) @@ -416,10 +409,9 @@ class FilteringTestCase(unittest.TestCase): results = user_filter.filter_room_state(events=events) self.assertEquals(events, results) - @defer.inlineCallbacks def test_filter_room_state_no_match(self): user_filter_json = {"room": {"state": {"types": ["m.*"]}}} - filter_id = yield defer.ensureDeferred( + filter_id = self.get_success( self.datastore.add_user_filter( user_localpart=user_localpart, user_filter=user_filter_json ) @@ -429,7 +421,7 @@ class FilteringTestCase(unittest.TestCase): ) events = [event] - user_filter = yield defer.ensureDeferred( + user_filter = self.get_success( self.filtering.get_user_filter( user_localpart=user_localpart, filter_id=filter_id ) @@ -454,11 +446,10 @@ class FilteringTestCase(unittest.TestCase): self.assertEquals(filtered_room_ids, ["!allowed:example.com"]) - @defer.inlineCallbacks def test_add_filter(self): user_filter_json = {"room": {"state": {"types": ["m.*"]}}} - filter_id = yield defer.ensureDeferred( + filter_id = self.get_success( self.filtering.add_user_filter( user_localpart=user_localpart, user_filter=user_filter_json ) @@ -468,7 +459,7 @@ class FilteringTestCase(unittest.TestCase): self.assertEquals( user_filter_json, ( - yield defer.ensureDeferred( + self.get_success( self.datastore.get_user_filter( user_localpart=user_localpart, filter_id=0 ) @@ -476,17 +467,16 @@ class FilteringTestCase(unittest.TestCase): ), ) - @defer.inlineCallbacks def test_get_filter(self): user_filter_json = {"room": {"state": {"types": ["m.*"]}}} - filter_id = yield defer.ensureDeferred( + filter_id = self.get_success( self.datastore.add_user_filter( user_localpart=user_localpart, user_filter=user_filter_json ) ) - filter = yield defer.ensureDeferred( + filter = self.get_success( self.filtering.get_user_filter( user_localpart=user_localpart, filter_id=filter_id ) diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 53763cd0f..d5d3fdd99 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -35,8 +35,8 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.mock_scheduler = Mock() hs = Mock() hs.get_datastore.return_value = self.mock_store - self.mock_store.get_received_ts.return_value = defer.succeed(0) - self.mock_store.set_appservice_last_pos.return_value = defer.succeed(None) + self.mock_store.get_received_ts.return_value = make_awaitable(0) + self.mock_store.set_appservice_last_pos.return_value = make_awaitable(None) hs.get_application_service_api.return_value = self.mock_as_api hs.get_application_service_scheduler.return_value = self.mock_scheduler hs.get_clock.return_value = MockClock() @@ -50,16 +50,16 @@ class AppServiceHandlerTestCase(unittest.TestCase): self._mkservice(is_interested=False), ] - self.mock_as_api.query_user.return_value = defer.succeed(True) + self.mock_as_api.query_user.return_value = make_awaitable(True) self.mock_store.get_app_services.return_value = services - self.mock_store.get_user_by_id.return_value = defer.succeed([]) + self.mock_store.get_user_by_id.return_value = make_awaitable([]) event = Mock( sender="@someone:anywhere", type="m.room.message", room_id="!foo:bar" ) self.mock_store.get_new_events_for_appservice.side_effect = [ - defer.succeed((0, [event])), - defer.succeed((0, [])), + make_awaitable((0, [event])), + make_awaitable((0, [])), ] self.handler.notify_interested_services(RoomStreamToken(None, 0)) @@ -72,13 +72,13 @@ class AppServiceHandlerTestCase(unittest.TestCase): services = [self._mkservice(is_interested=True)] services[0].is_interested_in_user.return_value = True self.mock_store.get_app_services.return_value = services - self.mock_store.get_user_by_id.return_value = defer.succeed(None) + self.mock_store.get_user_by_id.return_value = make_awaitable(None) event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar") - self.mock_as_api.query_user.return_value = defer.succeed(True) + self.mock_as_api.query_user.return_value = make_awaitable(True) self.mock_store.get_new_events_for_appservice.side_effect = [ - defer.succeed((0, [event])), - defer.succeed((0, [])), + make_awaitable((0, [event])), + make_awaitable((0, [])), ] self.handler.notify_interested_services(RoomStreamToken(None, 0)) @@ -90,13 +90,13 @@ class AppServiceHandlerTestCase(unittest.TestCase): services = [self._mkservice(is_interested=True)] services[0].is_interested_in_user.return_value = True self.mock_store.get_app_services.return_value = services - self.mock_store.get_user_by_id.return_value = defer.succeed({"name": user_id}) + self.mock_store.get_user_by_id.return_value = make_awaitable({"name": user_id}) event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar") - self.mock_as_api.query_user.return_value = defer.succeed(True) + self.mock_as_api.query_user.return_value = make_awaitable(True) self.mock_store.get_new_events_for_appservice.side_effect = [ - defer.succeed((0, [event])), - defer.succeed((0, [])), + make_awaitable((0, [event])), + make_awaitable((0, [])), ] self.handler.notify_interested_services(RoomStreamToken(None, 0)) @@ -106,7 +106,6 @@ class AppServiceHandlerTestCase(unittest.TestCase): "query_user called when it shouldn't have been.", ) - @defer.inlineCallbacks def test_query_room_alias_exists(self): room_alias_str = "#foo:bar" room_alias = Mock() @@ -127,8 +126,8 @@ class AppServiceHandlerTestCase(unittest.TestCase): Mock(room_id=room_id, servers=servers) ) - result = yield defer.ensureDeferred( - self.handler.query_room_alias_exists(room_alias) + result = self.successResultOf( + defer.ensureDeferred(self.handler.query_room_alias_exists(room_alias)) ) self.mock_as_api.query_alias.assert_called_once_with( diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index e59fa70ba..f3448c94d 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -14,163 +14,11 @@ # limitations under the License. """Tests REST events for /profile paths.""" -import json - -from mock import Mock - -from twisted.internet import defer - -import synapse.types -from synapse.api.errors import AuthError, SynapseError from synapse.rest import admin from synapse.rest.client.v1 import login, profile, room from tests import unittest -from ....utils import MockHttpResource, setup_test_homeserver - -myid = "@1234ABCD:test" -PATH_PREFIX = "/_matrix/client/r0" - - -class MockHandlerProfileTestCase(unittest.TestCase): - """ Tests rest layer of profile management. - - Todo: move these into ProfileTestCase - """ - - @defer.inlineCallbacks - def setUp(self): - self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) - self.mock_handler = Mock( - spec=[ - "get_displayname", - "set_displayname", - "get_avatar_url", - "set_avatar_url", - "check_profile_query_allowed", - ] - ) - - self.mock_handler.get_displayname.return_value = defer.succeed(Mock()) - self.mock_handler.set_displayname.return_value = defer.succeed(Mock()) - self.mock_handler.get_avatar_url.return_value = defer.succeed(Mock()) - self.mock_handler.set_avatar_url.return_value = defer.succeed(Mock()) - self.mock_handler.check_profile_query_allowed.return_value = defer.succeed( - Mock() - ) - - hs = yield setup_test_homeserver( - self.addCleanup, - "test", - federation_http_client=None, - resource_for_client=self.mock_resource, - federation=Mock(), - federation_client=Mock(), - profile_handler=self.mock_handler, - ) - - async def _get_user_by_req(request=None, allow_guest=False): - return synapse.types.create_requester(myid) - - hs.get_auth().get_user_by_req = _get_user_by_req - - profile.register_servlets(hs, self.mock_resource) - - @defer.inlineCallbacks - def test_get_my_name(self): - mocked_get = self.mock_handler.get_displayname - mocked_get.return_value = defer.succeed("Frank") - - (code, response) = yield self.mock_resource.trigger( - "GET", "/profile/%s/displayname" % (myid), None - ) - - self.assertEquals(200, code) - self.assertEquals({"displayname": "Frank"}, response) - self.assertEquals(mocked_get.call_args[0][0].localpart, "1234ABCD") - - @defer.inlineCallbacks - def test_set_my_name(self): - mocked_set = self.mock_handler.set_displayname - mocked_set.return_value = defer.succeed(()) - - (code, response) = yield self.mock_resource.trigger( - "PUT", "/profile/%s/displayname" % (myid), b'{"displayname": "Frank Jr."}' - ) - - self.assertEquals(200, code) - self.assertEquals(mocked_set.call_args[0][0].localpart, "1234ABCD") - self.assertEquals(mocked_set.call_args[0][1].user.localpart, "1234ABCD") - self.assertEquals(mocked_set.call_args[0][2], "Frank Jr.") - - @defer.inlineCallbacks - def test_set_my_name_noauth(self): - mocked_set = self.mock_handler.set_displayname - mocked_set.side_effect = AuthError(400, "message") - - (code, response) = yield self.mock_resource.trigger( - "PUT", - "/profile/%s/displayname" % ("@4567:test"), - b'{"displayname": "Frank Jr."}', - ) - - self.assertTrue(400 <= code < 499, msg="code %d is in the 4xx range" % (code)) - - @defer.inlineCallbacks - def test_get_other_name(self): - mocked_get = self.mock_handler.get_displayname - mocked_get.return_value = defer.succeed("Bob") - - (code, response) = yield self.mock_resource.trigger( - "GET", "/profile/%s/displayname" % ("@opaque:elsewhere"), None - ) - - self.assertEquals(200, code) - self.assertEquals({"displayname": "Bob"}, response) - - @defer.inlineCallbacks - def test_set_other_name(self): - mocked_set = self.mock_handler.set_displayname - mocked_set.side_effect = SynapseError(400, "message") - - (code, response) = yield self.mock_resource.trigger( - "PUT", - "/profile/%s/displayname" % ("@opaque:elsewhere"), - b'{"displayname":"bob"}', - ) - - self.assertTrue(400 <= code <= 499, msg="code %d is in the 4xx range" % (code)) - - @defer.inlineCallbacks - def test_get_my_avatar(self): - mocked_get = self.mock_handler.get_avatar_url - mocked_get.return_value = defer.succeed("http://my.server/me.png") - - (code, response) = yield self.mock_resource.trigger( - "GET", "/profile/%s/avatar_url" % (myid), None - ) - - self.assertEquals(200, code) - self.assertEquals({"avatar_url": "http://my.server/me.png"}, response) - self.assertEquals(mocked_get.call_args[0][0].localpart, "1234ABCD") - - @defer.inlineCallbacks - def test_set_my_avatar(self): - mocked_set = self.mock_handler.set_avatar_url - mocked_set.return_value = defer.succeed(()) - - (code, response) = yield self.mock_resource.trigger( - "PUT", - "/profile/%s/avatar_url" % (myid), - b'{"avatar_url": "http://my.server/pic.gif"}', - ) - - self.assertEquals(200, code) - self.assertEquals(mocked_set.call_args[0][0].localpart, "1234ABCD") - self.assertEquals(mocked_set.call_args[0][1].user.localpart, "1234ABCD") - self.assertEquals(mocked_set.call_args[0][2], "http://my.server/pic.gif") - class ProfileTestCase(unittest.HomeserverTestCase): @@ -187,37 +35,122 @@ class ProfileTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.owner = self.register_user("owner", "pass") self.owner_tok = self.login("owner", "pass") + self.other = self.register_user("other", "pass", displayname="Bob") + + def test_get_displayname(self): + res = self._get_displayname() + self.assertEqual(res, "owner") def test_set_displayname(self): channel = self.make_request( "PUT", "/profile/%s/displayname" % (self.owner,), - content=json.dumps({"displayname": "test"}), + content={"displayname": "test"}, access_token=self.owner_tok, ) self.assertEqual(channel.code, 200, channel.result) - res = self.get_displayname() + res = self._get_displayname() self.assertEqual(res, "test") + def test_set_displayname_noauth(self): + channel = self.make_request( + "PUT", + "/profile/%s/displayname" % (self.owner,), + content={"displayname": "test"}, + ) + self.assertEqual(channel.code, 401, channel.result) + def test_set_displayname_too_long(self): """Attempts to set a stupid displayname should get a 400""" channel = self.make_request( "PUT", "/profile/%s/displayname" % (self.owner,), - content=json.dumps({"displayname": "test" * 100}), + content={"displayname": "test" * 100}, access_token=self.owner_tok, ) self.assertEqual(channel.code, 400, channel.result) - res = self.get_displayname() + res = self._get_displayname() self.assertEqual(res, "owner") - def get_displayname(self): - channel = self.make_request("GET", "/profile/%s/displayname" % (self.owner,)) + def test_get_displayname_other(self): + res = self._get_displayname(self.other) + self.assertEquals(res, "Bob") + + def test_set_displayname_other(self): + channel = self.make_request( + "PUT", + "/profile/%s/displayname" % (self.other,), + content={"displayname": "test"}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 400, channel.result) + + def test_get_avatar_url(self): + res = self._get_avatar_url() + self.assertIsNone(res) + + def test_set_avatar_url(self): + channel = self.make_request( + "PUT", + "/profile/%s/avatar_url" % (self.owner,), + content={"avatar_url": "http://my.server/pic.gif"}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + res = self._get_avatar_url() + self.assertEqual(res, "http://my.server/pic.gif") + + def test_set_avatar_url_noauth(self): + channel = self.make_request( + "PUT", + "/profile/%s/avatar_url" % (self.owner,), + content={"avatar_url": "http://my.server/pic.gif"}, + ) + self.assertEqual(channel.code, 401, channel.result) + + def test_set_avatar_url_too_long(self): + """Attempts to set a stupid avatar_url should get a 400""" + channel = self.make_request( + "PUT", + "/profile/%s/avatar_url" % (self.owner,), + content={"avatar_url": "http://my.server/pic.gif" * 100}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 400, channel.result) + + res = self._get_avatar_url() + self.assertIsNone(res) + + def test_get_avatar_url_other(self): + res = self._get_avatar_url(self.other) + self.assertIsNone(res) + + def test_set_avatar_url_other(self): + channel = self.make_request( + "PUT", + "/profile/%s/avatar_url" % (self.other,), + content={"avatar_url": "http://my.server/pic.gif"}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 400, channel.result) + + def _get_displayname(self, name=None): + channel = self.make_request( + "GET", "/profile/%s/displayname" % (name or self.owner,) + ) self.assertEqual(channel.code, 200, channel.result) return channel.json_body["displayname"] + def _get_avatar_url(self, name=None): + channel = self.make_request( + "GET", "/profile/%s/avatar_url" % (name or self.owner,) + ) + self.assertEqual(channel.code, 200, channel.result) + return channel.json_body.get("avatar_url") + class ProfilesRestrictedTestCase(unittest.HomeserverTestCase):