0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2025-01-18 05:42:03 +01:00

Prefer make_awaitable over defer.succeed in tests (#12505)

When configuring the return values of mocks, prefer awaitables from
`make_awaitable` over `defer.succeed`. `Deferred`s are only awaitable
once, so it is inappropriate for a mock to return the same `Deferred`
multiple times.

Also update `run_in_background` to support functions that return
arbitrary awaitables.

Signed-off-by: Sean Quah <seanq@element.io>
This commit is contained in:
Sean Quah 2022-04-27 14:58:26 +01:00 committed by GitHub
parent 5ef673de4f
commit 78b99de7c2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 72 additions and 69 deletions

1
changelog.d/12505.misc Normal file
View file

@ -0,0 +1 @@
Use `make_awaitable` instead of `defer.succeed` for return values of mocks in tests.

View file

@ -722,6 +722,11 @@ P = ParamSpec("P")
R = TypeVar("R") R = TypeVar("R")
async def _unwrap_awaitable(awaitable: Awaitable[R]) -> R:
"""Unwraps an arbitrary awaitable by awaiting it."""
return await awaitable
@overload @overload
def preserve_fn( # type: ignore[misc] def preserve_fn( # type: ignore[misc]
f: Callable[P, Awaitable[R]], f: Callable[P, Awaitable[R]],
@ -802,17 +807,20 @@ def run_in_background( # type: ignore[misc]
# by synchronous exceptions, so let's turn them into Failures. # by synchronous exceptions, so let's turn them into Failures.
return defer.fail() return defer.fail()
# `res` may be a coroutine, `Deferred`, some other kind of awaitable, or a plain
# value. Convert it to a `Deferred`.
if isinstance(res, typing.Coroutine): if isinstance(res, typing.Coroutine):
# Wrap the coroutine in a `Deferred`.
res = defer.ensureDeferred(res) res = defer.ensureDeferred(res)
elif isinstance(res, defer.Deferred):
# At this point we should have a Deferred, if not then f was a synchronous pass
# function, wrap it in a Deferred for consistency. elif isinstance(res, Awaitable):
if not isinstance(res, defer.Deferred): # `res` is probably some kind of completed awaitable, such as a `DoneAwaitable`
# `res` is not a `Deferred` and not a `Coroutine`. # or `Future` from `make_awaitable`.
# There are no other types of `Awaitable`s we expect to encounter in Synapse. res = defer.ensureDeferred(_unwrap_awaitable(res))
assert not isinstance(res, Awaitable) else:
# `res` is a plain value. Wrap it in a `Deferred`.
return defer.succeed(res) res = defer.succeed(res)
if res.called and not res.paused: if res.called and not res.paused:
# The function should have maintained the logcontext, so we can # The function should have maintained the logcontext, so we can

View file

@ -83,7 +83,7 @@ class FederationClientTest(FederatingHomeserverTestCase):
) )
# mock up the response, and have the agent return it # mock up the response, and have the agent return it
self._mock_agent.request.return_value = defer.succeed( self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed(
_mock_response( _mock_response(
{ {
"pdus": [ "pdus": [

View file

@ -226,7 +226,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
# Send the server a device list EDU for the other user, this will cause # Send the server a device list EDU for the other user, this will cause
# it to try and resync the device lists. # it to try and resync the device lists.
self.hs.get_federation_transport_client().query_user_devices.return_value = ( self.hs.get_federation_transport_client().query_user_devices.return_value = (
defer.succeed( make_awaitable(
{ {
"stream_id": "1", "stream_id": "1",
"user_id": "@user2:host2", "user_id": "@user2:host2",

View file

@ -19,7 +19,6 @@ from unittest import mock
from parameterized import parameterized from parameterized import parameterized
from signedjson import key as key, sign as sign from signedjson import key as key, sign as sign
from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import RoomEncryptionAlgorithms from synapse.api.constants import RoomEncryptionAlgorithms
@ -704,7 +703,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ" remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
self.hs.get_federation_client().query_client_keys = mock.Mock( self.hs.get_federation_client().query_client_keys = mock.Mock(
return_value=defer.succeed( return_value=make_awaitable(
{ {
"device_keys": {remote_user_id: {}}, "device_keys": {remote_user_id: {}},
"master_keys": { "master_keys": {
@ -777,14 +776,14 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
# Pretend we're sharing a room with the user we're querying. If not, # Pretend we're sharing a room with the user we're querying. If not,
# `_query_devices_for_destination` will return early. # `_query_devices_for_destination` will return early.
self.store.get_rooms_for_user = mock.Mock( self.store.get_rooms_for_user = mock.Mock(
return_value=defer.succeed({"some_room_id"}) return_value=make_awaitable({"some_room_id"})
) )
remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY" remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ" remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
self.hs.get_federation_client().query_user_devices = mock.Mock( self.hs.get_federation_client().query_user_devices = mock.Mock(
return_value=defer.succeed( return_value=make_awaitable(
{ {
"user_id": remote_user_id, "user_id": remote_user_id,
"stream_id": 1, "stream_id": 1,

View file

@ -17,8 +17,6 @@
from typing import Any, Type, Union from typing import Any, Type, Union
from unittest.mock import Mock from unittest.mock import Mock
from twisted.internet import defer
import synapse import synapse
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import Codes from synapse.api.errors import Codes
@ -190,7 +188,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS) self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS)
# check_password must return an awaitable # check_password must return an awaitable
mock_password_provider.check_password.return_value = defer.succeed(True) mock_password_provider.check_password.return_value = make_awaitable(True)
channel = self._send_password_login("u", "p") channel = self._send_password_login("u", "p")
self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.code, 200, channel.result)
self.assertEqual("@u:test", channel.json_body["user_id"]) self.assertEqual("@u:test", channel.json_body["user_id"])
@ -226,13 +224,13 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.get_success(module_api.register_user("u")) self.get_success(module_api.register_user("u"))
# log in twice, to get two devices # log in twice, to get two devices
mock_password_provider.check_password.return_value = defer.succeed(True) mock_password_provider.check_password.return_value = make_awaitable(True)
tok1 = self.login("u", "p") tok1 = self.login("u", "p")
self.login("u", "p", device_id="dev2") self.login("u", "p", device_id="dev2")
mock_password_provider.reset_mock() mock_password_provider.reset_mock()
# have the auth provider deny the request to start with # have the auth provider deny the request to start with
mock_password_provider.check_password.return_value = defer.succeed(False) mock_password_provider.check_password.return_value = make_awaitable(False)
# make the initial request which returns a 401 # make the initial request which returns a 401
session = self._start_delete_device_session(tok1, "dev2") session = self._start_delete_device_session(tok1, "dev2")
@ -246,7 +244,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.reset_mock() mock_password_provider.reset_mock()
# Finally, check the request goes through when we allow it # Finally, check the request goes through when we allow it
mock_password_provider.check_password.return_value = defer.succeed(True) mock_password_provider.check_password.return_value = make_awaitable(True)
channel = self._authed_delete_device(tok1, "dev2", session, "u", "p") channel = self._authed_delete_device(tok1, "dev2", session, "u", "p")
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)
mock_password_provider.check_password.assert_called_once_with("@u:test", "p") mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
@ -260,7 +258,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.register_user("localuser", "localpass") self.register_user("localuser", "localpass")
# check_password must return an awaitable # check_password must return an awaitable
mock_password_provider.check_password.return_value = defer.succeed(False) mock_password_provider.check_password.return_value = make_awaitable(False)
channel = self._send_password_login("u", "p") channel = self._send_password_login("u", "p")
self.assertEqual(channel.code, 403, channel.result) self.assertEqual(channel.code, 403, channel.result)
@ -277,7 +275,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.register_user("localuser", "localpass") self.register_user("localuser", "localpass")
# have the auth provider deny the request # have the auth provider deny the request
mock_password_provider.check_password.return_value = defer.succeed(False) mock_password_provider.check_password.return_value = make_awaitable(False)
# log in twice, to get two devices # log in twice, to get two devices
tok1 = self.login("localuser", "localpass") tok1 = self.login("localuser", "localpass")
@ -320,7 +318,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.register_user("localuser", "localpass") self.register_user("localuser", "localpass")
# check_password must return an awaitable # check_password must return an awaitable
mock_password_provider.check_password.return_value = defer.succeed(False) mock_password_provider.check_password.return_value = make_awaitable(False)
channel = self._send_password_login("localuser", "localpass") channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 403) self.assertEqual(channel.code, 403)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
@ -342,7 +340,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.register_user("localuser", "localpass") self.register_user("localuser", "localpass")
# allow login via the auth provider # allow login via the auth provider
mock_password_provider.check_password.return_value = defer.succeed(True) mock_password_provider.check_password.return_value = make_awaitable(True)
# log in twice, to get two devices # log in twice, to get two devices
tok1 = self.login("localuser", "p") tok1 = self.login("localuser", "p")
@ -359,7 +357,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.check_password.assert_not_called() mock_password_provider.check_password.assert_not_called()
# now try deleting with the local password # now try deleting with the local password
mock_password_provider.check_password.return_value = defer.succeed(False) mock_password_provider.check_password.return_value = make_awaitable(False)
channel = self._authed_delete_device( channel = self._authed_delete_device(
tok1, "dev2", session, "localuser", "localpass" tok1, "dev2", session, "localuser", "localpass"
) )
@ -413,7 +411,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 400, channel.result) self.assertEqual(channel.code, 400, channel.result)
mock_password_provider.check_auth.assert_not_called() mock_password_provider.check_auth.assert_not_called()
mock_password_provider.check_auth.return_value = defer.succeed( mock_password_provider.check_auth.return_value = make_awaitable(
("@user:bz", None) ("@user:bz", None)
) )
channel = self._send_login("test.login_type", "u", test_field="y") channel = self._send_login("test.login_type", "u", test_field="y")
@ -427,7 +425,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# try a weird username. Again, it's unclear what we *expect* to happen # try a weird username. Again, it's unclear what we *expect* to happen
# in these cases, but at least we can guard against the API changing # in these cases, but at least we can guard against the API changing
# unexpectedly # unexpectedly
mock_password_provider.check_auth.return_value = defer.succeed( mock_password_provider.check_auth.return_value = make_awaitable(
("@ MALFORMED! :bz", None) ("@ MALFORMED! :bz", None)
) )
channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ") channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ")
@ -477,7 +475,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.reset_mock() mock_password_provider.reset_mock()
# right params, but authing as the wrong user # right params, but authing as the wrong user
mock_password_provider.check_auth.return_value = defer.succeed( mock_password_provider.check_auth.return_value = make_awaitable(
("@user:bz", None) ("@user:bz", None)
) )
body["auth"]["test_field"] = "foo" body["auth"]["test_field"] = "foo"
@ -490,7 +488,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.reset_mock() mock_password_provider.reset_mock()
# and finally, succeed # and finally, succeed
mock_password_provider.check_auth.return_value = defer.succeed( mock_password_provider.check_auth.return_value = make_awaitable(
("@localuser:test", None) ("@localuser:test", None)
) )
channel = self._delete_device(tok1, "dev2", body) channel = self._delete_device(tok1, "dev2", body)
@ -508,9 +506,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.custom_auth_provider_callback_test_body() self.custom_auth_provider_callback_test_body()
def custom_auth_provider_callback_test_body(self): def custom_auth_provider_callback_test_body(self):
callback = Mock(return_value=defer.succeed(None)) callback = Mock(return_value=make_awaitable(None))
mock_password_provider.check_auth.return_value = defer.succeed( mock_password_provider.check_auth.return_value = make_awaitable(
("@user:bz", callback) ("@user:bz", callback)
) )
channel = self._send_login("test.login_type", "u", test_field="y") channel = self._send_login("test.login_type", "u", test_field="y")
@ -646,7 +644,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
login is disabled""" login is disabled"""
# register the user and log in twice via the test login type to get two devices, # register the user and log in twice via the test login type to get two devices,
self.register_user("localuser", "localpass") self.register_user("localuser", "localpass")
mock_password_provider.check_auth.return_value = defer.succeed( mock_password_provider.check_auth.return_value = make_awaitable(
("@localuser:test", None) ("@localuser:test", None)
) )
channel = self._send_login("test.login_type", "localuser", test_field="") channel = self._send_login("test.login_type", "localuser", test_field="")

View file

@ -65,11 +65,11 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
# we mock out the keyring so as to skip the authentication check on the # we mock out the keyring so as to skip the authentication check on the
# federation API call. # federation API call.
mock_keyring = Mock(spec=["verify_json_for_server"]) mock_keyring = Mock(spec=["verify_json_for_server"])
mock_keyring.verify_json_for_server.return_value = defer.succeed(True) mock_keyring.verify_json_for_server.return_value = make_awaitable(True)
# we mock out the federation client too # we mock out the federation client too
mock_federation_client = Mock(spec=["put_json"]) mock_federation_client = Mock(spec=["put_json"])
mock_federation_client.put_json.return_value = defer.succeed((200, "OK")) mock_federation_client.put_json.return_value = make_awaitable((200, "OK"))
# the tests assume that we are starting at unix time 1000 # the tests assume that we are starting at unix time 1000
reactor.pump((1000,)) reactor.pump((1000,))
@ -98,7 +98,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.datastore = hs.get_datastores().main self.datastore = hs.get_datastores().main
self.datastore.get_destination_retry_timings = Mock( self.datastore.get_destination_retry_timings = Mock(
return_value=defer.succeed(None) return_value=make_awaitable(None)
) )
self.datastore.get_device_updates_by_remote = Mock( self.datastore.get_device_updates_by_remote = Mock(

View file

@ -15,7 +15,6 @@ from typing import Tuple
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from urllib.parse import quote from urllib.parse import quote
from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin import synapse.rest.admin
@ -30,6 +29,7 @@ from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.storage.test_user_directory import GetUserDirectoryTables from tests.storage.test_user_directory import GetUserDirectoryTables
from tests.test_utils import make_awaitable
from tests.test_utils.event_injection import inject_member_event from tests.test_utils.event_injection import inject_member_event
from tests.unittest import override_config from tests.unittest import override_config
@ -439,7 +439,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
) )
) )
mock_remove_from_user_dir = Mock(return_value=defer.succeed(None)) mock_remove_from_user_dir = Mock(return_value=make_awaitable(None))
with patch.object( with patch.object(
self.store, "remove_from_user_dir", mock_remove_from_user_dir self.store, "remove_from_user_dir", mock_remove_from_user_dir
): ):
@ -454,7 +454,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.store.register_user(user_id=r_user_id, password_hash=None) self.store.register_user(user_id=r_user_id, password_hash=None)
) )
mock_remove_from_user_dir = Mock(return_value=defer.succeed(None)) mock_remove_from_user_dir = Mock(return_value=make_awaitable(None))
with patch.object( with patch.object(
self.store, "remove_from_user_dir", mock_remove_from_user_dir self.store, "remove_from_user_dir", mock_remove_from_user_dir
): ):

View file

@ -14,7 +14,6 @@
from http import HTTPStatus from http import HTTPStatus
from unittest.mock import Mock from unittest.mock import Mock
from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
from synapse.handlers.presence import PresenceHandler from synapse.handlers.presence import PresenceHandler
@ -24,6 +23,7 @@ from synapse.types import UserID
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable
class PresenceTestCase(unittest.HomeserverTestCase): class PresenceTestCase(unittest.HomeserverTestCase):
@ -37,7 +37,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
presence_handler = Mock(spec=PresenceHandler) presence_handler = Mock(spec=PresenceHandler)
presence_handler.set_state.return_value = defer.succeed(None) presence_handler.set_state.return_value = make_awaitable(None)
hs = self.setup_test_homeserver( hs = self.setup_test_homeserver(
"red", "red",

View file

@ -22,7 +22,6 @@ from typing import Any, Dict, Iterable, List, Optional
from unittest.mock import Mock, call from unittest.mock import Mock, call
from urllib import parse as urlparse from urllib import parse as urlparse
from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin import synapse.rest.admin
@ -1426,9 +1425,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
def test_simple(self) -> None: def test_simple(self) -> None:
"Simple test for searching rooms over federation" "Simple test for searching rooms over federation"
self.federation_client.get_public_rooms.side_effect = lambda *a, **k: defer.succeed( # type: ignore[attr-defined] self.federation_client.get_public_rooms.return_value = make_awaitable({}) # type: ignore[attr-defined]
{}
)
search_filter = {"generic_search_term": "foobar"} search_filter = {"generic_search_term": "foobar"}
@ -1456,7 +1453,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
# with a 404, when using search filters. # with a 404, when using search filters.
self.federation_client.get_public_rooms.side_effect = ( # type: ignore[attr-defined] self.federation_client.get_public_rooms.side_effect = ( # type: ignore[attr-defined]
HttpResponseException(404, "Not Found", b""), HttpResponseException(404, "Not Found", b""),
defer.succeed({}), make_awaitable({}),
) )
search_filter = {"generic_search_term": "foobar"} search_filter = {"generic_search_term": "foobar"}

View file

@ -22,6 +22,7 @@ from synapse.rest.client.transactions import CLEANUP_PERIOD_MS, HttpTransactionC
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable
from tests.utils import MockClock from tests.utils import MockClock
@ -38,7 +39,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_executes_given_function(self): def test_executes_given_function(self):
cb = Mock(return_value=defer.succeed(self.mock_http_response)) cb = Mock(return_value=make_awaitable(self.mock_http_response))
res = yield self.cache.fetch_or_execute( res = yield self.cache.fetch_or_execute(
self.mock_key, cb, "some_arg", keyword="arg" self.mock_key, cb, "some_arg", keyword="arg"
) )
@ -47,7 +48,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_deduplicates_based_on_key(self): def test_deduplicates_based_on_key(self):
cb = Mock(return_value=defer.succeed(self.mock_http_response)) cb = Mock(return_value=make_awaitable(self.mock_http_response))
for i in range(3): # invoke multiple times for i in range(3): # invoke multiple times
res = yield self.cache.fetch_or_execute( res = yield self.cache.fetch_or_execute(
self.mock_key, cb, "some_arg", keyword="arg", changing_args=i self.mock_key, cb, "some_arg", keyword="arg", changing_args=i
@ -130,7 +131,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_cleans_up(self): def test_cleans_up(self):
cb = Mock(return_value=defer.succeed(self.mock_http_response)) cb = Mock(return_value=make_awaitable(self.mock_http_response))
yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg") yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg")
# should NOT have cleaned up yet # should NOT have cleaned up yet
self.clock.advance_time_msec(CLEANUP_PERIOD_MS / 2) self.clock.advance_time_msec(CLEANUP_PERIOD_MS / 2)

View file

@ -14,8 +14,6 @@
from unittest.mock import Mock from unittest.mock import Mock
from twisted.internet import defer
from synapse.api.constants import EventTypes, LimitBlockingTypes, ServerNoticeMsgType from synapse.api.constants import EventTypes, LimitBlockingTypes, ServerNoticeMsgType
from synapse.api.errors import ResourceLimitError from synapse.api.errors import ResourceLimitError
from synapse.rest import admin from synapse.rest import admin
@ -68,16 +66,16 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
return_value=make_awaitable(1000) return_value=make_awaitable(1000)
) )
self._rlsn._server_notices_manager.send_notice = Mock( self._rlsn._server_notices_manager.send_notice = Mock(
return_value=defer.succeed(Mock()) return_value=make_awaitable(Mock())
) )
self._send_notice = self._rlsn._server_notices_manager.send_notice self._send_notice = self._rlsn._server_notices_manager.send_notice
self.user_id = "@user_id:test" self.user_id = "@user_id:test"
self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = Mock( self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = Mock(
return_value=defer.succeed("!something:localhost") return_value=make_awaitable("!something:localhost")
) )
self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None)) self._rlsn._store.add_tag_to_room = Mock(return_value=make_awaitable(None))
self._rlsn._store.get_tags_for_room = Mock(return_value=make_awaitable({})) self._rlsn._store.get_tags_for_room = Mock(return_value=make_awaitable({}))
@override_config({"hs_disabled": True}) @override_config({"hs_disabled": True})
@ -95,7 +93,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
def test_maybe_send_server_notice_to_user_remove_blocked_notice(self): def test_maybe_send_server_notice_to_user_remove_blocked_notice(self):
"""Test when user has blocked notice, but should have it removed""" """Test when user has blocked notice, but should have it removed"""
self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None)) self._rlsn._auth.check_auth_blocking = Mock(return_value=make_awaitable(None))
mock_event = Mock( mock_event = Mock(
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
) )
@ -111,7 +109,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
Test when user has blocked notice, but notice ought to be there (NOOP) Test when user has blocked notice, but notice ought to be there (NOOP)
""" """
self._rlsn._auth.check_auth_blocking = Mock( self._rlsn._auth.check_auth_blocking = Mock(
return_value=defer.succeed(None), side_effect=ResourceLimitError(403, "foo") return_value=make_awaitable(None),
side_effect=ResourceLimitError(403, "foo"),
) )
mock_event = Mock( mock_event = Mock(
@ -130,7 +129,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
Test when user does not have blocked notice, but should have one Test when user does not have blocked notice, but should have one
""" """
self._rlsn._auth.check_auth_blocking = Mock( self._rlsn._auth.check_auth_blocking = Mock(
return_value=defer.succeed(None), side_effect=ResourceLimitError(403, "foo") return_value=make_awaitable(None),
side_effect=ResourceLimitError(403, "foo"),
) )
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@ -141,7 +141,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
""" """
Test when user does not have blocked notice, nor should they (NOOP) Test when user does not have blocked notice, nor should they (NOOP)
""" """
self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None)) self._rlsn._auth.check_auth_blocking = Mock(return_value=make_awaitable(None))
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@ -152,7 +152,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
Test when user is not part of the MAU cohort - this should not ever Test when user is not part of the MAU cohort - this should not ever
happen - but ... happen - but ...
""" """
self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None)) self._rlsn._auth.check_auth_blocking = Mock(return_value=make_awaitable(None))
self._rlsn._store.user_last_seen_monthly_active = Mock( self._rlsn._store.user_last_seen_monthly_active = Mock(
return_value=make_awaitable(None) return_value=make_awaitable(None)
) )
@ -167,7 +167,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
an alert message is not sent into the room an alert message is not sent into the room
""" """
self._rlsn._auth.check_auth_blocking = Mock( self._rlsn._auth.check_auth_blocking = Mock(
return_value=defer.succeed(None), return_value=make_awaitable(None),
side_effect=ResourceLimitError( side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
), ),
@ -182,7 +182,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
Test that when a server is disabled, that MAU limit alerting is ignored. Test that when a server is disabled, that MAU limit alerting is ignored.
""" """
self._rlsn._auth.check_auth_blocking = Mock( self._rlsn._auth.check_auth_blocking = Mock(
return_value=defer.succeed(None), return_value=make_awaitable(None),
side_effect=ResourceLimitError( side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED 403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED
), ),
@ -199,14 +199,14 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
is suppressed that the room is returned to an unblocked state. is suppressed that the room is returned to an unblocked state.
""" """
self._rlsn._auth.check_auth_blocking = Mock( self._rlsn._auth.check_auth_blocking = Mock(
return_value=defer.succeed(None), return_value=make_awaitable(None),
side_effect=ResourceLimitError( side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
), ),
) )
self._rlsn._server_notices_manager.__is_room_currently_blocked = Mock( self._rlsn._server_notices_manager.__is_room_currently_blocked = Mock(
return_value=defer.succeed((True, [])) return_value=make_awaitable((True, []))
) )
mock_event = Mock( mock_event = Mock(

View file

@ -14,7 +14,6 @@
from typing import Any, Dict, List from typing import Any, Dict, List
from unittest.mock import Mock from unittest.mock import Mock
from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import UserTypes from synapse.api.constants import UserTypes
@ -259,10 +258,10 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
def test_populate_monthly_users_should_update(self): def test_populate_monthly_users_should_update(self):
self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]
self.store.is_trial_user = Mock(return_value=defer.succeed(False)) # type: ignore[assignment] self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment]
self.store.user_last_seen_monthly_active = Mock( self.store.user_last_seen_monthly_active = Mock(
return_value=defer.succeed(None) return_value=make_awaitable(None)
) )
d = self.store.populate_monthly_active_users("user_id") d = self.store.populate_monthly_active_users("user_id")
self.get_success(d) self.get_success(d)
@ -272,9 +271,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
def test_populate_monthly_users_should_not_update(self): def test_populate_monthly_users_should_not_update(self):
self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]
self.store.is_trial_user = Mock(return_value=defer.succeed(False)) # type: ignore[assignment] self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment]
self.store.user_last_seen_monthly_active = Mock( self.store.user_last_seen_monthly_active = Mock(
return_value=defer.succeed(self.hs.get_clock().time_msec()) return_value=make_awaitable(self.hs.get_clock().time_msec())
) )
d = self.store.populate_monthly_active_users("user_id") d = self.store.populate_monthly_active_users("user_id")

View file

@ -233,7 +233,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
# Register mock device list retrieval on the federation client. # Register mock device list retrieval on the federation client.
federation_client = self.homeserver.get_federation_client() federation_client = self.homeserver.get_federation_client()
federation_client.query_user_devices = Mock( federation_client.query_user_devices = Mock(
return_value=succeed( return_value=make_awaitable(
{ {
"user_id": remote_user_id, "user_id": remote_user_id,
"stream_id": 1, "stream_id": 1,