forked from MirrorHub/synapse
Prefer make_awaitable
over defer.succeed
in tests (#12505)
When configuring the return values of mocks, prefer awaitables from `make_awaitable` over `defer.succeed`. `Deferred`s are only awaitable once, so it is inappropriate for a mock to return the same `Deferred` multiple times. Also update `run_in_background` to support functions that return arbitrary awaitables. Signed-off-by: Sean Quah <seanq@element.io>
This commit is contained in:
parent
5ef673de4f
commit
78b99de7c2
14 changed files with 72 additions and 69 deletions
1
changelog.d/12505.misc
Normal file
1
changelog.d/12505.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Use `make_awaitable` instead of `defer.succeed` for return values of mocks in tests.
|
|
@ -722,6 +722,11 @@ P = ParamSpec("P")
|
||||||
R = TypeVar("R")
|
R = TypeVar("R")
|
||||||
|
|
||||||
|
|
||||||
|
async def _unwrap_awaitable(awaitable: Awaitable[R]) -> R:
|
||||||
|
"""Unwraps an arbitrary awaitable by awaiting it."""
|
||||||
|
return await awaitable
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def preserve_fn( # type: ignore[misc]
|
def preserve_fn( # type: ignore[misc]
|
||||||
f: Callable[P, Awaitable[R]],
|
f: Callable[P, Awaitable[R]],
|
||||||
|
@ -802,17 +807,20 @@ def run_in_background( # type: ignore[misc]
|
||||||
# by synchronous exceptions, so let's turn them into Failures.
|
# by synchronous exceptions, so let's turn them into Failures.
|
||||||
return defer.fail()
|
return defer.fail()
|
||||||
|
|
||||||
|
# `res` may be a coroutine, `Deferred`, some other kind of awaitable, or a plain
|
||||||
|
# value. Convert it to a `Deferred`.
|
||||||
if isinstance(res, typing.Coroutine):
|
if isinstance(res, typing.Coroutine):
|
||||||
|
# Wrap the coroutine in a `Deferred`.
|
||||||
res = defer.ensureDeferred(res)
|
res = defer.ensureDeferred(res)
|
||||||
|
elif isinstance(res, defer.Deferred):
|
||||||
# At this point we should have a Deferred, if not then f was a synchronous
|
pass
|
||||||
# function, wrap it in a Deferred for consistency.
|
elif isinstance(res, Awaitable):
|
||||||
if not isinstance(res, defer.Deferred):
|
# `res` is probably some kind of completed awaitable, such as a `DoneAwaitable`
|
||||||
# `res` is not a `Deferred` and not a `Coroutine`.
|
# or `Future` from `make_awaitable`.
|
||||||
# There are no other types of `Awaitable`s we expect to encounter in Synapse.
|
res = defer.ensureDeferred(_unwrap_awaitable(res))
|
||||||
assert not isinstance(res, Awaitable)
|
else:
|
||||||
|
# `res` is a plain value. Wrap it in a `Deferred`.
|
||||||
return defer.succeed(res)
|
res = defer.succeed(res)
|
||||||
|
|
||||||
if res.called and not res.paused:
|
if res.called and not res.paused:
|
||||||
# The function should have maintained the logcontext, so we can
|
# The function should have maintained the logcontext, so we can
|
||||||
|
|
|
@ -83,7 +83,7 @@ class FederationClientTest(FederatingHomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# mock up the response, and have the agent return it
|
# mock up the response, and have the agent return it
|
||||||
self._mock_agent.request.return_value = defer.succeed(
|
self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed(
|
||||||
_mock_response(
|
_mock_response(
|
||||||
{
|
{
|
||||||
"pdus": [
|
"pdus": [
|
||||||
|
|
|
@ -226,7 +226,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
|
||||||
# Send the server a device list EDU for the other user, this will cause
|
# Send the server a device list EDU for the other user, this will cause
|
||||||
# it to try and resync the device lists.
|
# it to try and resync the device lists.
|
||||||
self.hs.get_federation_transport_client().query_user_devices.return_value = (
|
self.hs.get_federation_transport_client().query_user_devices.return_value = (
|
||||||
defer.succeed(
|
make_awaitable(
|
||||||
{
|
{
|
||||||
"stream_id": "1",
|
"stream_id": "1",
|
||||||
"user_id": "@user2:host2",
|
"user_id": "@user2:host2",
|
||||||
|
|
|
@ -19,7 +19,6 @@ from unittest import mock
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
from signedjson import key as key, sign as sign
|
from signedjson import key as key, sign as sign
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
from twisted.test.proto_helpers import MemoryReactor
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.api.constants import RoomEncryptionAlgorithms
|
from synapse.api.constants import RoomEncryptionAlgorithms
|
||||||
|
@ -704,7 +703,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
|
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
|
||||||
|
|
||||||
self.hs.get_federation_client().query_client_keys = mock.Mock(
|
self.hs.get_federation_client().query_client_keys = mock.Mock(
|
||||||
return_value=defer.succeed(
|
return_value=make_awaitable(
|
||||||
{
|
{
|
||||||
"device_keys": {remote_user_id: {}},
|
"device_keys": {remote_user_id: {}},
|
||||||
"master_keys": {
|
"master_keys": {
|
||||||
|
@ -777,14 +776,14 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
# Pretend we're sharing a room with the user we're querying. If not,
|
# Pretend we're sharing a room with the user we're querying. If not,
|
||||||
# `_query_devices_for_destination` will return early.
|
# `_query_devices_for_destination` will return early.
|
||||||
self.store.get_rooms_for_user = mock.Mock(
|
self.store.get_rooms_for_user = mock.Mock(
|
||||||
return_value=defer.succeed({"some_room_id"})
|
return_value=make_awaitable({"some_room_id"})
|
||||||
)
|
)
|
||||||
|
|
||||||
remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
|
remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
|
||||||
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
|
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
|
||||||
|
|
||||||
self.hs.get_federation_client().query_user_devices = mock.Mock(
|
self.hs.get_federation_client().query_user_devices = mock.Mock(
|
||||||
return_value=defer.succeed(
|
return_value=make_awaitable(
|
||||||
{
|
{
|
||||||
"user_id": remote_user_id,
|
"user_id": remote_user_id,
|
||||||
"stream_id": 1,
|
"stream_id": 1,
|
||||||
|
|
|
@ -17,8 +17,6 @@
|
||||||
from typing import Any, Type, Union
|
from typing import Any, Type, Union
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
import synapse
|
import synapse
|
||||||
from synapse.api.constants import LoginType
|
from synapse.api.constants import LoginType
|
||||||
from synapse.api.errors import Codes
|
from synapse.api.errors import Codes
|
||||||
|
@ -190,7 +188,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS)
|
self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS)
|
||||||
|
|
||||||
# check_password must return an awaitable
|
# check_password must return an awaitable
|
||||||
mock_password_provider.check_password.return_value = defer.succeed(True)
|
mock_password_provider.check_password.return_value = make_awaitable(True)
|
||||||
channel = self._send_password_login("u", "p")
|
channel = self._send_password_login("u", "p")
|
||||||
self.assertEqual(channel.code, 200, channel.result)
|
self.assertEqual(channel.code, 200, channel.result)
|
||||||
self.assertEqual("@u:test", channel.json_body["user_id"])
|
self.assertEqual("@u:test", channel.json_body["user_id"])
|
||||||
|
@ -226,13 +224,13 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
self.get_success(module_api.register_user("u"))
|
self.get_success(module_api.register_user("u"))
|
||||||
|
|
||||||
# log in twice, to get two devices
|
# log in twice, to get two devices
|
||||||
mock_password_provider.check_password.return_value = defer.succeed(True)
|
mock_password_provider.check_password.return_value = make_awaitable(True)
|
||||||
tok1 = self.login("u", "p")
|
tok1 = self.login("u", "p")
|
||||||
self.login("u", "p", device_id="dev2")
|
self.login("u", "p", device_id="dev2")
|
||||||
mock_password_provider.reset_mock()
|
mock_password_provider.reset_mock()
|
||||||
|
|
||||||
# have the auth provider deny the request to start with
|
# have the auth provider deny the request to start with
|
||||||
mock_password_provider.check_password.return_value = defer.succeed(False)
|
mock_password_provider.check_password.return_value = make_awaitable(False)
|
||||||
|
|
||||||
# make the initial request which returns a 401
|
# make the initial request which returns a 401
|
||||||
session = self._start_delete_device_session(tok1, "dev2")
|
session = self._start_delete_device_session(tok1, "dev2")
|
||||||
|
@ -246,7 +244,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
mock_password_provider.reset_mock()
|
mock_password_provider.reset_mock()
|
||||||
|
|
||||||
# Finally, check the request goes through when we allow it
|
# Finally, check the request goes through when we allow it
|
||||||
mock_password_provider.check_password.return_value = defer.succeed(True)
|
mock_password_provider.check_password.return_value = make_awaitable(True)
|
||||||
channel = self._authed_delete_device(tok1, "dev2", session, "u", "p")
|
channel = self._authed_delete_device(tok1, "dev2", session, "u", "p")
|
||||||
self.assertEqual(channel.code, 200)
|
self.assertEqual(channel.code, 200)
|
||||||
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
|
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
|
||||||
|
@ -260,7 +258,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
self.register_user("localuser", "localpass")
|
self.register_user("localuser", "localpass")
|
||||||
|
|
||||||
# check_password must return an awaitable
|
# check_password must return an awaitable
|
||||||
mock_password_provider.check_password.return_value = defer.succeed(False)
|
mock_password_provider.check_password.return_value = make_awaitable(False)
|
||||||
channel = self._send_password_login("u", "p")
|
channel = self._send_password_login("u", "p")
|
||||||
self.assertEqual(channel.code, 403, channel.result)
|
self.assertEqual(channel.code, 403, channel.result)
|
||||||
|
|
||||||
|
@ -277,7 +275,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
self.register_user("localuser", "localpass")
|
self.register_user("localuser", "localpass")
|
||||||
|
|
||||||
# have the auth provider deny the request
|
# have the auth provider deny the request
|
||||||
mock_password_provider.check_password.return_value = defer.succeed(False)
|
mock_password_provider.check_password.return_value = make_awaitable(False)
|
||||||
|
|
||||||
# log in twice, to get two devices
|
# log in twice, to get two devices
|
||||||
tok1 = self.login("localuser", "localpass")
|
tok1 = self.login("localuser", "localpass")
|
||||||
|
@ -320,7 +318,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
self.register_user("localuser", "localpass")
|
self.register_user("localuser", "localpass")
|
||||||
|
|
||||||
# check_password must return an awaitable
|
# check_password must return an awaitable
|
||||||
mock_password_provider.check_password.return_value = defer.succeed(False)
|
mock_password_provider.check_password.return_value = make_awaitable(False)
|
||||||
channel = self._send_password_login("localuser", "localpass")
|
channel = self._send_password_login("localuser", "localpass")
|
||||||
self.assertEqual(channel.code, 403)
|
self.assertEqual(channel.code, 403)
|
||||||
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
||||||
|
@ -342,7 +340,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
self.register_user("localuser", "localpass")
|
self.register_user("localuser", "localpass")
|
||||||
|
|
||||||
# allow login via the auth provider
|
# allow login via the auth provider
|
||||||
mock_password_provider.check_password.return_value = defer.succeed(True)
|
mock_password_provider.check_password.return_value = make_awaitable(True)
|
||||||
|
|
||||||
# log in twice, to get two devices
|
# log in twice, to get two devices
|
||||||
tok1 = self.login("localuser", "p")
|
tok1 = self.login("localuser", "p")
|
||||||
|
@ -359,7 +357,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
mock_password_provider.check_password.assert_not_called()
|
mock_password_provider.check_password.assert_not_called()
|
||||||
|
|
||||||
# now try deleting with the local password
|
# now try deleting with the local password
|
||||||
mock_password_provider.check_password.return_value = defer.succeed(False)
|
mock_password_provider.check_password.return_value = make_awaitable(False)
|
||||||
channel = self._authed_delete_device(
|
channel = self._authed_delete_device(
|
||||||
tok1, "dev2", session, "localuser", "localpass"
|
tok1, "dev2", session, "localuser", "localpass"
|
||||||
)
|
)
|
||||||
|
@ -413,7 +411,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(channel.code, 400, channel.result)
|
self.assertEqual(channel.code, 400, channel.result)
|
||||||
mock_password_provider.check_auth.assert_not_called()
|
mock_password_provider.check_auth.assert_not_called()
|
||||||
|
|
||||||
mock_password_provider.check_auth.return_value = defer.succeed(
|
mock_password_provider.check_auth.return_value = make_awaitable(
|
||||||
("@user:bz", None)
|
("@user:bz", None)
|
||||||
)
|
)
|
||||||
channel = self._send_login("test.login_type", "u", test_field="y")
|
channel = self._send_login("test.login_type", "u", test_field="y")
|
||||||
|
@ -427,7 +425,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
# try a weird username. Again, it's unclear what we *expect* to happen
|
# try a weird username. Again, it's unclear what we *expect* to happen
|
||||||
# in these cases, but at least we can guard against the API changing
|
# in these cases, but at least we can guard against the API changing
|
||||||
# unexpectedly
|
# unexpectedly
|
||||||
mock_password_provider.check_auth.return_value = defer.succeed(
|
mock_password_provider.check_auth.return_value = make_awaitable(
|
||||||
("@ MALFORMED! :bz", None)
|
("@ MALFORMED! :bz", None)
|
||||||
)
|
)
|
||||||
channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ")
|
channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ")
|
||||||
|
@ -477,7 +475,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
mock_password_provider.reset_mock()
|
mock_password_provider.reset_mock()
|
||||||
|
|
||||||
# right params, but authing as the wrong user
|
# right params, but authing as the wrong user
|
||||||
mock_password_provider.check_auth.return_value = defer.succeed(
|
mock_password_provider.check_auth.return_value = make_awaitable(
|
||||||
("@user:bz", None)
|
("@user:bz", None)
|
||||||
)
|
)
|
||||||
body["auth"]["test_field"] = "foo"
|
body["auth"]["test_field"] = "foo"
|
||||||
|
@ -490,7 +488,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
mock_password_provider.reset_mock()
|
mock_password_provider.reset_mock()
|
||||||
|
|
||||||
# and finally, succeed
|
# and finally, succeed
|
||||||
mock_password_provider.check_auth.return_value = defer.succeed(
|
mock_password_provider.check_auth.return_value = make_awaitable(
|
||||||
("@localuser:test", None)
|
("@localuser:test", None)
|
||||||
)
|
)
|
||||||
channel = self._delete_device(tok1, "dev2", body)
|
channel = self._delete_device(tok1, "dev2", body)
|
||||||
|
@ -508,9 +506,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
self.custom_auth_provider_callback_test_body()
|
self.custom_auth_provider_callback_test_body()
|
||||||
|
|
||||||
def custom_auth_provider_callback_test_body(self):
|
def custom_auth_provider_callback_test_body(self):
|
||||||
callback = Mock(return_value=defer.succeed(None))
|
callback = Mock(return_value=make_awaitable(None))
|
||||||
|
|
||||||
mock_password_provider.check_auth.return_value = defer.succeed(
|
mock_password_provider.check_auth.return_value = make_awaitable(
|
||||||
("@user:bz", callback)
|
("@user:bz", callback)
|
||||||
)
|
)
|
||||||
channel = self._send_login("test.login_type", "u", test_field="y")
|
channel = self._send_login("test.login_type", "u", test_field="y")
|
||||||
|
@ -646,7 +644,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
login is disabled"""
|
login is disabled"""
|
||||||
# register the user and log in twice via the test login type to get two devices,
|
# register the user and log in twice via the test login type to get two devices,
|
||||||
self.register_user("localuser", "localpass")
|
self.register_user("localuser", "localpass")
|
||||||
mock_password_provider.check_auth.return_value = defer.succeed(
|
mock_password_provider.check_auth.return_value = make_awaitable(
|
||||||
("@localuser:test", None)
|
("@localuser:test", None)
|
||||||
)
|
)
|
||||||
channel = self._send_login("test.login_type", "localuser", test_field="")
|
channel = self._send_login("test.login_type", "localuser", test_field="")
|
||||||
|
|
|
@ -65,11 +65,11 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||||
# we mock out the keyring so as to skip the authentication check on the
|
# we mock out the keyring so as to skip the authentication check on the
|
||||||
# federation API call.
|
# federation API call.
|
||||||
mock_keyring = Mock(spec=["verify_json_for_server"])
|
mock_keyring = Mock(spec=["verify_json_for_server"])
|
||||||
mock_keyring.verify_json_for_server.return_value = defer.succeed(True)
|
mock_keyring.verify_json_for_server.return_value = make_awaitable(True)
|
||||||
|
|
||||||
# we mock out the federation client too
|
# we mock out the federation client too
|
||||||
mock_federation_client = Mock(spec=["put_json"])
|
mock_federation_client = Mock(spec=["put_json"])
|
||||||
mock_federation_client.put_json.return_value = defer.succeed((200, "OK"))
|
mock_federation_client.put_json.return_value = make_awaitable((200, "OK"))
|
||||||
|
|
||||||
# the tests assume that we are starting at unix time 1000
|
# the tests assume that we are starting at unix time 1000
|
||||||
reactor.pump((1000,))
|
reactor.pump((1000,))
|
||||||
|
@ -98,7 +98,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.datastore = hs.get_datastores().main
|
self.datastore = hs.get_datastores().main
|
||||||
self.datastore.get_destination_retry_timings = Mock(
|
self.datastore.get_destination_retry_timings = Mock(
|
||||||
return_value=defer.succeed(None)
|
return_value=make_awaitable(None)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.datastore.get_device_updates_by_remote = Mock(
|
self.datastore.get_device_updates_by_remote = Mock(
|
||||||
|
|
|
@ -15,7 +15,6 @@ from typing import Tuple
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
from twisted.test.proto_helpers import MemoryReactor
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
import synapse.rest.admin
|
import synapse.rest.admin
|
||||||
|
@ -30,6 +29,7 @@ from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.storage.test_user_directory import GetUserDirectoryTables
|
from tests.storage.test_user_directory import GetUserDirectoryTables
|
||||||
|
from tests.test_utils import make_awaitable
|
||||||
from tests.test_utils.event_injection import inject_member_event
|
from tests.test_utils.event_injection import inject_member_event
|
||||||
from tests.unittest import override_config
|
from tests.unittest import override_config
|
||||||
|
|
||||||
|
@ -439,7 +439,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_remove_from_user_dir = Mock(return_value=defer.succeed(None))
|
mock_remove_from_user_dir = Mock(return_value=make_awaitable(None))
|
||||||
with patch.object(
|
with patch.object(
|
||||||
self.store, "remove_from_user_dir", mock_remove_from_user_dir
|
self.store, "remove_from_user_dir", mock_remove_from_user_dir
|
||||||
):
|
):
|
||||||
|
@ -454,7 +454,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
||||||
self.store.register_user(user_id=r_user_id, password_hash=None)
|
self.store.register_user(user_id=r_user_id, password_hash=None)
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_remove_from_user_dir = Mock(return_value=defer.succeed(None))
|
mock_remove_from_user_dir = Mock(return_value=make_awaitable(None))
|
||||||
with patch.object(
|
with patch.object(
|
||||||
self.store, "remove_from_user_dir", mock_remove_from_user_dir
|
self.store, "remove_from_user_dir", mock_remove_from_user_dir
|
||||||
):
|
):
|
||||||
|
|
|
@ -14,7 +14,6 @@
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
from twisted.test.proto_helpers import MemoryReactor
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.handlers.presence import PresenceHandler
|
from synapse.handlers.presence import PresenceHandler
|
||||||
|
@ -24,6 +23,7 @@ from synapse.types import UserID
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
from tests.test_utils import make_awaitable
|
||||||
|
|
||||||
|
|
||||||
class PresenceTestCase(unittest.HomeserverTestCase):
|
class PresenceTestCase(unittest.HomeserverTestCase):
|
||||||
|
@ -37,7 +37,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
|
||||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
|
|
||||||
presence_handler = Mock(spec=PresenceHandler)
|
presence_handler = Mock(spec=PresenceHandler)
|
||||||
presence_handler.set_state.return_value = defer.succeed(None)
|
presence_handler.set_state.return_value = make_awaitable(None)
|
||||||
|
|
||||||
hs = self.setup_test_homeserver(
|
hs = self.setup_test_homeserver(
|
||||||
"red",
|
"red",
|
||||||
|
|
|
@ -22,7 +22,6 @@ from typing import Any, Dict, Iterable, List, Optional
|
||||||
from unittest.mock import Mock, call
|
from unittest.mock import Mock, call
|
||||||
from urllib import parse as urlparse
|
from urllib import parse as urlparse
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
from twisted.test.proto_helpers import MemoryReactor
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
import synapse.rest.admin
|
import synapse.rest.admin
|
||||||
|
@ -1426,9 +1425,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
def test_simple(self) -> None:
|
def test_simple(self) -> None:
|
||||||
"Simple test for searching rooms over federation"
|
"Simple test for searching rooms over federation"
|
||||||
self.federation_client.get_public_rooms.side_effect = lambda *a, **k: defer.succeed( # type: ignore[attr-defined]
|
self.federation_client.get_public_rooms.return_value = make_awaitable({}) # type: ignore[attr-defined]
|
||||||
{}
|
|
||||||
)
|
|
||||||
|
|
||||||
search_filter = {"generic_search_term": "foobar"}
|
search_filter = {"generic_search_term": "foobar"}
|
||||||
|
|
||||||
|
@ -1456,7 +1453,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
|
||||||
# with a 404, when using search filters.
|
# with a 404, when using search filters.
|
||||||
self.federation_client.get_public_rooms.side_effect = ( # type: ignore[attr-defined]
|
self.federation_client.get_public_rooms.side_effect = ( # type: ignore[attr-defined]
|
||||||
HttpResponseException(404, "Not Found", b""),
|
HttpResponseException(404, "Not Found", b""),
|
||||||
defer.succeed({}),
|
make_awaitable({}),
|
||||||
)
|
)
|
||||||
|
|
||||||
search_filter = {"generic_search_term": "foobar"}
|
search_filter = {"generic_search_term": "foobar"}
|
||||||
|
|
|
@ -22,6 +22,7 @@ from synapse.rest.client.transactions import CLEANUP_PERIOD_MS, HttpTransactionC
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
from tests.test_utils import make_awaitable
|
||||||
from tests.utils import MockClock
|
from tests.utils import MockClock
|
||||||
|
|
||||||
|
|
||||||
|
@ -38,7 +39,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_executes_given_function(self):
|
def test_executes_given_function(self):
|
||||||
cb = Mock(return_value=defer.succeed(self.mock_http_response))
|
cb = Mock(return_value=make_awaitable(self.mock_http_response))
|
||||||
res = yield self.cache.fetch_or_execute(
|
res = yield self.cache.fetch_or_execute(
|
||||||
self.mock_key, cb, "some_arg", keyword="arg"
|
self.mock_key, cb, "some_arg", keyword="arg"
|
||||||
)
|
)
|
||||||
|
@ -47,7 +48,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_deduplicates_based_on_key(self):
|
def test_deduplicates_based_on_key(self):
|
||||||
cb = Mock(return_value=defer.succeed(self.mock_http_response))
|
cb = Mock(return_value=make_awaitable(self.mock_http_response))
|
||||||
for i in range(3): # invoke multiple times
|
for i in range(3): # invoke multiple times
|
||||||
res = yield self.cache.fetch_or_execute(
|
res = yield self.cache.fetch_or_execute(
|
||||||
self.mock_key, cb, "some_arg", keyword="arg", changing_args=i
|
self.mock_key, cb, "some_arg", keyword="arg", changing_args=i
|
||||||
|
@ -130,7 +131,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_cleans_up(self):
|
def test_cleans_up(self):
|
||||||
cb = Mock(return_value=defer.succeed(self.mock_http_response))
|
cb = Mock(return_value=make_awaitable(self.mock_http_response))
|
||||||
yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg")
|
yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg")
|
||||||
# should NOT have cleaned up yet
|
# should NOT have cleaned up yet
|
||||||
self.clock.advance_time_msec(CLEANUP_PERIOD_MS / 2)
|
self.clock.advance_time_msec(CLEANUP_PERIOD_MS / 2)
|
||||||
|
|
|
@ -14,8 +14,6 @@
|
||||||
|
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, LimitBlockingTypes, ServerNoticeMsgType
|
from synapse.api.constants import EventTypes, LimitBlockingTypes, ServerNoticeMsgType
|
||||||
from synapse.api.errors import ResourceLimitError
|
from synapse.api.errors import ResourceLimitError
|
||||||
from synapse.rest import admin
|
from synapse.rest import admin
|
||||||
|
@ -68,16 +66,16 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
||||||
return_value=make_awaitable(1000)
|
return_value=make_awaitable(1000)
|
||||||
)
|
)
|
||||||
self._rlsn._server_notices_manager.send_notice = Mock(
|
self._rlsn._server_notices_manager.send_notice = Mock(
|
||||||
return_value=defer.succeed(Mock())
|
return_value=make_awaitable(Mock())
|
||||||
)
|
)
|
||||||
self._send_notice = self._rlsn._server_notices_manager.send_notice
|
self._send_notice = self._rlsn._server_notices_manager.send_notice
|
||||||
|
|
||||||
self.user_id = "@user_id:test"
|
self.user_id = "@user_id:test"
|
||||||
|
|
||||||
self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = Mock(
|
self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = Mock(
|
||||||
return_value=defer.succeed("!something:localhost")
|
return_value=make_awaitable("!something:localhost")
|
||||||
)
|
)
|
||||||
self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None))
|
self._rlsn._store.add_tag_to_room = Mock(return_value=make_awaitable(None))
|
||||||
self._rlsn._store.get_tags_for_room = Mock(return_value=make_awaitable({}))
|
self._rlsn._store.get_tags_for_room = Mock(return_value=make_awaitable({}))
|
||||||
|
|
||||||
@override_config({"hs_disabled": True})
|
@override_config({"hs_disabled": True})
|
||||||
|
@ -95,7 +93,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
||||||
def test_maybe_send_server_notice_to_user_remove_blocked_notice(self):
|
def test_maybe_send_server_notice_to_user_remove_blocked_notice(self):
|
||||||
"""Test when user has blocked notice, but should have it removed"""
|
"""Test when user has blocked notice, but should have it removed"""
|
||||||
|
|
||||||
self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
|
self._rlsn._auth.check_auth_blocking = Mock(return_value=make_awaitable(None))
|
||||||
mock_event = Mock(
|
mock_event = Mock(
|
||||||
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
|
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
|
||||||
)
|
)
|
||||||
|
@ -111,7 +109,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
||||||
Test when user has blocked notice, but notice ought to be there (NOOP)
|
Test when user has blocked notice, but notice ought to be there (NOOP)
|
||||||
"""
|
"""
|
||||||
self._rlsn._auth.check_auth_blocking = Mock(
|
self._rlsn._auth.check_auth_blocking = Mock(
|
||||||
return_value=defer.succeed(None), side_effect=ResourceLimitError(403, "foo")
|
return_value=make_awaitable(None),
|
||||||
|
side_effect=ResourceLimitError(403, "foo"),
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_event = Mock(
|
mock_event = Mock(
|
||||||
|
@ -130,7 +129,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
||||||
Test when user does not have blocked notice, but should have one
|
Test when user does not have blocked notice, but should have one
|
||||||
"""
|
"""
|
||||||
self._rlsn._auth.check_auth_blocking = Mock(
|
self._rlsn._auth.check_auth_blocking = Mock(
|
||||||
return_value=defer.succeed(None), side_effect=ResourceLimitError(403, "foo")
|
return_value=make_awaitable(None),
|
||||||
|
side_effect=ResourceLimitError(403, "foo"),
|
||||||
)
|
)
|
||||||
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
|
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
|
||||||
|
|
||||||
|
@ -141,7 +141,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
||||||
"""
|
"""
|
||||||
Test when user does not have blocked notice, nor should they (NOOP)
|
Test when user does not have blocked notice, nor should they (NOOP)
|
||||||
"""
|
"""
|
||||||
self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
|
self._rlsn._auth.check_auth_blocking = Mock(return_value=make_awaitable(None))
|
||||||
|
|
||||||
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
|
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
|
||||||
|
|
||||||
|
@ -152,7 +152,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
||||||
Test when user is not part of the MAU cohort - this should not ever
|
Test when user is not part of the MAU cohort - this should not ever
|
||||||
happen - but ...
|
happen - but ...
|
||||||
"""
|
"""
|
||||||
self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
|
self._rlsn._auth.check_auth_blocking = Mock(return_value=make_awaitable(None))
|
||||||
self._rlsn._store.user_last_seen_monthly_active = Mock(
|
self._rlsn._store.user_last_seen_monthly_active = Mock(
|
||||||
return_value=make_awaitable(None)
|
return_value=make_awaitable(None)
|
||||||
)
|
)
|
||||||
|
@ -167,7 +167,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
||||||
an alert message is not sent into the room
|
an alert message is not sent into the room
|
||||||
"""
|
"""
|
||||||
self._rlsn._auth.check_auth_blocking = Mock(
|
self._rlsn._auth.check_auth_blocking = Mock(
|
||||||
return_value=defer.succeed(None),
|
return_value=make_awaitable(None),
|
||||||
side_effect=ResourceLimitError(
|
side_effect=ResourceLimitError(
|
||||||
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
|
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
|
||||||
),
|
),
|
||||||
|
@ -182,7 +182,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
||||||
Test that when a server is disabled, that MAU limit alerting is ignored.
|
Test that when a server is disabled, that MAU limit alerting is ignored.
|
||||||
"""
|
"""
|
||||||
self._rlsn._auth.check_auth_blocking = Mock(
|
self._rlsn._auth.check_auth_blocking = Mock(
|
||||||
return_value=defer.succeed(None),
|
return_value=make_awaitable(None),
|
||||||
side_effect=ResourceLimitError(
|
side_effect=ResourceLimitError(
|
||||||
403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED
|
403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED
|
||||||
),
|
),
|
||||||
|
@ -199,14 +199,14 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
||||||
is suppressed that the room is returned to an unblocked state.
|
is suppressed that the room is returned to an unblocked state.
|
||||||
"""
|
"""
|
||||||
self._rlsn._auth.check_auth_blocking = Mock(
|
self._rlsn._auth.check_auth_blocking = Mock(
|
||||||
return_value=defer.succeed(None),
|
return_value=make_awaitable(None),
|
||||||
side_effect=ResourceLimitError(
|
side_effect=ResourceLimitError(
|
||||||
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
|
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
self._rlsn._server_notices_manager.__is_room_currently_blocked = Mock(
|
self._rlsn._server_notices_manager.__is_room_currently_blocked = Mock(
|
||||||
return_value=defer.succeed((True, []))
|
return_value=make_awaitable((True, []))
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_event = Mock(
|
mock_event = Mock(
|
||||||
|
|
|
@ -14,7 +14,6 @@
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
from twisted.test.proto_helpers import MemoryReactor
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.api.constants import UserTypes
|
from synapse.api.constants import UserTypes
|
||||||
|
@ -259,10 +258,10 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
|
||||||
def test_populate_monthly_users_should_update(self):
|
def test_populate_monthly_users_should_update(self):
|
||||||
self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]
|
self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]
|
||||||
|
|
||||||
self.store.is_trial_user = Mock(return_value=defer.succeed(False)) # type: ignore[assignment]
|
self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment]
|
||||||
|
|
||||||
self.store.user_last_seen_monthly_active = Mock(
|
self.store.user_last_seen_monthly_active = Mock(
|
||||||
return_value=defer.succeed(None)
|
return_value=make_awaitable(None)
|
||||||
)
|
)
|
||||||
d = self.store.populate_monthly_active_users("user_id")
|
d = self.store.populate_monthly_active_users("user_id")
|
||||||
self.get_success(d)
|
self.get_success(d)
|
||||||
|
@ -272,9 +271,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
|
||||||
def test_populate_monthly_users_should_not_update(self):
|
def test_populate_monthly_users_should_not_update(self):
|
||||||
self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]
|
self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]
|
||||||
|
|
||||||
self.store.is_trial_user = Mock(return_value=defer.succeed(False)) # type: ignore[assignment]
|
self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment]
|
||||||
self.store.user_last_seen_monthly_active = Mock(
|
self.store.user_last_seen_monthly_active = Mock(
|
||||||
return_value=defer.succeed(self.hs.get_clock().time_msec())
|
return_value=make_awaitable(self.hs.get_clock().time_msec())
|
||||||
)
|
)
|
||||||
|
|
||||||
d = self.store.populate_monthly_active_users("user_id")
|
d = self.store.populate_monthly_active_users("user_id")
|
||||||
|
|
|
@ -233,7 +233,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
|
||||||
# Register mock device list retrieval on the federation client.
|
# Register mock device list retrieval on the federation client.
|
||||||
federation_client = self.homeserver.get_federation_client()
|
federation_client = self.homeserver.get_federation_client()
|
||||||
federation_client.query_user_devices = Mock(
|
federation_client.query_user_devices = Mock(
|
||||||
return_value=succeed(
|
return_value=make_awaitable(
|
||||||
{
|
{
|
||||||
"user_id": remote_user_id,
|
"user_id": remote_user_id,
|
||||||
"stream_id": 1,
|
"stream_id": 1,
|
||||||
|
|
Loading…
Reference in a new issue