0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-06-15 17:18:19 +02:00

Replace make_awaitable with AsyncMock (#16179)

Python 3.8 provides a native AsyncMock, we can replace the
homegrown version we have.
This commit is contained in:
Patrick Cloke 2023-08-24 19:38:46 -04:00 committed by GitHub
parent 5856a8ba42
commit daf11e26ef
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
36 changed files with 508 additions and 604 deletions

1
changelog.d/16179.misc Normal file
View file

@ -0,0 +1 @@
Use `AsyncMock` instead of custom code.

View file

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import time import time
from typing import Any, Dict, List, Optional, cast from typing import Any, Dict, List, Optional, cast
from unittest.mock import Mock from unittest.mock import AsyncMock, Mock
import attr import attr
import canonicaljson import canonicaljson
@ -45,7 +45,6 @@ from synapse.types import JsonDict
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable
from tests.unittest import logcontext_clean, override_config from tests.unittest import logcontext_clean, override_config
@ -291,7 +290,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
with a null `ts_valid_until_ms` with a null `ts_valid_until_ms`
""" """
mock_fetcher = Mock() mock_fetcher = Mock()
mock_fetcher.get_keys = Mock(return_value=make_awaitable({})) mock_fetcher.get_keys = AsyncMock(return_value={})
key1 = signedjson.key.generate_signing_key("1") key1 = signedjson.key.generate_signing_key("1")
r = self.hs.get_datastores().main.store_server_signature_keys( r = self.hs.get_datastores().main.store_server_signature_keys(

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from unittest.mock import Mock from unittest.mock import AsyncMock
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.rest import admin from synapse.rest import admin
@ -20,7 +20,6 @@ from synapse.rest.client import login, room
from synapse.types import JsonDict, UserID, create_requester from synapse.types import JsonDict, UserID, create_requester
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable
class RoomComplexityTests(unittest.FederatingHomeserverTestCase): class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
@ -75,9 +74,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client() fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join # Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment] fed_transport.client.get_json = AsyncMock(return_value={"v1": 9999}) # type: ignore[assignment]
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] handler.federation_handler.do_invite_join = AsyncMock( # type: ignore[assignment]
return_value=make_awaitable(("", 1)) return_value=("", 1)
) )
d = handler._remote_join( d = handler._remote_join(
@ -106,9 +105,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client() fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join # Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment] fed_transport.client.get_json = AsyncMock(return_value={"v1": 9999}) # type: ignore[assignment]
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] handler.federation_handler.do_invite_join = AsyncMock( # type: ignore[assignment]
return_value=make_awaitable(("", 1)) return_value=("", 1)
) )
d = handler._remote_join( d = handler._remote_join(
@ -143,9 +142,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client() fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join # Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] fed_transport.client.get_json = AsyncMock(return_value=None) # type: ignore[assignment]
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] handler.federation_handler.do_invite_join = AsyncMock( # type: ignore[assignment]
return_value=make_awaitable(("", 1)) return_value=("", 1)
) )
# Artificially raise the complexity # Artificially raise the complexity
@ -200,9 +199,9 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client() fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join # Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment] fed_transport.client.get_json = AsyncMock(return_value={"v1": 9999}) # type: ignore[assignment]
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] handler.federation_handler.do_invite_join = AsyncMock( # type: ignore[assignment]
return_value=make_awaitable(("", 1)) return_value=("", 1)
) )
d = handler._remote_join( d = handler._remote_join(
@ -230,9 +229,9 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client() fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join # Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment] fed_transport.client.get_json = AsyncMock(return_value={"v1": 9999}) # type: ignore[assignment]
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] handler.federation_handler.do_invite_join = AsyncMock( # type: ignore[assignment]
return_value=make_awaitable(("", 1)) return_value=("", 1)
) )
d = handler._remote_join( d = handler._remote_join(

View file

@ -1,6 +1,6 @@
from typing import Callable, Collection, List, Optional, Tuple from typing import Callable, Collection, List, Optional, Tuple
from unittest import mock from unittest import mock
from unittest.mock import Mock from unittest.mock import AsyncMock, Mock
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
@ -19,7 +19,7 @@ from synapse.types import JsonDict
from synapse.util import Clock from synapse.util import Clock
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
from tests.test_utils import event_injection, make_awaitable from tests.test_utils import event_injection
from tests.unittest import FederatingHomeserverTestCase from tests.unittest import FederatingHomeserverTestCase
@ -50,8 +50,8 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
# This mock is crucial for destination_rooms to be populated. # This mock is crucial for destination_rooms to be populated.
# TODO: this seems to no longer be the case---tests pass with this mock # TODO: this seems to no longer be the case---tests pass with this mock
# commented out. # commented out.
state_storage_controller.get_current_hosts_in_room = Mock( # type: ignore[assignment] state_storage_controller.get_current_hosts_in_room = AsyncMock( # type: ignore[assignment]
return_value=make_awaitable({"test", "host2"}) return_value={"test", "host2"}
) )
# whenever send_transaction is called, record the pdu data # whenever send_transaction is called, record the pdu data

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Callable, FrozenSet, List, Optional, Set from typing import Callable, FrozenSet, List, Optional, Set
from unittest.mock import Mock from unittest.mock import AsyncMock, Mock
from signedjson import key, sign from signedjson import key, sign
from signedjson.types import BaseKey, SigningKey from signedjson.types import BaseKey, SigningKey
@ -29,7 +29,6 @@ from synapse.server import HomeServer
from synapse.types import JsonDict, ReadReceipt from synapse.types import JsonDict, ReadReceipt
from synapse.util import Clock from synapse.util import Clock
from tests.test_utils import make_awaitable
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
@ -43,12 +42,13 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.federation_transport_client = Mock(spec=["send_transaction"]) self.federation_transport_client = Mock(spec=["send_transaction"])
self.federation_transport_client.send_transaction = AsyncMock()
hs = self.setup_test_homeserver( hs = self.setup_test_homeserver(
federation_transport_client=self.federation_transport_client, federation_transport_client=self.federation_transport_client,
) )
hs.get_storage_controllers().state.get_current_hosts_in_room = Mock( # type: ignore[assignment] hs.get_storage_controllers().state.get_current_hosts_in_room = AsyncMock( # type: ignore[assignment]
return_value=make_awaitable({"test", "host2"}) return_value={"test", "host2"}
) )
hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = ( # type: ignore[assignment] hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = ( # type: ignore[assignment]
@ -64,7 +64,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
def test_send_receipts(self) -> None: def test_send_receipts(self) -> None:
mock_send_transaction = self.federation_transport_client.send_transaction mock_send_transaction = self.federation_transport_client.send_transaction
mock_send_transaction.return_value = make_awaitable({}) mock_send_transaction.return_value = {}
sender = self.hs.get_federation_sender() sender = self.hs.get_federation_sender()
receipt = ReadReceipt( receipt = ReadReceipt(
@ -104,7 +104,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
def test_send_receipts_thread(self) -> None: def test_send_receipts_thread(self) -> None:
mock_send_transaction = self.federation_transport_client.send_transaction mock_send_transaction = self.federation_transport_client.send_transaction
mock_send_transaction.return_value = make_awaitable({}) mock_send_transaction.return_value = {}
# Create receipts for: # Create receipts for:
# #
@ -180,7 +180,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
"""Send two receipts in quick succession; the second should be flushed, but """Send two receipts in quick succession; the second should be flushed, but
only after 20ms""" only after 20ms"""
mock_send_transaction = self.federation_transport_client.send_transaction mock_send_transaction = self.federation_transport_client.send_transaction
mock_send_transaction.return_value = make_awaitable({}) mock_send_transaction.return_value = {}
sender = self.hs.get_federation_sender() sender = self.hs.get_federation_sender()
receipt = ReadReceipt( receipt = ReadReceipt(
@ -276,6 +276,8 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.federation_transport_client = Mock( self.federation_transport_client = Mock(
spec=["send_transaction", "query_user_devices"] spec=["send_transaction", "query_user_devices"]
) )
self.federation_transport_client.send_transaction = AsyncMock()
self.federation_transport_client.query_user_devices = AsyncMock()
return self.setup_test_homeserver( return self.setup_test_homeserver(
federation_transport_client=self.federation_transport_client, federation_transport_client=self.federation_transport_client,
) )
@ -317,13 +319,13 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.record_transaction self.record_transaction
) )
def record_transaction( async def record_transaction(
self, txn: Transaction, json_cb: Optional[Callable[[], JsonDict]] = None self, txn: Transaction, json_cb: Optional[Callable[[], JsonDict]] = None
) -> "defer.Deferred[JsonDict]": ) -> JsonDict:
assert json_cb is not None assert json_cb is not None
data = json_cb() data = json_cb()
self.edus.extend(data["edus"]) self.edus.extend(data["edus"])
return defer.succeed({}) return {}
def test_send_device_updates(self) -> None: def test_send_device_updates(self) -> None:
"""Basic case: each device update should result in an EDU""" """Basic case: each device update should result in an EDU"""
@ -354,15 +356,11 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
# Send the server a device list EDU for the other user, this will cause # Send the server a device list EDU for the other user, this will cause
# it to try and resync the device lists. # it to try and resync the device lists.
self.federation_transport_client.query_user_devices.return_value = ( self.federation_transport_client.query_user_devices.return_value = {
make_awaitable( "stream_id": "1",
{ "user_id": "@user2:host2",
"stream_id": "1", "devices": [{"device_id": "D1"}],
"user_id": "@user2:host2", }
"devices": [{"device_id": "D1"}],
}
)
)
self.get_success( self.get_success(
self.device_handler.device_list_updater.incoming_device_list_update( self.device_handler.device_list_updater.incoming_device_list_update(
@ -533,7 +531,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
recovery recovery
""" """
mock_send_txn = self.federation_transport_client.send_transaction mock_send_txn = self.federation_transport_client.send_transaction
mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail")) mock_send_txn.side_effect = AssertionError("fail")
# create devices # create devices
u1 = self.register_user("user", "pass") u1 = self.register_user("user", "pass")
@ -578,7 +576,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
This case tests the behaviour when the server has never been reachable. This case tests the behaviour when the server has never been reachable.
""" """
mock_send_txn = self.federation_transport_client.send_transaction mock_send_txn = self.federation_transport_client.send_transaction
mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail")) mock_send_txn.side_effect = AssertionError("fail")
# create devices # create devices
u1 = self.register_user("user", "pass") u1 = self.register_user("user", "pass")
@ -636,7 +634,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
# now the server goes offline # now the server goes offline
mock_send_txn = self.federation_transport_client.send_transaction mock_send_txn = self.federation_transport_client.send_transaction
mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail")) mock_send_txn.side_effect = AssertionError("fail")
self.login("user", "pass", device_id="D2") self.login("user", "pass", device_id="D2")
self.login("user", "pass", device_id="D3") self.login("user", "pass", device_id="D3")

View file

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from typing import Dict, Iterable, List, Optional from typing import Dict, Iterable, List, Optional
from unittest.mock import Mock from unittest.mock import AsyncMock, Mock
from parameterized import parameterized from parameterized import parameterized
@ -36,7 +36,7 @@ from synapse.util import Clock
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
from tests import unittest from tests import unittest
from tests.test_utils import event_injection, make_awaitable, simple_async_mock from tests.test_utils import event_injection, simple_async_mock
from tests.unittest import override_config from tests.unittest import override_config
from tests.utils import MockClock from tests.utils import MockClock
@ -46,15 +46,13 @@ class AppServiceHandlerTestCase(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.mock_store = Mock() self.mock_store = Mock()
self.mock_as_api = Mock() self.mock_as_api = AsyncMock()
self.mock_scheduler = Mock() self.mock_scheduler = Mock()
hs = Mock() hs = Mock()
hs.get_datastores.return_value = Mock(main=self.mock_store) hs.get_datastores.return_value = Mock(main=self.mock_store)
self.mock_store.get_appservice_last_pos.return_value = make_awaitable(None) self.mock_store.get_appservice_last_pos = AsyncMock(return_value=None)
self.mock_store.set_appservice_last_pos.return_value = make_awaitable(None) self.mock_store.set_appservice_last_pos = AsyncMock(return_value=None)
self.mock_store.set_appservice_stream_type_pos.return_value = make_awaitable( self.mock_store.set_appservice_stream_type_pos = AsyncMock(return_value=None)
None
)
hs.get_application_service_api.return_value = self.mock_as_api hs.get_application_service_api.return_value = self.mock_as_api
hs.get_application_service_scheduler.return_value = self.mock_scheduler hs.get_application_service_scheduler.return_value = self.mock_scheduler
hs.get_clock.return_value = MockClock() hs.get_clock.return_value = MockClock()
@ -69,21 +67,25 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self._mkservice(is_interested_in_event=False), self._mkservice(is_interested_in_event=False),
] ]
self.mock_as_api.query_user.return_value = make_awaitable(True) self.mock_as_api.query_user.return_value = True
self.mock_store.get_app_services.return_value = services self.mock_store.get_app_services.return_value = services
self.mock_store.get_user_by_id.return_value = make_awaitable([]) self.mock_store.get_user_by_id = AsyncMock(return_value=[])
event = Mock( event = Mock(
sender="@someone:anywhere", type="m.room.message", room_id="!foo:bar" sender="@someone:anywhere", type="m.room.message", room_id="!foo:bar"
) )
self.mock_store.get_all_new_event_ids_stream.side_effect = [ self.mock_store.get_all_new_event_ids_stream = AsyncMock(
make_awaitable((0, {})), side_effect=[
make_awaitable((1, {event.event_id: 0})), (0, {}),
] (1, {event.event_id: 0}),
self.mock_store.get_events_as_list.side_effect = [ ]
make_awaitable([]), )
make_awaitable([event]), self.mock_store.get_events_as_list = AsyncMock(
] side_effect=[
[],
[event],
]
)
self.handler.notify_interested_services(RoomStreamToken(None, 1)) self.handler.notify_interested_services(RoomStreamToken(None, 1))
self.mock_scheduler.enqueue_for_appservice.assert_called_once_with( self.mock_scheduler.enqueue_for_appservice.assert_called_once_with(
@ -95,14 +97,16 @@ class AppServiceHandlerTestCase(unittest.TestCase):
services = [self._mkservice(is_interested_in_event=True)] services = [self._mkservice(is_interested_in_event=True)]
services[0].is_interested_in_user.return_value = True services[0].is_interested_in_user.return_value = True
self.mock_store.get_app_services.return_value = services self.mock_store.get_app_services.return_value = services
self.mock_store.get_user_by_id.return_value = make_awaitable(None) self.mock_store.get_user_by_id = AsyncMock(return_value=None)
event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar") event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
self.mock_as_api.query_user.return_value = make_awaitable(True) self.mock_as_api.query_user.return_value = True
self.mock_store.get_all_new_event_ids_stream.side_effect = [ self.mock_store.get_all_new_event_ids_stream = AsyncMock(
make_awaitable((0, {event.event_id: 0})), side_effect=[
] (0, {event.event_id: 0}),
self.mock_store.get_events_as_list.side_effect = [make_awaitable([event])] ]
)
self.mock_store.get_events_as_list = AsyncMock(side_effect=[[event]])
self.handler.notify_interested_services(RoomStreamToken(None, 0)) self.handler.notify_interested_services(RoomStreamToken(None, 0))
self.mock_as_api.query_user.assert_called_once_with(services[0], user_id) self.mock_as_api.query_user.assert_called_once_with(services[0], user_id)
@ -112,13 +116,15 @@ class AppServiceHandlerTestCase(unittest.TestCase):
services = [self._mkservice(is_interested_in_event=True)] services = [self._mkservice(is_interested_in_event=True)]
services[0].is_interested_in_user.return_value = True services[0].is_interested_in_user.return_value = True
self.mock_store.get_app_services.return_value = services self.mock_store.get_app_services.return_value = services
self.mock_store.get_user_by_id.return_value = make_awaitable({"name": user_id}) self.mock_store.get_user_by_id = AsyncMock(return_value={"name": user_id})
event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar") event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
self.mock_as_api.query_user.return_value = make_awaitable(True) self.mock_as_api.query_user.return_value = True
self.mock_store.get_all_new_event_ids_stream.side_effect = [ self.mock_store.get_all_new_event_ids_stream = AsyncMock(
make_awaitable((0, [event], {event.event_id: 0})), side_effect=[
] (0, [event], {event.event_id: 0}),
]
)
self.handler.notify_interested_services(RoomStreamToken(None, 0)) self.handler.notify_interested_services(RoomStreamToken(None, 0))
@ -141,10 +147,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self._mkservice_alias(is_room_alias_in_namespace=False), self._mkservice_alias(is_room_alias_in_namespace=False),
] ]
self.mock_as_api.query_alias.return_value = make_awaitable(True) self.mock_as_api.query_alias = AsyncMock(return_value=True)
self.mock_store.get_app_services.return_value = services self.mock_store.get_app_services.return_value = services
self.mock_store.get_association_from_room_alias.return_value = make_awaitable( self.mock_store.get_association_from_room_alias = AsyncMock(
Mock(room_id=room_id, servers=servers) return_value=Mock(room_id=room_id, servers=servers)
) )
result = self.successResultOf( result = self.successResultOf(
@ -177,7 +183,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
def test_get_3pe_protocols_protocol_no_response(self) -> None: def test_get_3pe_protocols_protocol_no_response(self) -> None:
service = self._mkservice(False, ["my-protocol"]) service = self._mkservice(False, ["my-protocol"])
self.mock_store.get_app_services.return_value = [service] self.mock_store.get_app_services.return_value = [service]
self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(None) self.mock_as_api.get_3pe_protocol.return_value = None
response = self.successResultOf( response = self.successResultOf(
defer.ensureDeferred(self.handler.get_3pe_protocols()) defer.ensureDeferred(self.handler.get_3pe_protocols())
) )
@ -189,9 +195,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
def test_get_3pe_protocols_select_one_protocol(self) -> None: def test_get_3pe_protocols_select_one_protocol(self) -> None:
service = self._mkservice(False, ["my-protocol"]) service = self._mkservice(False, ["my-protocol"])
self.mock_store.get_app_services.return_value = [service] self.mock_store.get_app_services.return_value = [service]
self.mock_as_api.get_3pe_protocol.return_value = make_awaitable( self.mock_as_api.get_3pe_protocol.return_value = {
{"x-protocol-data": 42, "instances": []} "x-protocol-data": 42,
) "instances": [],
}
response = self.successResultOf( response = self.successResultOf(
defer.ensureDeferred(self.handler.get_3pe_protocols("my-protocol")) defer.ensureDeferred(self.handler.get_3pe_protocols("my-protocol"))
) )
@ -205,9 +212,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
def test_get_3pe_protocols_one_protocol(self) -> None: def test_get_3pe_protocols_one_protocol(self) -> None:
service = self._mkservice(False, ["my-protocol"]) service = self._mkservice(False, ["my-protocol"])
self.mock_store.get_app_services.return_value = [service] self.mock_store.get_app_services.return_value = [service]
self.mock_as_api.get_3pe_protocol.return_value = make_awaitable( self.mock_as_api.get_3pe_protocol.return_value = {
{"x-protocol-data": 42, "instances": []} "x-protocol-data": 42,
) "instances": [],
}
response = self.successResultOf( response = self.successResultOf(
defer.ensureDeferred(self.handler.get_3pe_protocols()) defer.ensureDeferred(self.handler.get_3pe_protocols())
) )
@ -222,9 +230,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
service_one = self._mkservice(False, ["my-protocol"]) service_one = self._mkservice(False, ["my-protocol"])
service_two = self._mkservice(False, ["other-protocol"]) service_two = self._mkservice(False, ["other-protocol"])
self.mock_store.get_app_services.return_value = [service_one, service_two] self.mock_store.get_app_services.return_value = [service_one, service_two]
self.mock_as_api.get_3pe_protocol.return_value = make_awaitable( self.mock_as_api.get_3pe_protocol.return_value = {
{"x-protocol-data": 42, "instances": []} "x-protocol-data": 42,
) "instances": [],
}
response = self.successResultOf( response = self.successResultOf(
defer.ensureDeferred(self.handler.get_3pe_protocols()) defer.ensureDeferred(self.handler.get_3pe_protocols())
) )
@ -287,13 +296,11 @@ class AppServiceHandlerTestCase(unittest.TestCase):
interested_service = self._mkservice(is_interested_in_event=True) interested_service = self._mkservice(is_interested_in_event=True)
services = [interested_service] services = [interested_service]
self.mock_store.get_app_services.return_value = services self.mock_store.get_app_services.return_value = services
self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable( self.mock_store.get_type_stream_id_for_appservice = AsyncMock(return_value=579)
579
)
event = Mock(event_id="event_1") event = Mock(event_id="event_1")
self.event_source.sources.receipt.get_new_events_as.return_value = ( self.event_source.sources.receipt.get_new_events_as = AsyncMock(
make_awaitable(([event], None)) return_value=([event], None)
) )
self.handler.notify_interested_services_ephemeral( self.handler.notify_interested_services_ephemeral(
@ -317,13 +324,11 @@ class AppServiceHandlerTestCase(unittest.TestCase):
services = [interested_service] services = [interested_service]
self.mock_store.get_app_services.return_value = services self.mock_store.get_app_services.return_value = services
self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable( self.mock_store.get_type_stream_id_for_appservice = AsyncMock(return_value=580)
580
)
event = Mock(event_id="event_1") event = Mock(event_id="event_1")
self.event_source.sources.receipt.get_new_events_as.return_value = ( self.event_source.sources.receipt.get_new_events_as = AsyncMock(
make_awaitable(([event], None)) return_value=([event], None)
) )
self.handler.notify_interested_services_ephemeral( self.handler.notify_interested_services_ephemeral(
@ -350,9 +355,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
A mock representing the ApplicationService. A mock representing the ApplicationService.
""" """
service = Mock() service = Mock()
service.is_interested_in_event.return_value = make_awaitable( service.is_interested_in_event = AsyncMock(return_value=is_interested_in_event)
is_interested_in_event
)
service.token = "mock_service_token" service.token = "mock_service_token"
service.url = "mock_service_url" service.url = "mock_service_url"
service.protocols = protocols service.protocols = protocols

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional from typing import Optional
from unittest.mock import Mock from unittest.mock import AsyncMock
import pymacaroons import pymacaroons
@ -25,7 +25,6 @@ from synapse.server import HomeServer
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable
class AuthTestCase(unittest.HomeserverTestCase): class AuthTestCase(unittest.HomeserverTestCase):
@ -166,8 +165,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
def test_mau_limits_exceeded_large(self) -> None: def test_mau_limits_exceeded_large(self) -> None:
self.auth_blocking._limit_usage_by_mau = True self.auth_blocking._limit_usage_by_mau = True
self.hs.get_datastores().main.get_monthly_active_count = Mock( self.hs.get_datastores().main.get_monthly_active_count = AsyncMock(
return_value=make_awaitable(self.large_number_of_users) return_value=self.large_number_of_users
) )
self.get_failure( self.get_failure(
@ -177,8 +176,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
ResourceLimitError, ResourceLimitError,
) )
self.hs.get_datastores().main.get_monthly_active_count = Mock( self.hs.get_datastores().main.get_monthly_active_count = AsyncMock(
return_value=make_awaitable(self.large_number_of_users) return_value=self.large_number_of_users
) )
token = self.get_success( token = self.get_success(
self.auth_handler.create_login_token_for_user_id(self.user1) self.auth_handler.create_login_token_for_user_id(self.user1)
@ -191,8 +190,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.auth_blocking._limit_usage_by_mau = True self.auth_blocking._limit_usage_by_mau = True
# Set the server to be at the edge of too many users. # Set the server to be at the edge of too many users.
self.hs.get_datastores().main.get_monthly_active_count = Mock( self.hs.get_datastores().main.get_monthly_active_count = AsyncMock(
return_value=make_awaitable(self.auth_blocking._max_mau_value) return_value=self.auth_blocking._max_mau_value
) )
# If not in monthly active cohort # If not in monthly active cohort
@ -208,8 +207,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.assertIsNone(self.token_login(token)) self.assertIsNone(self.token_login(token))
# If in monthly active cohort # If in monthly active cohort
self.hs.get_datastores().main.user_last_seen_monthly_active = Mock( self.hs.get_datastores().main.user_last_seen_monthly_active = AsyncMock(
return_value=make_awaitable(self.clock.time_msec()) return_value=self.clock.time_msec()
) )
self.get_success( self.get_success(
self.auth_handler.create_access_token_for_user_id( self.auth_handler.create_access_token_for_user_id(
@ -224,8 +223,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
def test_mau_limits_not_exceeded(self) -> None: def test_mau_limits_not_exceeded(self) -> None:
self.auth_blocking._limit_usage_by_mau = True self.auth_blocking._limit_usage_by_mau = True
self.hs.get_datastores().main.get_monthly_active_count = Mock( self.hs.get_datastores().main.get_monthly_active_count = AsyncMock(
return_value=make_awaitable(self.small_number_of_users) return_value=self.small_number_of_users
) )
# Ensure does not raise exception # Ensure does not raise exception
self.get_success( self.get_success(
@ -234,8 +233,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
) )
) )
self.hs.get_datastores().main.get_monthly_active_count = Mock( self.hs.get_datastores().main.get_monthly_active_count = AsyncMock(
return_value=make_awaitable(self.small_number_of_users) return_value=self.small_number_of_users
) )
token = self.get_success( token = self.get_success(
self.auth_handler.create_login_token_for_user_id(self.user1) self.auth_handler.create_login_token_for_user_id(self.user1)

View file

@ -32,7 +32,6 @@ from synapse.types import JsonDict, create_requester
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable
from tests.unittest import override_config from tests.unittest import override_config
user1 = "@boris:aaa" user1 = "@boris:aaa"
@ -41,7 +40,7 @@ user2 = "@theresa:bbb"
class DeviceTestCase(unittest.HomeserverTestCase): class DeviceTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.appservice_api = mock.Mock() self.appservice_api = mock.AsyncMock()
hs = self.setup_test_homeserver( hs = self.setup_test_homeserver(
"server", "server",
application_service_api=self.appservice_api, application_service_api=self.appservice_api,
@ -375,13 +374,11 @@ class DeviceTestCase(unittest.HomeserverTestCase):
) )
# Setup a response. # Setup a response.
self.appservice_api.query_keys.return_value = make_awaitable( self.appservice_api.query_keys.return_value = {
{ "device_keys": {
"device_keys": { local_user: {device_2: device_key_2b, device_3: device_key_3}
local_user: {device_2: device_key_2b, device_3: device_key_3}
}
} }
) }
# Request all devices. # Request all devices.
res = self.get_success( res = self.get_success(

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Awaitable, Callable, Dict from typing import Any, Awaitable, Callable, Dict
from unittest.mock import Mock from unittest.mock import AsyncMock, Mock
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
@ -27,14 +27,13 @@ from synapse.types import JsonDict, RoomAlias, create_requester
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable
class DirectoryTestCase(unittest.HomeserverTestCase): class DirectoryTestCase(unittest.HomeserverTestCase):
"""Tests the directory service.""" """Tests the directory service."""
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.mock_federation = Mock() self.mock_federation = AsyncMock()
self.mock_registry = Mock() self.mock_registry = Mock()
self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {} self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {}
@ -73,9 +72,10 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
self.assertEqual({"room_id": "!8765qwer:test", "servers": ["test"]}, result) self.assertEqual({"room_id": "!8765qwer:test", "servers": ["test"]}, result)
def test_get_remote_association(self) -> None: def test_get_remote_association(self) -> None:
self.mock_federation.make_query.return_value = make_awaitable( self.mock_federation.make_query.return_value = {
{"room_id": "!8765qwer:test", "servers": ["test", "remote"]} "room_id": "!8765qwer:test",
) "servers": ["test", "remote"],
}
result = self.get_success(self.handler.get_association(self.remote_room)) result = self.get_success(self.handler.get_association(self.remote_room))

View file

@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Iterable from typing import Dict, Iterable
from unittest import mock from unittest import mock
from parameterized import parameterized from parameterized import parameterized
@ -31,13 +31,12 @@ from synapse.types import JsonDict, UserID
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable
from tests.unittest import override_config from tests.unittest import override_config
class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.appservice_api = mock.Mock() self.appservice_api = mock.AsyncMock()
return self.setup_test_homeserver( return self.setup_test_homeserver(
federation_client=mock.Mock(), application_service_api=self.appservice_api federation_client=mock.Mock(), application_service_api=self.appservice_api
) )
@ -801,29 +800,27 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY" remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ" remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
self.hs.get_federation_client().query_client_keys = mock.Mock( # type: ignore[assignment] self.hs.get_federation_client().query_client_keys = mock.AsyncMock( # type: ignore[assignment]
return_value=make_awaitable( return_value={
{ "device_keys": {remote_user_id: {}},
"device_keys": {remote_user_id: {}}, "master_keys": {
"master_keys": { remote_user_id: {
remote_user_id: { "user_id": remote_user_id,
"user_id": remote_user_id, "usage": ["master"],
"usage": ["master"], "keys": {"ed25519:" + remote_master_key: remote_master_key},
"keys": {"ed25519:" + remote_master_key: remote_master_key}, },
},
"self_signing_keys": {
remote_user_id: {
"user_id": remote_user_id,
"usage": ["self_signing"],
"keys": {
"ed25519:"
+ remote_self_signing_key: remote_self_signing_key
}, },
}, }
"self_signing_keys": { },
remote_user_id: { }
"user_id": remote_user_id,
"usage": ["self_signing"],
"keys": {
"ed25519:"
+ remote_self_signing_key: remote_self_signing_key
},
}
},
}
)
) )
e2e_handler = self.hs.get_e2e_keys_handler() e2e_handler = self.hs.get_e2e_keys_handler()
@ -874,34 +871,29 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
# Pretend we're sharing a room with the user we're querying. If not, # Pretend we're sharing a room with the user we're querying. If not,
# `_query_devices_for_destination` will return early. # `_query_devices_for_destination` will return early.
self.store.get_rooms_for_user = mock.Mock( self.store.get_rooms_for_user = mock.AsyncMock(return_value={"some_room_id"})
return_value=make_awaitable({"some_room_id"})
)
remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY" remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ" remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
self.hs.get_federation_client().query_user_devices = mock.Mock( # type: ignore[assignment] self.hs.get_federation_client().query_user_devices = mock.AsyncMock( # type: ignore[assignment]
return_value=make_awaitable( return_value={
{ "user_id": remote_user_id,
"stream_id": 1,
"devices": [],
"master_key": {
"user_id": remote_user_id, "user_id": remote_user_id,
"stream_id": 1, "usage": ["master"],
"devices": [], "keys": {"ed25519:" + remote_master_key: remote_master_key},
"master_key": { },
"user_id": remote_user_id, "self_signing_key": {
"usage": ["master"], "user_id": remote_user_id,
"keys": {"ed25519:" + remote_master_key: remote_master_key}, "usage": ["self_signing"],
"keys": {
"ed25519:" + remote_self_signing_key: remote_self_signing_key
}, },
"self_signing_key": { },
"user_id": remote_user_id, }
"usage": ["self_signing"],
"keys": {
"ed25519:"
+ remote_self_signing_key: remote_self_signing_key
},
},
}
)
) )
e2e_handler = self.hs.get_e2e_keys_handler() e2e_handler = self.hs.get_e2e_keys_handler()
@ -987,20 +979,20 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
mock_get_rooms = mock.patch.object( mock_get_rooms = mock.patch.object(
self.store, self.store,
"get_rooms_for_user", "get_rooms_for_user",
new_callable=mock.MagicMock, new_callable=mock.AsyncMock,
return_value=make_awaitable(["some_room_id"]), return_value=["some_room_id"],
) )
mock_get_users = mock.patch.object( mock_get_users = mock.patch.object(
self.store, self.store,
"get_users_server_still_shares_room_with", "get_users_server_still_shares_room_with",
new_callable=mock.MagicMock, new_callable=mock.AsyncMock,
return_value=make_awaitable({remote_user_id}), return_value={remote_user_id},
) )
mock_request = mock.patch.object( mock_request = mock.patch.object(
self.hs.get_federation_client(), self.hs.get_federation_client(),
"query_user_devices", "query_user_devices",
new_callable=mock.MagicMock, new_callable=mock.AsyncMock,
return_value=make_awaitable(response_body), return_value=response_body,
) )
with mock_get_rooms, mock_get_users, mock_request as mocked_federation_request: with mock_get_rooms, mock_get_users, mock_request as mocked_federation_request:
@ -1060,8 +1052,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
) )
# Setup a response, but only for device 2. # Setup a response, but only for device 2.
self.appservice_api.claim_client_keys.return_value = make_awaitable( self.appservice_api.claim_client_keys.return_value = (
({local_user: {device_id_2: otk}}, [(local_user, device_id_1, "alg1", 1)]) {local_user: {device_id_2: otk}},
[(local_user, device_id_1, "alg1", 1)],
) )
# we shouldn't have any unused fallback keys yet # we shouldn't have any unused fallback keys yet
@ -1127,9 +1120,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
) )
# Setup a response. # Setup a response.
self.appservice_api.claim_client_keys.return_value = make_awaitable( response: Dict[str, Dict[str, Dict[str, JsonDict]]] = {
({local_user: {device_id_1: {**as_otk, **as_fallback_key}}}, []) local_user: {device_id_1: {**as_otk, **as_fallback_key}}
) }
self.appservice_api.claim_client_keys.return_value = (response, [])
# Claim OTKs, which will ask the appservice and do nothing else. # Claim OTKs, which will ask the appservice and do nothing else.
claim_res = self.get_success( claim_res = self.get_success(
@ -1171,8 +1165,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
self.assertEqual(fallback_res, ["alg1"]) self.assertEqual(fallback_res, ["alg1"])
# The appservice will return only the OTK. # The appservice will return only the OTK.
self.appservice_api.claim_client_keys.return_value = make_awaitable( self.appservice_api.claim_client_keys.return_value = (
({local_user: {device_id_1: as_otk}}, []) {local_user: {device_id_1: as_otk}},
[],
) )
# Claim OTKs, which should return the OTK from the appservice and the # Claim OTKs, which should return the OTK from the appservice and the
@ -1234,8 +1229,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
self.assertEqual(fallback_res, ["alg1"]) self.assertEqual(fallback_res, ["alg1"])
# Finally, return only the fallback key from the appservice. # Finally, return only the fallback key from the appservice.
self.appservice_api.claim_client_keys.return_value = make_awaitable( self.appservice_api.claim_client_keys.return_value = (
({local_user: {device_id_1: as_fallback_key}}, []) {local_user: {device_id_1: as_fallback_key}},
[],
) )
# Claim OTKs, which will return only the fallback key from the database. # Claim OTKs, which will return only the fallback key from the database.
@ -1350,13 +1346,11 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
) )
# Setup a response. # Setup a response.
self.appservice_api.query_keys.return_value = make_awaitable( self.appservice_api.query_keys.return_value = {
{ "device_keys": {
"device_keys": { local_user: {device_2: device_key_2b, device_3: device_key_3}
local_user: {device_2: device_key_2b, device_3: device_key_3}
}
} }
) }
# Request all devices. # Request all devices.
res = self.get_success(self.handler.query_local_devices({local_user: None})) res = self.get_success(self.handler.query_local_devices({local_user: None}))

View file

@ -14,7 +14,7 @@
import logging import logging
from typing import Collection, Optional, cast from typing import Collection, Optional, cast
from unittest import TestCase from unittest import TestCase
from unittest.mock import Mock, patch from unittest.mock import AsyncMock, Mock, patch
from twisted.internet.defer import Deferred from twisted.internet.defer import Deferred
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
@ -40,7 +40,7 @@ from synapse.util import Clock
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
from tests import unittest from tests import unittest
from tests.test_utils import event_injection, make_awaitable from tests.test_utils import event_injection
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -370,7 +370,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
# We mock out the FederationClient.backfill method, to pretend that a remote # We mock out the FederationClient.backfill method, to pretend that a remote
# server has returned our fake event. # server has returned our fake event.
federation_client_backfill_mock = Mock(return_value=make_awaitable([event])) federation_client_backfill_mock = AsyncMock(return_value=[event])
self.hs.get_federation_client().backfill = federation_client_backfill_mock # type: ignore[assignment] self.hs.get_federation_client().backfill = federation_client_backfill_mock # type: ignore[assignment]
# We also mock the persist method with a side effect of itself. This allows us # We also mock the persist method with a side effect of itself. This allows us
@ -631,33 +631,29 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
}, },
RoomVersions.V10, RoomVersions.V10,
) )
mock_make_membership_event = Mock( mock_make_membership_event = AsyncMock(
return_value=make_awaitable( return_value=(
( "example.com",
"example.com", membership_event,
membership_event, RoomVersions.V10,
RoomVersions.V10,
)
) )
) )
mock_send_join = Mock( mock_send_join = AsyncMock(
return_value=make_awaitable( return_value=SendJoinResult(
SendJoinResult( membership_event,
membership_event, "example.com",
"example.com", state=[
state=[ EVENT_CREATE,
EVENT_CREATE, EVENT_CREATOR_MEMBERSHIP,
EVENT_CREATOR_MEMBERSHIP, EVENT_INVITATION_MEMBERSHIP,
EVENT_INVITATION_MEMBERSHIP, ],
], auth_chain=[
auth_chain=[ EVENT_CREATE,
EVENT_CREATE, EVENT_CREATOR_MEMBERSHIP,
EVENT_CREATOR_MEMBERSHIP, EVENT_INVITATION_MEMBERSHIP,
EVENT_INVITATION_MEMBERSHIP, ],
], partial_state=True,
partial_state=True, servers_in_room={"example.com"},
servers_in_room={"example.com"},
)
) )
) )

View file

@ -35,7 +35,7 @@ from synapse.types import JsonDict
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.test_utils import event_injection, make_awaitable from tests.test_utils import event_injection
class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
@ -50,6 +50,10 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
self.mock_federation_transport_client = mock.Mock( self.mock_federation_transport_client = mock.Mock(
spec=["get_room_state_ids", "get_room_state", "get_event", "backfill"] spec=["get_room_state_ids", "get_room_state", "get_event", "backfill"]
) )
self.mock_federation_transport_client.get_room_state_ids = mock.AsyncMock()
self.mock_federation_transport_client.get_room_state = mock.AsyncMock()
self.mock_federation_transport_client.get_event = mock.AsyncMock()
self.mock_federation_transport_client.backfill = mock.AsyncMock()
return super().setup_test_homeserver( return super().setup_test_homeserver(
federation_transport_client=self.mock_federation_transport_client federation_transport_client=self.mock_federation_transport_client
) )
@ -198,20 +202,14 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
) )
# we expect an outbound request to /state_ids, so stub that out # we expect an outbound request to /state_ids, so stub that out
self.mock_federation_transport_client.get_room_state_ids.return_value = ( self.mock_federation_transport_client.get_room_state_ids.return_value = {
make_awaitable( "pdu_ids": [e.event_id for e in state_at_prev_event],
{ "auth_chain_ids": [],
"pdu_ids": [e.event_id for e in state_at_prev_event], }
"auth_chain_ids": [],
}
)
)
# we also expect an outbound request to /state # we also expect an outbound request to /state
self.mock_federation_transport_client.get_room_state.return_value = ( self.mock_federation_transport_client.get_room_state.return_value = (
make_awaitable( StateRequestResponse(auth_events=[], state=state_at_prev_event)
StateRequestResponse(auth_events=[], state=state_at_prev_event)
)
) )
# we have to bump the clock a bit, to keep the retry logic in # we have to bump the clock a bit, to keep the retry logic in
@ -273,26 +271,23 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
room_version = self.get_success(main_store.get_room_version(room_id)) room_version = self.get_success(main_store.get_room_version(room_id))
# We expect an outbound request to /state_ids, so stub that out # We expect an outbound request to /state_ids, so stub that out
self.mock_federation_transport_client.get_room_state_ids.return_value = make_awaitable( self.mock_federation_transport_client.get_room_state_ids.return_value = {
{ # Mimic the other server not knowing about the state at all.
# Mimic the other server not knowing about the state at all. # We want to cause Synapse to throw an error (`Unable to get
# We want to cause Synapse to throw an error (`Unable to get # missing prev_event $fake_prev_event`) and fail to backfill
# missing prev_event $fake_prev_event`) and fail to backfill # the pulled event.
# the pulled event. "pdu_ids": [],
"pdu_ids": [], "auth_chain_ids": [],
"auth_chain_ids": [], }
}
)
# We also expect an outbound request to /state # We also expect an outbound request to /state
self.mock_federation_transport_client.get_room_state.return_value = make_awaitable( self.mock_federation_transport_client.get_room_state.return_value = StateRequestResponse(
StateRequestResponse( # Mimic the other server not knowing about the state at all.
# Mimic the other server not knowing about the state at all. # We want to cause Synapse to throw an error (`Unable to get
# We want to cause Synapse to throw an error (`Unable to get # missing prev_event $fake_prev_event`) and fail to backfill
# missing prev_event $fake_prev_event`) and fail to backfill # the pulled event.
# the pulled event. auth_events=[],
auth_events=[], state=[],
state=[],
)
) )
pulled_event = make_event_from_dict( pulled_event = make_event_from_dict(
@ -545,25 +540,23 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
) )
# We expect an outbound request to /backfill, so stub that out # We expect an outbound request to /backfill, so stub that out
self.mock_federation_transport_client.backfill.return_value = make_awaitable( self.mock_federation_transport_client.backfill.return_value = {
{ "origin": self.OTHER_SERVER_NAME,
"origin": self.OTHER_SERVER_NAME, "origin_server_ts": 123,
"origin_server_ts": 123, "pdus": [
"pdus": [ # This is one of the important aspects of this test: we include
# This is one of the important aspects of this test: we include # `pulled_event_without_signatures` so it fails the signature check
# `pulled_event_without_signatures` so it fails the signature check # when we filter down the backfill response down to events which
# when we filter down the backfill response down to events which # have valid signatures in
# have valid signatures in # `_check_sigs_and_hash_for_pulled_events_and_fetch`
# `_check_sigs_and_hash_for_pulled_events_and_fetch` pulled_event_without_signatures.get_pdu_json(),
pulled_event_without_signatures.get_pdu_json(), # Then later when we process this valid signature event, when we
# Then later when we process this valid signature event, when we # fetch the missing `prev_event`s, we want to make sure that we
# fetch the missing `prev_event`s, we want to make sure that we # backoff and don't try and fetch `pulled_event_without_signatures`
# backoff and don't try and fetch `pulled_event_without_signatures` # again since we know it just had an invalid signature.
# again since we know it just had an invalid signature. pulled_event.get_pdu_json(),
pulled_event.get_pdu_json(), ],
], }
}
)
# Keep track of the count and make sure we don't make any of these requests # Keep track of the count and make sure we don't make any of these requests
event_endpoint_requested_count = 0 event_endpoint_requested_count = 0
@ -731,15 +724,13 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
) )
# We expect an outbound request to /backfill, so stub that out # We expect an outbound request to /backfill, so stub that out
self.mock_federation_transport_client.backfill.return_value = make_awaitable( self.mock_federation_transport_client.backfill.return_value = {
{ "origin": self.OTHER_SERVER_NAME,
"origin": self.OTHER_SERVER_NAME, "origin_server_ts": 123,
"origin_server_ts": 123, "pdus": [
"pdus": [ pulled_event.get_pdu_json(),
pulled_event.get_pdu_json(), ],
], }
}
)
# The function under test: try to backfill and process the pulled event # The function under test: try to backfill and process the pulled event
with LoggingContext("test"): with LoggingContext("test"):

View file

@ -16,7 +16,7 @@
from http import HTTPStatus from http import HTTPStatus
from typing import Any, Dict, List, Optional, Type, Union from typing import Any, Dict, List, Optional, Type, Union
from unittest.mock import Mock from unittest.mock import AsyncMock, Mock
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
@ -32,7 +32,6 @@ from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.server import FakeChannel from tests.server import FakeChannel
from tests.test_utils import make_awaitable
from tests.unittest import override_config from tests.unittest import override_config
# Login flows we expect to appear in the list after the normal ones. # Login flows we expect to appear in the list after the normal ones.
@ -187,7 +186,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS) self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS)
# check_password must return an awaitable # check_password must return an awaitable
mock_password_provider.check_password.return_value = make_awaitable(True) mock_password_provider.check_password = AsyncMock(return_value=True)
channel = self._send_password_login("u", "p") channel = self._send_password_login("u", "p")
self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@u:test", channel.json_body["user_id"]) self.assertEqual("@u:test", channel.json_body["user_id"])
@ -209,13 +208,13 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"""UI Auth should delegate correctly to the password provider""" """UI Auth should delegate correctly to the password provider"""
# log in twice, to get two devices # log in twice, to get two devices
mock_password_provider.check_password.return_value = make_awaitable(True) mock_password_provider.check_password = AsyncMock(return_value=True)
tok1 = self.login("u", "p") tok1 = self.login("u", "p")
self.login("u", "p", device_id="dev2") self.login("u", "p", device_id="dev2")
mock_password_provider.reset_mock() mock_password_provider.reset_mock()
# have the auth provider deny the request to start with # have the auth provider deny the request to start with
mock_password_provider.check_password.return_value = make_awaitable(False) mock_password_provider.check_password = AsyncMock(return_value=False)
# make the initial request which returns a 401 # make the initial request which returns a 401
session = self._start_delete_device_session(tok1, "dev2") session = self._start_delete_device_session(tok1, "dev2")
@ -229,7 +228,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.reset_mock() mock_password_provider.reset_mock()
# Finally, check the request goes through when we allow it # Finally, check the request goes through when we allow it
mock_password_provider.check_password.return_value = make_awaitable(True) mock_password_provider.check_password = AsyncMock(return_value=True)
channel = self._authed_delete_device(tok1, "dev2", session, "u", "p") channel = self._authed_delete_device(tok1, "dev2", session, "u", "p")
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)
mock_password_provider.check_password.assert_called_once_with("@u:test", "p") mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
@ -243,7 +242,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.register_user("localuser", "localpass") self.register_user("localuser", "localpass")
# check_password must return an awaitable # check_password must return an awaitable
mock_password_provider.check_password.return_value = make_awaitable(False) mock_password_provider.check_password = AsyncMock(return_value=False)
channel = self._send_password_login("u", "p") channel = self._send_password_login("u", "p")
self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result) self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)
@ -260,7 +259,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.register_user("localuser", "localpass") self.register_user("localuser", "localpass")
# have the auth provider deny the request # have the auth provider deny the request
mock_password_provider.check_password.return_value = make_awaitable(False) mock_password_provider.check_password = AsyncMock(return_value=False)
# log in twice, to get two devices # log in twice, to get two devices
tok1 = self.login("localuser", "localpass") tok1 = self.login("localuser", "localpass")
@ -303,7 +302,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.register_user("localuser", "localpass") self.register_user("localuser", "localpass")
# check_password must return an awaitable # check_password must return an awaitable
mock_password_provider.check_password.return_value = make_awaitable(False) mock_password_provider.check_password = AsyncMock(return_value=False)
channel = self._send_password_login("localuser", "localpass") channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 403) self.assertEqual(channel.code, 403)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
@ -325,7 +324,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.register_user("localuser", "localpass") self.register_user("localuser", "localpass")
# allow login via the auth provider # allow login via the auth provider
mock_password_provider.check_password.return_value = make_awaitable(True) mock_password_provider.check_password = AsyncMock(return_value=True)
# log in twice, to get two devices # log in twice, to get two devices
tok1 = self.login("localuser", "p") tok1 = self.login("localuser", "p")
@ -342,7 +341,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.check_password.assert_not_called() mock_password_provider.check_password.assert_not_called()
# now try deleting with the local password # now try deleting with the local password
mock_password_provider.check_password.return_value = make_awaitable(False) mock_password_provider.check_password = AsyncMock(return_value=False)
channel = self._authed_delete_device( channel = self._authed_delete_device(
tok1, "dev2", session, "localuser", "localpass" tok1, "dev2", session, "localuser", "localpass"
) )
@ -396,9 +395,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
mock_password_provider.check_auth.assert_not_called() mock_password_provider.check_auth.assert_not_called()
mock_password_provider.check_auth.return_value = make_awaitable( mock_password_provider.check_auth = AsyncMock(return_value=("@user:test", None))
("@user:test", None)
)
channel = self._send_login("test.login_type", "u", test_field="y") channel = self._send_login("test.login_type", "u", test_field="y")
self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@user:test", channel.json_body["user_id"]) self.assertEqual("@user:test", channel.json_body["user_id"])
@ -447,9 +444,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.reset_mock() mock_password_provider.reset_mock()
# right params, but authing as the wrong user # right params, but authing as the wrong user
mock_password_provider.check_auth.return_value = make_awaitable( mock_password_provider.check_auth = AsyncMock(return_value=("@user:test", None))
("@user:test", None)
)
body["auth"]["test_field"] = "foo" body["auth"]["test_field"] = "foo"
channel = self._delete_device(tok1, "dev2", body) channel = self._delete_device(tok1, "dev2", body)
self.assertEqual(channel.code, 403) self.assertEqual(channel.code, 403)
@ -460,8 +455,8 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.reset_mock() mock_password_provider.reset_mock()
# and finally, succeed # and finally, succeed
mock_password_provider.check_auth.return_value = make_awaitable( mock_password_provider.check_auth = AsyncMock(
("@localuser:test", None) return_value=("@localuser:test", None)
) )
channel = self._delete_device(tok1, "dev2", body) channel = self._delete_device(tok1, "dev2", body)
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)
@ -478,10 +473,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.custom_auth_provider_callback_test_body() self.custom_auth_provider_callback_test_body()
def custom_auth_provider_callback_test_body(self) -> None: def custom_auth_provider_callback_test_body(self) -> None:
callback = Mock(return_value=make_awaitable(None)) callback = AsyncMock(return_value=None)
mock_password_provider.check_auth.return_value = make_awaitable( mock_password_provider.check_auth = AsyncMock(
("@user:test", callback) return_value=("@user:test", callback)
) )
channel = self._send_login("test.login_type", "u", test_field="y") channel = self._send_login("test.login_type", "u", test_field="y")
self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
@ -616,8 +611,8 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
login is disabled""" login is disabled"""
# register the user and log in twice via the test login type to get two devices, # register the user and log in twice via the test login type to get two devices,
self.register_user("localuser", "localpass") self.register_user("localuser", "localpass")
mock_password_provider.check_auth.return_value = make_awaitable( mock_password_provider.check_auth = AsyncMock(
("@localuser:test", None) return_value=("@localuser:test", None)
) )
channel = self._send_login("test.login_type", "localuser", test_field="") channel = self._send_login("test.login_type", "localuser", test_field="")
self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
@ -835,11 +830,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
username: The username to use for the test. username: The username to use for the test.
registration: Whether to test with registration URLs. registration: Whether to test with registration URLs.
""" """
self.hs.get_identity_handler().send_threepid_validation = Mock( # type: ignore[assignment] self.hs.get_identity_handler().send_threepid_validation = AsyncMock( # type: ignore[assignment]
return_value=make_awaitable(0), return_value=0
) )
m = Mock(return_value=make_awaitable(False)) m = AsyncMock(return_value=False)
self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m] self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m]
self.register_user(username, "password") self.register_user(username, "password")
@ -869,7 +864,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
m.assert_called_once_with("email", "foo@test.com", registration) m.assert_called_once_with("email", "foo@test.com", registration)
m = Mock(return_value=make_awaitable(True)) m = AsyncMock(return_value=True)
self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m] self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m]
channel = self.make_request( channel = self.make_request(

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Awaitable, Callable, Dict from typing import Any, Awaitable, Callable, Dict
from unittest.mock import Mock from unittest.mock import AsyncMock, Mock
from parameterized import parameterized from parameterized import parameterized
@ -26,7 +26,6 @@ from synapse.types import JsonDict, UserID
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable
class ProfileTestCase(unittest.HomeserverTestCase): class ProfileTestCase(unittest.HomeserverTestCase):
@ -35,7 +34,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
servlets = [admin.register_servlets] servlets = [admin.register_servlets]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.mock_federation = Mock() self.mock_federation = AsyncMock()
self.mock_registry = Mock() self.mock_registry = Mock()
self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {} self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {}
@ -135,9 +134,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
) )
def test_get_other_name(self) -> None: def test_get_other_name(self) -> None:
self.mock_federation.make_query.return_value = make_awaitable( self.mock_federation.make_query.return_value = {"displayname": "Alice"}
{"displayname": "Alice"}
)
displayname = self.get_success(self.handler.get_displayname(self.alice)) displayname = self.get_success(self.handler.get_displayname(self.alice))

View file

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from typing import Any, Collection, List, Optional, Tuple from typing import Any, Collection, List, Optional, Tuple
from unittest.mock import Mock from unittest.mock import AsyncMock, Mock
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
@ -38,7 +38,6 @@ from synapse.types import (
) )
from synapse.util import Clock from synapse.util import Clock
from tests.test_utils import make_awaitable
from tests.unittest import override_config from tests.unittest import override_config
from tests.utils import mock_getRawHeaders from tests.utils import mock_getRawHeaders
@ -203,24 +202,22 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
@override_config({"limit_usage_by_mau": True}) @override_config({"limit_usage_by_mau": True})
def test_get_or_create_user_mau_not_blocked(self) -> None: def test_get_or_create_user_mau_not_blocked(self) -> None:
self.store.count_monthly_users = Mock( # type: ignore[assignment] self.store.count_monthly_users = AsyncMock( # type: ignore[assignment]
return_value=make_awaitable(self.hs.config.server.max_mau_value - 1) return_value=self.hs.config.server.max_mau_value - 1
) )
# Ensure does not throw exception # Ensure does not throw exception
self.get_success(self.get_or_create_user(self.requester, "c", "User")) self.get_success(self.get_or_create_user(self.requester, "c", "User"))
@override_config({"limit_usage_by_mau": True}) @override_config({"limit_usage_by_mau": True})
def test_get_or_create_user_mau_blocked(self) -> None: def test_get_or_create_user_mau_blocked(self) -> None:
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = AsyncMock(return_value=self.lots_of_users)
return_value=make_awaitable(self.lots_of_users)
)
self.get_failure( self.get_failure(
self.get_or_create_user(self.requester, "b", "display_name"), self.get_or_create_user(self.requester, "b", "display_name"),
ResourceLimitError, ResourceLimitError,
) )
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = AsyncMock(
return_value=make_awaitable(self.hs.config.server.max_mau_value) return_value=self.hs.config.server.max_mau_value
) )
self.get_failure( self.get_failure(
self.get_or_create_user(self.requester, "b", "display_name"), self.get_or_create_user(self.requester, "b", "display_name"),
@ -229,15 +226,13 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
@override_config({"limit_usage_by_mau": True}) @override_config({"limit_usage_by_mau": True})
def test_register_mau_blocked(self) -> None: def test_register_mau_blocked(self) -> None:
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = AsyncMock(return_value=self.lots_of_users)
return_value=make_awaitable(self.lots_of_users)
)
self.get_failure( self.get_failure(
self.handler.register_user(localpart="local_part"), ResourceLimitError self.handler.register_user(localpart="local_part"), ResourceLimitError
) )
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = AsyncMock(
return_value=make_awaitable(self.hs.config.server.max_mau_value) return_value=self.hs.config.server.max_mau_value
) )
self.get_failure( self.get_failure(
self.handler.register_user(localpart="local_part"), ResourceLimitError self.handler.register_user(localpart="local_part"), ResourceLimitError
@ -292,7 +287,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
@override_config({"auto_join_rooms": ["#room:test"]}) @override_config({"auto_join_rooms": ["#room:test"]})
def test_auto_create_auto_join_rooms_when_user_is_not_a_real_user(self) -> None: def test_auto_create_auto_join_rooms_when_user_is_not_a_real_user(self) -> None:
room_alias_str = "#room:test" room_alias_str = "#room:test"
self.store.is_real_user = Mock(return_value=make_awaitable(False)) self.store.is_real_user = AsyncMock(return_value=False)
user_id = self.get_success(self.handler.register_user(localpart="support")) user_id = self.get_success(self.handler.register_user(localpart="support"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id)) rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0) self.assertEqual(len(rooms), 0)
@ -304,8 +299,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self) -> None: def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self) -> None:
room_alias_str = "#room:test" room_alias_str = "#room:test"
self.store.count_real_users = Mock(return_value=make_awaitable(1)) # type: ignore[assignment] self.store.count_real_users = AsyncMock(return_value=1) # type: ignore[assignment]
self.store.is_real_user = Mock(return_value=make_awaitable(True)) self.store.is_real_user = AsyncMock(return_value=True)
user_id = self.get_success(self.handler.register_user(localpart="real")) user_id = self.get_success(self.handler.register_user(localpart="real"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id)) rooms = self.get_success(self.store.get_rooms_for_user(user_id))
directory_handler = self.hs.get_directory_handler() directory_handler = self.hs.get_directory_handler()
@ -319,8 +314,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user( def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user(
self, self,
) -> None: ) -> None:
self.store.count_real_users = Mock(return_value=make_awaitable(2)) # type: ignore[assignment] self.store.count_real_users = AsyncMock(return_value=2) # type: ignore[assignment]
self.store.is_real_user = Mock(return_value=make_awaitable(True)) self.store.is_real_user = AsyncMock(return_value=True)
user_id = self.get_success(self.handler.register_user(localpart="real")) user_id = self.get_success(self.handler.register_user(localpart="real"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id)) rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0) self.assertEqual(len(rooms), 0)

View file

@ -1,4 +1,4 @@
from unittest.mock import Mock, patch from unittest.mock import AsyncMock, patch
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
@ -16,7 +16,6 @@ from synapse.util import Clock
from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import make_request from tests.server import make_request
from tests.test_utils import make_awaitable
from tests.unittest import ( from tests.unittest import (
FederatingHomeserverTestCase, FederatingHomeserverTestCase,
HomeserverTestCase, HomeserverTestCase,
@ -154,25 +153,21 @@ class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase):
None, None,
) )
mock_make_membership_event = Mock( mock_make_membership_event = AsyncMock(
return_value=make_awaitable( return_value=(
( self.OTHER_SERVER_NAME,
self.OTHER_SERVER_NAME, join_event,
join_event, self.hs.config.server.default_room_version,
self.hs.config.server.default_room_version,
)
) )
) )
mock_send_join = Mock( mock_send_join = AsyncMock(
return_value=make_awaitable( return_value=SendJoinResult(
SendJoinResult( join_event,
join_event, self.OTHER_SERVER_NAME,
self.OTHER_SERVER_NAME, state=[create_event],
state=[create_event], auth_chain=[create_event],
auth_chain=[create_event], partial_state=False,
partial_state=False, servers_in_room=frozenset(),
servers_in_room=frozenset(),
)
) )
) )

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional from typing import Optional
from unittest.mock import MagicMock, Mock, patch from unittest.mock import AsyncMock, Mock, patch
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
@ -29,7 +29,6 @@ from synapse.util import Clock
import tests.unittest import tests.unittest
import tests.utils import tests.utils
from tests.test_utils import make_awaitable
class SyncTestCase(tests.unittest.HomeserverTestCase): class SyncTestCase(tests.unittest.HomeserverTestCase):
@ -253,8 +252,8 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
mocked_get_prev_events = patch.object( mocked_get_prev_events = patch.object(
self.hs.get_datastores().main, self.hs.get_datastores().main,
"get_prev_events_for_room", "get_prev_events_for_room",
new_callable=MagicMock, new_callable=AsyncMock,
return_value=make_awaitable([last_room_creation_event_id]), return_value=[last_room_creation_event_id],
) )
with mocked_get_prev_events: with mocked_get_prev_events:
self.helper.join(room_id, eve, tok=eve_token) self.helper.join(room_id, eve, tok=eve_token)

View file

@ -15,7 +15,7 @@
import json import json
from typing import Dict, List, Set from typing import Dict, List, Set
from unittest.mock import ANY, Mock, call from unittest.mock import ANY, AsyncMock, Mock, call
from netaddr import IPSet from netaddr import IPSet
@ -33,7 +33,6 @@ from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.server import ThreadedMemoryReactorClock from tests.server import ThreadedMemoryReactorClock
from tests.test_utils import make_awaitable
from tests.unittest import override_config from tests.unittest import override_config
# Some local users to test with # Some local users to test with
@ -74,11 +73,11 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
# we mock out the keyring so as to skip the authentication check on the # we mock out the keyring so as to skip the authentication check on the
# federation API call. # federation API call.
mock_keyring = Mock(spec=["verify_json_for_server"]) mock_keyring = Mock(spec=["verify_json_for_server"])
mock_keyring.verify_json_for_server.return_value = make_awaitable(True) mock_keyring.verify_json_for_server = AsyncMock(return_value=True)
# we mock out the federation client too # we mock out the federation client too
self.mock_federation_client = Mock(spec=["put_json"]) self.mock_federation_client = AsyncMock(spec=["put_json"])
self.mock_federation_client.put_json.return_value = make_awaitable((200, "OK")) self.mock_federation_client.put_json.return_value = (200, "OK")
self.mock_federation_client.agent = MatrixFederationAgent( self.mock_federation_client.agent = MatrixFederationAgent(
reactor, reactor,
tls_client_options_factory=None, tls_client_options_factory=None,
@ -121,20 +120,18 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.datastore = hs.get_datastores().main self.datastore = hs.get_datastores().main
self.datastore.get_destination_retry_timings = Mock( self.datastore.get_destination_retry_timings = AsyncMock(return_value=None)
return_value=make_awaitable(None)
self.datastore.get_device_updates_by_remote = AsyncMock( # type: ignore[assignment]
return_value=(0, [])
) )
self.datastore.get_device_updates_by_remote = Mock( # type: ignore[assignment] self.datastore.get_destination_last_successful_stream_ordering = AsyncMock( # type: ignore[assignment]
return_value=make_awaitable((0, [])) return_value=None
) )
self.datastore.get_destination_last_successful_stream_ordering = Mock( # type: ignore[assignment] self.datastore.get_received_txn_response = AsyncMock( # type: ignore[assignment]
return_value=make_awaitable(None) return_value=None
)
self.datastore.get_received_txn_response = Mock( # type: ignore[assignment]
return_value=make_awaitable(None)
) )
self.room_members: List[UserID] = [] self.room_members: List[UserID] = []
@ -173,27 +170,25 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.datastore.get_users_in_room = Mock(side_effect=get_users_in_room) self.datastore.get_users_in_room = Mock(side_effect=get_users_in_room)
self.datastore.get_user_directory_stream_pos = Mock( # type: ignore[assignment] self.datastore.get_user_directory_stream_pos = AsyncMock( # type: ignore[assignment]
side_effect=( # we deliberately return a non-None stream pos to avoid
# we deliberately return a non-None stream pos to avoid # doing an initial_sync
# doing an initial_sync return_value=1
lambda: make_awaitable(1)
)
) )
self.datastore.get_partial_current_state_deltas = Mock(return_value=(0, None)) # type: ignore[assignment] self.datastore.get_partial_current_state_deltas = Mock(return_value=(0, None)) # type: ignore[assignment]
self.datastore.get_to_device_stream_token = Mock( # type: ignore[assignment] self.datastore.get_to_device_stream_token = Mock( # type: ignore[assignment]
side_effect=lambda: 0 return_value=0
) )
self.datastore.get_new_device_msgs_for_remote = Mock( # type: ignore[assignment] self.datastore.get_new_device_msgs_for_remote = AsyncMock( # type: ignore[assignment]
side_effect=lambda *args, **kargs: make_awaitable(([], 0)) return_value=([], 0)
) )
self.datastore.delete_device_msgs_for_remote = Mock( # type: ignore[assignment] self.datastore.delete_device_msgs_for_remote = AsyncMock( # type: ignore[assignment]
side_effect=lambda *args, **kargs: make_awaitable(None) return_value=None
) )
self.datastore.set_received_txn_response = Mock( # type: ignore[assignment] self.datastore.set_received_txn_response = AsyncMock( # type: ignore[assignment]
side_effect=lambda *args, **kwargs: make_awaitable(None) return_value=None
) )
def test_started_typing_local(self) -> None: def test_started_typing_local(self) -> None:

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Tuple from typing import Any, Tuple
from unittest.mock import Mock, patch from unittest.mock import AsyncMock, Mock, patch
from urllib.parse import quote from urllib.parse import quote
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
@ -30,7 +30,7 @@ from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.storage.test_user_directory import GetUserDirectoryTables from tests.storage.test_user_directory import GetUserDirectoryTables
from tests.test_utils import event_injection, make_awaitable from tests.test_utils import event_injection
from tests.test_utils.event_injection import inject_member_event from tests.test_utils.event_injection import inject_member_event
from tests.unittest import override_config from tests.unittest import override_config
@ -471,7 +471,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.store.register_user(user_id=r_user_id, password_hash=None) self.store.register_user(user_id=r_user_id, password_hash=None)
) )
mock_remove_from_user_dir = Mock(return_value=make_awaitable(None)) mock_remove_from_user_dir = AsyncMock(return_value=None)
with patch.object( with patch.object(
self.store, "remove_from_user_dir", mock_remove_from_user_dir self.store, "remove_from_user_dir", mock_remove_from_user_dir
): ):

View file

@ -14,8 +14,8 @@
import base64 import base64
import logging import logging
import os import os
from typing import Any, Awaitable, Callable, Generator, List, Optional, cast from typing import Generator, List, Optional, cast
from unittest.mock import Mock, patch from unittest.mock import AsyncMock, patch
import treq import treq
from netaddr import IPSet from netaddr import IPSet
@ -41,7 +41,7 @@ from twisted.web.iweb import IPolicyForHTTPS, IResponse
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.crypto.context_factory import FederationPolicyForHTTPS from synapse.crypto.context_factory import FederationPolicyForHTTPS
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.http.federation.srv_resolver import Server from synapse.http.federation.srv_resolver import Server, SrvResolver
from synapse.http.federation.well_known_resolver import ( from synapse.http.federation.well_known_resolver import (
WELL_KNOWN_MAX_SIZE, WELL_KNOWN_MAX_SIZE,
WellKnownResolver, WellKnownResolver,
@ -68,21 +68,11 @@ from tests.utils import checked_cast, default_config
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Once Async Mocks or lambdas are supported this can go away.
def generate_resolve_service(
result: List[Server],
) -> Callable[[Any], Awaitable[List[Server]]]:
async def resolve_service(_: Any) -> List[Server]:
return result
return resolve_service
class MatrixFederationAgentTests(unittest.TestCase): class MatrixFederationAgentTests(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.reactor = ThreadedMemoryReactorClock() self.reactor = ThreadedMemoryReactorClock()
self.mock_resolver = Mock() self.mock_resolver = AsyncMock(spec=SrvResolver)
config_dict = default_config("test", parse=False) config_dict = default_config("test", parse=False)
config_dict["federation_custom_ca_list"] = [get_test_ca_cert_file()] config_dict["federation_custom_ca_list"] = [get_test_ca_cert_file()]
@ -636,7 +626,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
""" """
self.agent = self._make_agent() self.agent = self._make_agent()
self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.mock_resolver.resolve_service.return_value = []
self.reactor.lookups["testserv1"] = "1.2.3.4" self.reactor.lookups["testserv1"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix-federation://testserv1/foo/bar") test_d = self._make_get_request(b"matrix-federation://testserv1/foo/bar")
@ -722,7 +712,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
""" """
self.agent = self._make_agent() self.agent = self._make_agent()
self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.mock_resolver.resolve_service.return_value = []
self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["testserv"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar") test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar")
@ -776,7 +766,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
"""Test the behaviour when the .well-known delegates elsewhere""" """Test the behaviour when the .well-known delegates elsewhere"""
self.agent = self._make_agent() self.agent = self._make_agent()
self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.mock_resolver.resolve_service.return_value = []
self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["testserv"] = "1.2.3.4"
self.reactor.lookups["target-server"] = "1::f" self.reactor.lookups["target-server"] = "1::f"
@ -840,7 +830,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
""" """
self.agent = self._make_agent() self.agent = self._make_agent()
self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.mock_resolver.resolve_service.return_value = []
self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["testserv"] = "1.2.3.4"
self.reactor.lookups["target-server"] = "1::f" self.reactor.lookups["target-server"] = "1::f"
@ -930,7 +920,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
""" """
self.agent = self._make_agent() self.agent = self._make_agent()
self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.mock_resolver.resolve_service.return_value = []
self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["testserv"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar") test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar")
@ -986,7 +976,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
# the config left to the default, which will not trust it (since the # the config left to the default, which will not trust it (since the
# presented cert is signed by a test CA) # presented cert is signed by a test CA)
self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.mock_resolver.resolve_service.return_value = []
self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["testserv"] = "1.2.3.4"
config = default_config("test", parse=True) config = default_config("test", parse=True)
@ -1037,9 +1027,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
""" """
self.agent = self._make_agent() self.agent = self._make_agent()
self.mock_resolver.resolve_service.side_effect = generate_resolve_service( self.mock_resolver.resolve_service.return_value = [
[Server(host=b"srvtarget", port=8443)] Server(host=b"srvtarget", port=8443)
) ]
self.reactor.lookups["srvtarget"] = "1.2.3.4" self.reactor.lookups["srvtarget"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar") test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar")
@ -1094,9 +1084,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.assertEqual(host, "1.2.3.4") self.assertEqual(host, "1.2.3.4")
self.assertEqual(port, 443) self.assertEqual(port, 443)
self.mock_resolver.resolve_service.side_effect = generate_resolve_service( self.mock_resolver.resolve_service.return_value = [
[Server(host=b"srvtarget", port=8443)] Server(host=b"srvtarget", port=8443)
) ]
self._handle_well_known_connection( self._handle_well_known_connection(
client_factory, client_factory,
@ -1137,7 +1127,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
"""test the behaviour when the server name has idna chars in""" """test the behaviour when the server name has idna chars in"""
self.agent = self._make_agent() self.agent = self._make_agent()
self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.mock_resolver.resolve_service.return_value = []
# the resolver is always called with the IDNA hostname as a native string. # the resolver is always called with the IDNA hostname as a native string.
self.reactor.lookups["xn--bcher-kva.com"] = "1.2.3.4" self.reactor.lookups["xn--bcher-kva.com"] = "1.2.3.4"
@ -1201,9 +1191,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
"""test the behaviour when the target of a SRV record has idna chars""" """test the behaviour when the target of a SRV record has idna chars"""
self.agent = self._make_agent() self.agent = self._make_agent()
self.mock_resolver.resolve_service.side_effect = generate_resolve_service( self.mock_resolver.resolve_service.return_value = [
[Server(host=b"xn--trget-3qa.com", port=8443)] # târget.com Server(host=b"xn--trget-3qa.com", port=8443)
) ] # târget.com
self.reactor.lookups["xn--trget-3qa.com"] = "1.2.3.4" self.reactor.lookups["xn--trget-3qa.com"] = "1.2.3.4"
test_d = self._make_get_request( test_d = self._make_get_request(
@ -1407,12 +1397,10 @@ class MatrixFederationAgentTests(unittest.TestCase):
"""Test that other SRV results are tried if the first one fails.""" """Test that other SRV results are tried if the first one fails."""
self.agent = self._make_agent() self.agent = self._make_agent()
self.mock_resolver.resolve_service.side_effect = generate_resolve_service( self.mock_resolver.resolve_service.return_value = [
[ Server(host=b"target.com", port=8443),
Server(host=b"target.com", port=8443), Server(host=b"target.com", port=8444),
Server(host=b"target.com", port=8444), ]
]
)
self.reactor.lookups["target.com"] = "1.2.3.4" self.reactor.lookups["target.com"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar") test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar")

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from unittest.mock import Mock from unittest.mock import AsyncMock, Mock
from netaddr import IPSet from netaddr import IPSet
@ -26,7 +26,6 @@ from synapse.types import UserID, create_requester
from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import get_clock from tests.server import get_clock
from tests.test_utils import make_awaitable
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -62,7 +61,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
new event. new event.
""" """
mock_client = Mock(spec=["put_json"]) mock_client = Mock(spec=["put_json"])
mock_client.put_json.return_value = make_awaitable({}) mock_client.put_json = AsyncMock(return_value={})
mock_client.agent = self.matrix_federation_agent mock_client.agent = self.matrix_federation_agent
self.make_worker_hs( self.make_worker_hs(
"synapse.app.generic_worker", "synapse.app.generic_worker",
@ -93,7 +92,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
new events. new events.
""" """
mock_client1 = Mock(spec=["put_json"]) mock_client1 = Mock(spec=["put_json"])
mock_client1.put_json.return_value = make_awaitable({}) mock_client1.put_json = AsyncMock(return_value={})
mock_client1.agent = self.matrix_federation_agent mock_client1.agent = self.matrix_federation_agent
self.make_worker_hs( self.make_worker_hs(
"synapse.app.generic_worker", "synapse.app.generic_worker",
@ -108,7 +107,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
) )
mock_client2 = Mock(spec=["put_json"]) mock_client2 = Mock(spec=["put_json"])
mock_client2.put_json.return_value = make_awaitable({}) mock_client2.put_json = AsyncMock(return_value={})
mock_client2.agent = self.matrix_federation_agent mock_client2.agent = self.matrix_federation_agent
self.make_worker_hs( self.make_worker_hs(
"synapse.app.generic_worker", "synapse.app.generic_worker",
@ -162,7 +161,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
new typing EDUs. new typing EDUs.
""" """
mock_client1 = Mock(spec=["put_json"]) mock_client1 = Mock(spec=["put_json"])
mock_client1.put_json.return_value = make_awaitable({}) mock_client1.put_json = AsyncMock(return_value={})
mock_client1.agent = self.matrix_federation_agent mock_client1.agent = self.matrix_federation_agent
self.make_worker_hs( self.make_worker_hs(
"synapse.app.generic_worker", "synapse.app.generic_worker",
@ -177,7 +176,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
) )
mock_client2 = Mock(spec=["put_json"]) mock_client2 = Mock(spec=["put_json"])
mock_client2.put_json.return_value = make_awaitable({}) mock_client2.put_json = AsyncMock(return_value={})
mock_client2.agent = self.matrix_federation_agent mock_client2.agent = self.matrix_federation_agent
self.make_worker_hs( self.make_worker_hs(
"synapse.app.generic_worker", "synapse.app.generic_worker",

View file

@ -18,7 +18,7 @@ import os
import urllib.parse import urllib.parse
from binascii import unhexlify from binascii import unhexlify
from typing import List, Optional from typing import List, Optional
from unittest.mock import Mock, patch from unittest.mock import AsyncMock, Mock, patch
from parameterized import parameterized, parameterized_class from parameterized import parameterized, parameterized_class
@ -45,7 +45,7 @@ from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.server import FakeSite, make_request from tests.server import FakeSite, make_request
from tests.test_utils import SMALL_PNG, make_awaitable from tests.test_utils import SMALL_PNG
from tests.unittest import override_config from tests.unittest import override_config
@ -419,8 +419,8 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
store = self.hs.get_datastores().main store = self.hs.get_datastores().main
# Set monthly active users to the limit # Set monthly active users to the limit
store.get_monthly_active_count = Mock( store.get_monthly_active_count = AsyncMock(
return_value=make_awaitable(self.hs.config.server.max_mau_value) return_value=self.hs.config.server.max_mau_value
) )
# Check that the blocking of monthly active users is working as expected # Check that the blocking of monthly active users is working as expected
# The registration of a new user fails due to the limit # The registration of a new user fails due to the limit
@ -1834,8 +1834,8 @@ class UserRestTestCase(unittest.HomeserverTestCase):
) )
# Set monthly active users to the limit # Set monthly active users to the limit
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = AsyncMock(
return_value=make_awaitable(self.hs.config.server.max_mau_value) return_value=self.hs.config.server.max_mau_value
) )
# Check that the blocking of monthly active users is working as expected # Check that the blocking of monthly active users is working as expected
# The registration of a new user fails due to the limit # The registration of a new user fails due to the limit
@ -1871,8 +1871,8 @@ class UserRestTestCase(unittest.HomeserverTestCase):
handler = self.hs.get_registration_handler() handler = self.hs.get_registration_handler()
# Set monthly active users to the limit # Set monthly active users to the limit
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = AsyncMock(
return_value=make_awaitable(self.hs.config.server.max_mau_value) return_value=self.hs.config.server.max_mau_value
) )
# Check that the blocking of monthly active users is working as expected # Check that the blocking of monthly active users is working as expected
# The registration of a new user fails due to the limit # The registration of a new user fails due to the limit

View file

@ -11,13 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from unittest.mock import Mock from unittest.mock import AsyncMock
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import account_data, login, room from synapse.rest.client import account_data, login, room
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable
class AccountDataTestCase(unittest.HomeserverTestCase): class AccountDataTestCase(unittest.HomeserverTestCase):
@ -32,7 +31,7 @@ class AccountDataTestCase(unittest.HomeserverTestCase):
"""Tests that the on_account_data_updated module callback is called correctly when """Tests that the on_account_data_updated module callback is called correctly when
a user's account data changes. a user's account data changes.
""" """
mocked_callback = Mock(return_value=make_awaitable(None)) mocked_callback = AsyncMock(return_value=None)
self.hs.get_account_data_handler()._on_account_data_updated_callbacks.append( self.hs.get_account_data_handler()._on_account_data_updated_callbacks.append(
mocked_callback mocked_callback
) )

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from http import HTTPStatus from http import HTTPStatus
from unittest.mock import Mock from unittest.mock import AsyncMock, Mock
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
@ -23,7 +23,6 @@ from synapse.types import UserID
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable
class PresenceTestCase(unittest.HomeserverTestCase): class PresenceTestCase(unittest.HomeserverTestCase):
@ -36,7 +35,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.presence_handler = Mock(spec=PresenceHandler) self.presence_handler = Mock(spec=PresenceHandler)
self.presence_handler.set_state.return_value = make_awaitable(None) self.presence_handler.set_state = AsyncMock(return_value=None)
hs = self.setup_test_homeserver( hs = self.setup_test_homeserver(
"red", "red",

View file

@ -15,7 +15,7 @@
import urllib.parse import urllib.parse
from typing import Any, Callable, Dict, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional, Tuple
from unittest.mock import patch from unittest.mock import AsyncMock, patch
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
@ -28,7 +28,6 @@ from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.server import FakeChannel from tests.server import FakeChannel
from tests.test_utils import make_awaitable
from tests.test_utils.event_injection import inject_event from tests.test_utils.event_injection import inject_event
from tests.unittest import override_config from tests.unittest import override_config
@ -264,7 +263,8 @@ class RelationsTestCase(BaseRelationsTestCase):
# Disable the validation to pretend this came over federation. # Disable the validation to pretend this came over federation.
with patch( with patch(
"synapse.handlers.message.EventCreationHandler._validate_event_relation", "synapse.handlers.message.EventCreationHandler._validate_event_relation",
new=lambda self, event: make_awaitable(None), new_callable=AsyncMock,
return_value=None,
): ):
# Generate a various relations from a different room. # Generate a various relations from a different room.
self.get_success( self.get_success(
@ -1300,7 +1300,8 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
# not an event the Client-Server API will allow.. # not an event the Client-Server API will allow..
with patch( with patch(
"synapse.handlers.message.EventCreationHandler._validate_event_relation", "synapse.handlers.message.EventCreationHandler._validate_event_relation",
new=lambda self, event: make_awaitable(None), new_callable=AsyncMock,
return_value=None,
): ):
# Create a sub-thread off the thread, which is not allowed. # Create a sub-thread off the thread, which is not allowed.
self._send_relation( self._send_relation(

View file

@ -20,7 +20,7 @@
import json import json
from http import HTTPStatus from http import HTTPStatus
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from unittest.mock import Mock, call, patch from unittest.mock import AsyncMock, Mock, call, patch
from urllib import parse as urlparse from urllib import parse as urlparse
from parameterized import param, parameterized from parameterized import param, parameterized
@ -52,7 +52,6 @@ from synapse.util.stringutils import random_string
from tests import unittest from tests import unittest
from tests.http.server._base import make_request_with_cancellation_test from tests.http.server._base import make_request_with_cancellation_test
from tests.storage.test_stream import PaginationTestCase from tests.storage.test_stream import PaginationTestCase
from tests.test_utils import make_awaitable
from tests.test_utils.event_injection import create_event from tests.test_utils.event_injection import create_event
from tests.unittest import override_config from tests.unittest import override_config
@ -70,8 +69,8 @@ class RoomBase(unittest.HomeserverTestCase):
) )
self.hs.get_federation_handler = Mock() # type: ignore[assignment] self.hs.get_federation_handler = Mock() # type: ignore[assignment]
self.hs.get_federation_handler.return_value.maybe_backfill = Mock( self.hs.get_federation_handler.return_value.maybe_backfill = AsyncMock(
return_value=make_awaitable(None) return_value=None
) )
async def _insert_client_ip(*args: Any, **kwargs: Any) -> None: async def _insert_client_ip(*args: Any, **kwargs: Any) -> None:
@ -2375,7 +2374,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
] ]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
return self.setup_test_homeserver(federation_client=Mock()) return self.setup_test_homeserver(federation_client=AsyncMock())
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.register_user("user", "pass") self.register_user("user", "pass")
@ -2385,7 +2384,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
def test_simple(self) -> None: def test_simple(self) -> None:
"Simple test for searching rooms over federation" "Simple test for searching rooms over federation"
self.federation_client.get_public_rooms.return_value = make_awaitable({}) # type: ignore[attr-defined] self.federation_client.get_public_rooms.return_value = {} # type: ignore[attr-defined]
search_filter = {PublicRoomsFilterFields.GENERIC_SEARCH_TERM: "foobar"} search_filter = {PublicRoomsFilterFields.GENERIC_SEARCH_TERM: "foobar"}
@ -2413,7 +2412,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
# with a 404, when using search filters. # with a 404, when using search filters.
self.federation_client.get_public_rooms.side_effect = ( # type: ignore[attr-defined] self.federation_client.get_public_rooms.side_effect = ( # type: ignore[attr-defined]
HttpResponseException(HTTPStatus.NOT_FOUND, "Not Found", b""), HttpResponseException(HTTPStatus.NOT_FOUND, "Not Found", b""),
make_awaitable({}), {},
) )
search_filter = {PublicRoomsFilterFields.GENERIC_SEARCH_TERM: "foobar"} search_filter = {PublicRoomsFilterFields.GENERIC_SEARCH_TERM: "foobar"}
@ -3413,17 +3412,17 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
# Mock a few functions to prevent the test from failing due to failing to talk to # Mock a few functions to prevent the test from failing due to failing to talk to
# a remote IS. We keep the mock for make_and_store_3pid_invite around so we # a remote IS. We keep the mock for make_and_store_3pid_invite around so we
# can check its call_count later on during the test. # can check its call_count later on during the test.
make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0))) make_invite_mock = AsyncMock(return_value=(Mock(event_id="abc"), 0))
self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock # type: ignore[assignment] self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock # type: ignore[assignment]
self.hs.get_identity_handler().lookup_3pid = Mock( # type: ignore[assignment] self.hs.get_identity_handler().lookup_3pid = AsyncMock( # type: ignore[assignment]
return_value=make_awaitable(None), return_value=None,
) )
# Add a mock to the spamchecker callbacks for user_may_send_3pid_invite. Make it # Add a mock to the spamchecker callbacks for user_may_send_3pid_invite. Make it
# allow everything for now. # allow everything for now.
# `spec` argument is needed for this function mock to have `__qualname__`, which # `spec` argument is needed for this function mock to have `__qualname__`, which
# is needed for `Measure` metrics buried in SpamChecker. # is needed for `Measure` metrics buried in SpamChecker.
mock = Mock(return_value=make_awaitable(True), spec=lambda *x: None) mock = AsyncMock(return_value=True, spec=lambda *x: None)
self.hs.get_module_api_callbacks().spam_checker._user_may_send_3pid_invite_callbacks.append( self.hs.get_module_api_callbacks().spam_checker._user_may_send_3pid_invite_callbacks.append(
mock mock
) )
@ -3451,7 +3450,7 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
# Now change the return value of the callback to deny any invite and test that # Now change the return value of the callback to deny any invite and test that
# we can't send the invite. # we can't send the invite.
mock.return_value = make_awaitable(False) mock.return_value = False
channel = self.make_request( channel = self.make_request(
method="POST", method="POST",
path="/rooms/" + self.room_id + "/invite", path="/rooms/" + self.room_id + "/invite",
@ -3477,18 +3476,18 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
# Mock a few functions to prevent the test from failing due to failing to talk to # Mock a few functions to prevent the test from failing due to failing to talk to
# a remote IS. We keep the mock for make_and_store_3pid_invite around so we # a remote IS. We keep the mock for make_and_store_3pid_invite around so we
# can check its call_count later on during the test. # can check its call_count later on during the test.
make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0))) make_invite_mock = AsyncMock(return_value=(Mock(event_id="abc"), 0))
self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock # type: ignore[assignment] self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock # type: ignore[assignment]
self.hs.get_identity_handler().lookup_3pid = Mock( # type: ignore[assignment] self.hs.get_identity_handler().lookup_3pid = AsyncMock( # type: ignore[assignment]
return_value=make_awaitable(None), return_value=None,
) )
# Add a mock to the spamchecker callbacks for user_may_send_3pid_invite. Make it # Add a mock to the spamchecker callbacks for user_may_send_3pid_invite. Make it
# allow everything for now. # allow everything for now.
# `spec` argument is needed for this function mock to have `__qualname__`, which # `spec` argument is needed for this function mock to have `__qualname__`, which
# is needed for `Measure` metrics buried in SpamChecker. # is needed for `Measure` metrics buried in SpamChecker.
mock = Mock( mock = AsyncMock(
return_value=make_awaitable(synapse.module_api.NOT_SPAM), return_value=synapse.module_api.NOT_SPAM,
spec=lambda *x: None, spec=lambda *x: None,
) )
self.hs.get_module_api_callbacks().spam_checker._user_may_send_3pid_invite_callbacks.append( self.hs.get_module_api_callbacks().spam_checker._user_may_send_3pid_invite_callbacks.append(
@ -3519,7 +3518,7 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
# Now change the return value of the callback to deny any invite and test that # Now change the return value of the callback to deny any invite and test that
# we can't send the invite. We pick an arbitrary error code to be able to check # we can't send the invite. We pick an arbitrary error code to be able to check
# that the same code has been returned # that the same code has been returned
mock.return_value = make_awaitable(Codes.CONSENT_NOT_GIVEN) mock.return_value = Codes.CONSENT_NOT_GIVEN
channel = self.make_request( channel = self.make_request(
method="POST", method="POST",
path="/rooms/" + self.room_id + "/invite", path="/rooms/" + self.room_id + "/invite",
@ -3538,7 +3537,7 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
make_invite_mock.assert_called_once() make_invite_mock.assert_called_once()
# Run variant with `Tuple[Codes, dict]`. # Run variant with `Tuple[Codes, dict]`.
mock.return_value = make_awaitable((Codes.EXPIRED_ACCOUNT, {"field": "value"})) mock.return_value = (Codes.EXPIRED_ACCOUNT, {"field": "value"})
channel = self.make_request( channel = self.make_request(
method="POST", method="POST",
path="/rooms/" + self.room_id + "/invite", path="/rooms/" + self.room_id + "/invite",

View file

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import threading import threading
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
from unittest.mock import Mock from unittest.mock import AsyncMock, Mock
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
@ -33,7 +33,6 @@ from synapse.util import Clock
from synapse.util.frozenutils import unfreeze from synapse.util.frozenutils import unfreeze
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.module_api import ModuleApi from synapse.module_api import ModuleApi
@ -477,7 +476,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
def test_on_new_event(self) -> None: def test_on_new_event(self) -> None:
"""Test that the on_new_event callback is called on new events""" """Test that the on_new_event callback is called on new events"""
on_new_event = Mock(make_awaitable(None)) on_new_event = AsyncMock(return_value=None)
self.hs.get_module_api_callbacks().third_party_event_rules._on_new_event_callbacks.append( self.hs.get_module_api_callbacks().third_party_event_rules._on_new_event_callbacks.append(
on_new_event on_new_event
) )
@ -580,7 +579,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
avatar_url = "mxc://matrix.org/oWQDvfewxmlRaRCkVbfetyEo" avatar_url = "mxc://matrix.org/oWQDvfewxmlRaRCkVbfetyEo"
# Register a mock callback. # Register a mock callback.
m = Mock(return_value=make_awaitable(None)) m = AsyncMock(return_value=None)
self.hs.get_module_api_callbacks().third_party_event_rules._on_profile_update_callbacks.append( self.hs.get_module_api_callbacks().third_party_event_rules._on_profile_update_callbacks.append(
m m
) )
@ -641,7 +640,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
avatar_url = "mxc://matrix.org/oWQDvfewxmlRaRCkVbfetyEo" avatar_url = "mxc://matrix.org/oWQDvfewxmlRaRCkVbfetyEo"
# Register a mock callback. # Register a mock callback.
m = Mock(return_value=make_awaitable(None)) m = AsyncMock(return_value=None)
self.hs.get_module_api_callbacks().third_party_event_rules._on_profile_update_callbacks.append( self.hs.get_module_api_callbacks().third_party_event_rules._on_profile_update_callbacks.append(
m m
) )
@ -682,7 +681,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
correctly when processing a user's deactivation. correctly when processing a user's deactivation.
""" """
# Register a mocked callback. # Register a mocked callback.
deactivation_mock = Mock(return_value=make_awaitable(None)) deactivation_mock = AsyncMock(return_value=None)
third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
third_party_rules._on_user_deactivation_status_changed_callbacks.append( third_party_rules._on_user_deactivation_status_changed_callbacks.append(
deactivation_mock, deactivation_mock,
@ -690,7 +689,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
# Also register a mocked callback for profile updates, to check that the # Also register a mocked callback for profile updates, to check that the
# deactivation code calls it in a way that let modules know the user is being # deactivation code calls it in a way that let modules know the user is being
# deactivated. # deactivated.
profile_mock = Mock(return_value=make_awaitable(None)) profile_mock = AsyncMock(return_value=None)
self.hs.get_module_api_callbacks().third_party_event_rules._on_profile_update_callbacks.append( self.hs.get_module_api_callbacks().third_party_event_rules._on_profile_update_callbacks.append(
profile_mock, profile_mock,
) )
@ -740,7 +739,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
well as a reactivation. well as a reactivation.
""" """
# Register a mock callback. # Register a mock callback.
m = Mock(return_value=make_awaitable(None)) m = AsyncMock(return_value=None)
third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
third_party_rules._on_user_deactivation_status_changed_callbacks.append(m) third_party_rules._on_user_deactivation_status_changed_callbacks.append(m)
@ -794,7 +793,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
correctly when processing a user's deactivation. correctly when processing a user's deactivation.
""" """
# Register a mocked callback. # Register a mocked callback.
deactivation_mock = Mock(return_value=make_awaitable(False)) deactivation_mock = AsyncMock(return_value=False)
third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
third_party_rules._check_can_deactivate_user_callbacks.append( third_party_rules._check_can_deactivate_user_callbacks.append(
deactivation_mock, deactivation_mock,
@ -840,7 +839,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
correctly when processing a user's deactivation triggered by a server admin. correctly when processing a user's deactivation triggered by a server admin.
""" """
# Register a mocked callback. # Register a mocked callback.
deactivation_mock = Mock(return_value=make_awaitable(False)) deactivation_mock = AsyncMock(return_value=False)
third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
third_party_rules._check_can_deactivate_user_callbacks.append( third_party_rules._check_can_deactivate_user_callbacks.append(
deactivation_mock, deactivation_mock,
@ -879,7 +878,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
correctly when processing an admin's shutdown room request. correctly when processing an admin's shutdown room request.
""" """
# Register a mocked callback. # Register a mocked callback.
shutdown_mock = Mock(return_value=make_awaitable(False)) shutdown_mock = AsyncMock(return_value=False)
third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
third_party_rules._check_can_shutdown_room_callbacks.append( third_party_rules._check_can_shutdown_room_callbacks.append(
shutdown_mock, shutdown_mock,
@ -915,7 +914,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
associating a 3PID to an account. associating a 3PID to an account.
""" """
# Register a mocked callback. # Register a mocked callback.
threepid_bind_mock = Mock(return_value=make_awaitable(None)) threepid_bind_mock = AsyncMock(return_value=None)
third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
third_party_rules._on_threepid_bind_callbacks.append(threepid_bind_mock) third_party_rules._on_threepid_bind_callbacks.append(threepid_bind_mock)
@ -957,11 +956,9 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
just before associating and removing a 3PID to/from an account. just before associating and removing a 3PID to/from an account.
""" """
# Pretend to be a Synapse module and register both callbacks as mocks. # Pretend to be a Synapse module and register both callbacks as mocks.
on_add_user_third_party_identifier_callback_mock = Mock( on_add_user_third_party_identifier_callback_mock = AsyncMock(return_value=None)
return_value=make_awaitable(None) on_remove_user_third_party_identifier_callback_mock = AsyncMock(
) return_value=None
on_remove_user_third_party_identifier_callback_mock = Mock(
return_value=make_awaitable(None)
) )
self.hs.get_module_api().register_third_party_rules_callbacks( self.hs.get_module_api().register_third_party_rules_callbacks(
on_add_user_third_party_identifier=on_add_user_third_party_identifier_callback_mock, on_add_user_third_party_identifier=on_add_user_third_party_identifier_callback_mock,
@ -1021,8 +1018,8 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
when a user is deactivated and their third-party ID associations are deleted. when a user is deactivated and their third-party ID associations are deleted.
""" """
# Pretend to be a Synapse module and register both callbacks as mocks. # Pretend to be a Synapse module and register both callbacks as mocks.
on_remove_user_third_party_identifier_callback_mock = Mock( on_remove_user_third_party_identifier_callback_mock = AsyncMock(
return_value=make_awaitable(None) return_value=None
) )
self.hs.get_module_api().register_third_party_rules_callbacks( self.hs.get_module_api().register_third_party_rules_callbacks(
on_remove_user_third_party_identifier=on_remove_user_third_party_identifier_callback_mock, on_remove_user_third_party_identifier=on_remove_user_third_party_identifier_callback_mock,

View file

@ -14,7 +14,7 @@
from http import HTTPStatus from http import HTTPStatus
from typing import Any, Generator, Tuple, cast from typing import Any, Generator, Tuple, cast
from unittest.mock import Mock, call from unittest.mock import AsyncMock, Mock, call
from twisted.internet import defer, reactor as _reactor from twisted.internet import defer, reactor as _reactor
@ -24,7 +24,6 @@ from synapse.types import ISynapseReactor, JsonDict
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable
from tests.utils import MockClock from tests.utils import MockClock
reactor = cast(ISynapseReactor, _reactor) reactor = cast(ISynapseReactor, _reactor)
@ -53,7 +52,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
def test_executes_given_function( def test_executes_given_function(
self, self,
) -> Generator["defer.Deferred[Any]", object, None]: ) -> Generator["defer.Deferred[Any]", object, None]:
cb = Mock(return_value=make_awaitable(self.mock_http_response)) cb = AsyncMock(return_value=self.mock_http_response)
res = yield self.cache.fetch_or_execute_request( res = yield self.cache.fetch_or_execute_request(
self.mock_request, self.mock_requester, cb, "some_arg", keyword="arg" self.mock_request, self.mock_requester, cb, "some_arg", keyword="arg"
) )
@ -64,7 +63,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
def test_deduplicates_based_on_key( def test_deduplicates_based_on_key(
self, self,
) -> Generator["defer.Deferred[Any]", object, None]: ) -> Generator["defer.Deferred[Any]", object, None]:
cb = Mock(return_value=make_awaitable(self.mock_http_response)) cb = AsyncMock(return_value=self.mock_http_response)
for i in range(3): # invoke multiple times for i in range(3): # invoke multiple times
res = yield self.cache.fetch_or_execute_request( res = yield self.cache.fetch_or_execute_request(
self.mock_request, self.mock_request,
@ -168,7 +167,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_cleans_up(self) -> Generator["defer.Deferred[Any]", object, None]: def test_cleans_up(self) -> Generator["defer.Deferred[Any]", object, None]:
cb = Mock(return_value=make_awaitable(self.mock_http_response)) cb = AsyncMock(return_value=self.mock_http_response)
yield self.cache.fetch_or_execute_request( yield self.cache.fetch_or_execute_request(
self.mock_request, self.mock_requester, cb, "an arg" self.mock_request, self.mock_requester, cb, "an arg"
) )

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Tuple from typing import Tuple
from unittest.mock import Mock from unittest.mock import AsyncMock, Mock
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
@ -29,7 +29,6 @@ from synapse.types import JsonDict
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable
from tests.unittest import override_config from tests.unittest import override_config
from tests.utils import default_config from tests.utils import default_config
@ -69,24 +68,22 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
assert isinstance(rlsn, ResourceLimitsServerNotices) assert isinstance(rlsn, ResourceLimitsServerNotices)
self._rlsn = rlsn self._rlsn = rlsn
self._rlsn._store.user_last_seen_monthly_active = Mock( self._rlsn._store.user_last_seen_monthly_active = AsyncMock(return_value=1000)
return_value=make_awaitable(1000) self._rlsn._server_notices_manager.send_notice = AsyncMock( # type: ignore[assignment]
) return_value=Mock()
self._rlsn._server_notices_manager.send_notice = Mock( # type: ignore[assignment]
return_value=make_awaitable(Mock())
) )
self._send_notice = self._rlsn._server_notices_manager.send_notice self._send_notice = self._rlsn._server_notices_manager.send_notice
self.user_id = "@user_id:test" self.user_id = "@user_id:test"
self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = Mock( self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = (
return_value=make_awaitable("!something:localhost") AsyncMock(return_value="!something:localhost")
) )
self._rlsn._server_notices_manager.maybe_get_notice_room_for_user = Mock( self._rlsn._server_notices_manager.maybe_get_notice_room_for_user = AsyncMock(
return_value=make_awaitable("!something:localhost") return_value="!something:localhost"
) )
self._rlsn._store.add_tag_to_room = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] self._rlsn._store.add_tag_to_room = AsyncMock(return_value=None) # type: ignore[assignment]
self._rlsn._store.get_tags_for_room = Mock(return_value=make_awaitable({})) # type: ignore[assignment] self._rlsn._store.get_tags_for_room = AsyncMock(return_value={}) # type: ignore[assignment]
@override_config({"hs_disabled": True}) @override_config({"hs_disabled": True})
def test_maybe_send_server_notice_disabled_hs(self) -> None: def test_maybe_send_server_notice_disabled_hs(self) -> None:
@ -103,14 +100,14 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
def test_maybe_send_server_notice_to_user_remove_blocked_notice(self) -> None: def test_maybe_send_server_notice_to_user_remove_blocked_notice(self) -> None:
"""Test when user has blocked notice, but should have it removed""" """Test when user has blocked notice, but should have it removed"""
self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[assignment]
return_value=make_awaitable(None) return_value=None
) )
mock_event = Mock( mock_event = Mock(
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
) )
self._rlsn._store.get_events = Mock( # type: ignore[assignment] self._rlsn._store.get_events = AsyncMock( # type: ignore[assignment]
return_value=make_awaitable({"123": mock_event}) return_value={"123": mock_event}
) )
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
# Would be better to check the content, but once == remove blocking event # Would be better to check the content, but once == remove blocking event
@ -125,16 +122,16 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
""" """
Test when user has blocked notice, but notice ought to be there (NOOP) Test when user has blocked notice, but notice ought to be there (NOOP)
""" """
self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[assignment]
return_value=make_awaitable(None), return_value=None,
side_effect=ResourceLimitError(403, "foo"), side_effect=ResourceLimitError(403, "foo"),
) )
mock_event = Mock( mock_event = Mock(
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
) )
self._rlsn._store.get_events = Mock( # type: ignore[assignment] self._rlsn._store.get_events = AsyncMock( # type: ignore[assignment]
return_value=make_awaitable({"123": mock_event}) return_value={"123": mock_event}
) )
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@ -145,8 +142,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
""" """
Test when user does not have blocked notice, but should have one Test when user does not have blocked notice, but should have one
""" """
self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[assignment]
return_value=make_awaitable(None), return_value=None,
side_effect=ResourceLimitError(403, "foo"), side_effect=ResourceLimitError(403, "foo"),
) )
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@ -158,8 +155,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
""" """
Test when user does not have blocked notice, nor should they (NOOP) Test when user does not have blocked notice, nor should they (NOOP)
""" """
self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[assignment]
return_value=make_awaitable(None) return_value=None
) )
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@ -171,12 +168,10 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
Test when user is not part of the MAU cohort - this should not ever Test when user is not part of the MAU cohort - this should not ever
happen - but ... happen - but ...
""" """
self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[assignment]
return_value=make_awaitable(None) return_value=None
)
self._rlsn._store.user_last_seen_monthly_active = Mock(
return_value=make_awaitable(None)
) )
self._rlsn._store.user_last_seen_monthly_active = AsyncMock(return_value=None)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
self._send_notice.assert_not_called() self._send_notice.assert_not_called()
@ -189,8 +184,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
Test that when server is over MAU limit and alerting is suppressed, then Test that when server is over MAU limit and alerting is suppressed, then
an alert message is not sent into the room an alert message is not sent into the room
""" """
self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[assignment]
return_value=make_awaitable(None), return_value=None,
side_effect=ResourceLimitError( side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
), ),
@ -204,8 +199,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
""" """
Test that when a server is disabled, that MAU limit alerting is ignored. Test that when a server is disabled, that MAU limit alerting is ignored.
""" """
self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[assignment]
return_value=make_awaitable(None), return_value=None,
side_effect=ResourceLimitError( side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED 403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED
), ),
@ -223,22 +218,22 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
When the room is already in a blocked state, test that when alerting When the room is already in a blocked state, test that when alerting
is suppressed that the room is returned to an unblocked state. is suppressed that the room is returned to an unblocked state.
""" """
self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[assignment]
return_value=make_awaitable(None), return_value=None,
side_effect=ResourceLimitError( side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
), ),
) )
self._rlsn._is_room_currently_blocked = Mock( # type: ignore[assignment] self._rlsn._is_room_currently_blocked = AsyncMock( # type: ignore[assignment]
return_value=make_awaitable((True, [])) return_value=(True, [])
) )
mock_event = Mock( mock_event = Mock(
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
) )
self._rlsn._store.get_events = Mock( # type: ignore[assignment] self._rlsn._store.get_events = AsyncMock( # type: ignore[assignment]
return_value=make_awaitable({"123": mock_event}) return_value={"123": mock_event}
) )
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@ -284,11 +279,9 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
self.user_id = "@user_id:test" self.user_id = "@user_id:test"
def test_server_notice_only_sent_once(self) -> None: def test_server_notice_only_sent_once(self) -> None:
self.store.get_monthly_active_count = Mock(return_value=make_awaitable(1000)) self.store.get_monthly_active_count = AsyncMock(return_value=1000)
self.store.user_last_seen_monthly_active = Mock( self.store.user_last_seen_monthly_active = AsyncMock(return_value=1000)
return_value=make_awaitable(1000)
)
# Call the function multiple times to ensure we only send the notice once # Call the function multiple times to ensure we only send the notice once
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@ -327,7 +320,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
hasn't been reached (since it's the only user and the limit is 5), so users hasn't been reached (since it's the only user and the limit is 5), so users
shouldn't receive a server notice. shouldn't receive a server notice.
""" """
m = Mock(return_value=make_awaitable(None)) m = AsyncMock(return_value=None)
self._rlsn._server_notices_manager.maybe_get_notice_room_for_user = m self._rlsn._server_notices_manager.maybe_get_notice_room_for_user = m
user_id = self.register_user("user", "password") user_id = self.register_user("user", "password")

View file

@ -15,7 +15,7 @@ import json
import os import os
import tempfile import tempfile
from typing import List, cast from typing import List, cast
from unittest.mock import Mock from unittest.mock import AsyncMock, Mock
import yaml import yaml
@ -35,7 +35,6 @@ from synapse.types import DeviceListUpdates
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable
class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase): class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase):
@ -339,7 +338,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
# we aren't testing store._base stuff here, so mock this out # we aren't testing store._base stuff here, so mock this out
# (ignore needed because Mypy won't allow us to assign to a method otherwise) # (ignore needed because Mypy won't allow us to assign to a method otherwise)
self.store.get_events_as_list = Mock(return_value=make_awaitable(events)) # type: ignore[assignment] self.store.get_events_as_list = AsyncMock(return_value=events) # type: ignore[assignment]
self.get_success(self._insert_txn(self.as_list[1]["id"], 9, other_events)) self.get_success(self._insert_txn(self.as_list[1]["id"], 9, other_events))
self.get_success(self._insert_txn(service.id, 10, events)) self.get_success(self._insert_txn(service.id, 10, events))

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from unittest.mock import Mock from unittest.mock import AsyncMock, Mock
import yaml import yaml
@ -32,7 +32,7 @@ from synapse.types import JsonDict
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable, simple_async_mock from tests.test_utils import simple_async_mock
from tests.unittest import override_config from tests.unittest import override_config
@ -363,9 +363,9 @@ class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
# Register the callbacks with more mocks # Register the callbacks with more mocks
self.hs.get_module_api().register_background_update_controller_callbacks( self.hs.get_module_api().register_background_update_controller_callbacks(
on_update=self._on_update, on_update=self._on_update,
min_batch_size=Mock(return_value=make_awaitable(self._default_batch_size)), min_batch_size=AsyncMock(return_value=self._default_batch_size),
default_batch_size=Mock( default_batch_size=AsyncMock(
return_value=make_awaitable(self._default_batch_size), return_value=self._default_batch_size,
), ),
) )

View file

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
from typing import Any, Dict from typing import Any, Dict
from unittest.mock import Mock from unittest.mock import AsyncMock
from parameterized import parameterized from parameterized import parameterized
@ -30,7 +30,6 @@ from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.server import make_request from tests.server import make_request
from tests.test_utils import make_awaitable
from tests.unittest import override_config from tests.unittest import override_config
@ -443,9 +442,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
lots_of_users = 100 lots_of_users = 100
user_id = "@user:server" user_id = "@user:server"
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = AsyncMock(return_value=lots_of_users)
return_value=make_awaitable(lots_of_users)
)
self.get_success( self.get_success(
self.store.insert_client_ip( self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", "device_id" user_id, "access_token", "ip", "user_agent", "device_id"

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Dict, List from typing import Any, Dict, List
from unittest.mock import Mock from unittest.mock import AsyncMock
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
@ -21,7 +21,6 @@ from synapse.server import HomeServer
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable
from tests.unittest import default_config, override_config from tests.unittest import default_config, override_config
FORTY_DAYS = 40 * 24 * 60 * 60 FORTY_DAYS = 40 * 24 * 60 * 60
@ -253,7 +252,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
) )
self.get_success(d) self.get_success(d)
self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] self.store.upsert_monthly_active_user = AsyncMock(return_value=None) # type: ignore[assignment]
d = self.store.populate_monthly_active_users(user_id) d = self.store.populate_monthly_active_users(user_id)
self.get_success(d) self.get_success(d)
@ -261,24 +260,22 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.upsert_monthly_active_user.assert_not_called() self.store.upsert_monthly_active_user.assert_not_called()
def test_populate_monthly_users_should_update(self) -> None: def test_populate_monthly_users_should_update(self) -> None:
self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] self.store.upsert_monthly_active_user = AsyncMock(return_value=None) # type: ignore[assignment]
self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment] self.store.is_trial_user = AsyncMock(return_value=False) # type: ignore[assignment]
self.store.user_last_seen_monthly_active = Mock( self.store.user_last_seen_monthly_active = AsyncMock(return_value=None)
return_value=make_awaitable(None)
)
d = self.store.populate_monthly_active_users("user_id") d = self.store.populate_monthly_active_users("user_id")
self.get_success(d) self.get_success(d)
self.store.upsert_monthly_active_user.assert_called_once() self.store.upsert_monthly_active_user.assert_called_once()
def test_populate_monthly_users_should_not_update(self) -> None: def test_populate_monthly_users_should_not_update(self) -> None:
self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] self.store.upsert_monthly_active_user = AsyncMock(return_value=None) # type: ignore[assignment]
self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment] self.store.is_trial_user = AsyncMock(return_value=False) # type: ignore[assignment]
self.store.user_last_seen_monthly_active = Mock( self.store.user_last_seen_monthly_active = AsyncMock(
return_value=make_awaitable(self.hs.get_clock().time_msec()) return_value=self.hs.get_clock().time_msec()
) )
d = self.store.populate_monthly_active_users("user_id") d = self.store.populate_monthly_active_users("user_id")
@ -359,7 +356,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
@override_config({"limit_usage_by_mau": False, "mau_stats_only": False}) @override_config({"limit_usage_by_mau": False, "mau_stats_only": False})
def test_no_users_when_not_tracking(self) -> None: def test_no_users_when_not_tracking(self) -> None:
self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] self.store.upsert_monthly_active_user = AsyncMock(return_value=None) # type: ignore[assignment]
self.get_success(self.store.populate_monthly_active_users("@user:sever")) self.get_success(self.store.populate_monthly_active_users("@user:sever"))

View file

@ -22,7 +22,6 @@ from synapse.storage.util.partial_state_events_tracker import (
PartialStateEventsTracker, PartialStateEventsTracker,
) )
from tests.test_utils import make_awaitable
from tests.unittest import TestCase from tests.unittest import TestCase
@ -124,16 +123,17 @@ class PartialStateEventsTrackerTestCase(TestCase):
class PartialCurrentStateTrackerTestCase(TestCase): class PartialCurrentStateTrackerTestCase(TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.mock_store = mock.Mock(spec_set=["is_partial_state_room"]) self.mock_store = mock.Mock(spec_set=["is_partial_state_room"])
self.mock_store.is_partial_state_room = mock.AsyncMock()
self.tracker = PartialCurrentStateTracker(self.mock_store) self.tracker = PartialCurrentStateTracker(self.mock_store)
def test_does_not_block_for_full_state_rooms(self) -> None: def test_does_not_block_for_full_state_rooms(self) -> None:
self.mock_store.is_partial_state_room.return_value = make_awaitable(False) self.mock_store.is_partial_state_room.return_value = False
self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id"))) self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id")))
def test_blocks_for_partial_room_state(self) -> None: def test_blocks_for_partial_room_state(self) -> None:
self.mock_store.is_partial_state_room.return_value = make_awaitable(True) self.mock_store.is_partial_state_room.return_value = True
d = ensureDeferred(self.tracker.await_full_state("room_id")) d = ensureDeferred(self.tracker.await_full_state("room_id"))
@ -156,7 +156,7 @@ class PartialCurrentStateTrackerTestCase(TestCase):
self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id"))) self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id")))
def test_cancellation(self) -> None: def test_cancellation(self) -> None:
self.mock_store.is_partial_state_room.return_value = make_awaitable(True) self.mock_store.is_partial_state_room.return_value = True
d1 = ensureDeferred(self.tracker.await_full_state("room_id")) d1 = ensureDeferred(self.tracker.await_full_state("room_id"))
self.assertNoResult(d1) self.assertNoResult(d1)

View file

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from typing import Collection, List, Optional, Union from typing import Collection, List, Optional, Union
from unittest.mock import Mock from unittest.mock import AsyncMock, Mock
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
@ -31,7 +31,6 @@ from synapse.util import Clock
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable
class MessageAcceptTests(unittest.HomeserverTestCase): class MessageAcceptTests(unittest.HomeserverTestCase):
@ -196,7 +195,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
# Register a mock on the store so that the incoming update doesn't fail because # Register a mock on the store so that the incoming update doesn't fail because
# we don't share a room with the user. # we don't share a room with the user.
store = self.hs.get_datastores().main store = self.hs.get_datastores().main
store.get_rooms_for_user = Mock(return_value=make_awaitable(["!someroom:test"])) store.get_rooms_for_user = AsyncMock(return_value=["!someroom:test"])
# Manually inject a fake device list update. We need this update to include at # Manually inject a fake device list update. We need this update to include at
# least one prev_id so that the user's device list will need to be retried. # least one prev_id so that the user's device list will need to be retried.
@ -241,27 +240,24 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
# Register mock device list retrieval on the federation client. # Register mock device list retrieval on the federation client.
federation_client = self.hs.get_federation_client() federation_client = self.hs.get_federation_client()
federation_client.query_user_devices = Mock( # type: ignore[assignment] federation_client.query_user_devices = AsyncMock( # type: ignore[assignment]
return_value=make_awaitable( return_value={
{ "user_id": remote_user_id,
"stream_id": 1,
"devices": [],
"master_key": {
"user_id": remote_user_id, "user_id": remote_user_id,
"stream_id": 1, "usage": ["master"],
"devices": [], "keys": {"ed25519:" + remote_master_key: remote_master_key},
"master_key": { },
"user_id": remote_user_id, "self_signing_key": {
"usage": ["master"], "user_id": remote_user_id,
"keys": {"ed25519:" + remote_master_key: remote_master_key}, "usage": ["self_signing"],
"keys": {
"ed25519:" + remote_self_signing_key: remote_self_signing_key
}, },
"self_signing_key": { },
"user_id": remote_user_id, }
"usage": ["self_signing"],
"keys": {
"ed25519:"
+ remote_self_signing_key: remote_self_signing_key
},
},
}
)
) )
# Resync the device list. # Resync the device list.

View file

@ -18,7 +18,6 @@ Utilities for running the unit tests
import json import json
import sys import sys
import warnings import warnings
from asyncio import Future
from binascii import unhexlify from binascii import unhexlify
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, TypeVar from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, TypeVar
from unittest.mock import Mock from unittest.mock import Mock
@ -57,17 +56,6 @@ def get_awaitable_result(awaitable: Awaitable[TV]) -> TV:
raise Exception("awaitable has not yet completed") raise Exception("awaitable has not yet completed")
def make_awaitable(result: TV) -> Awaitable[TV]:
"""
Makes an awaitable, suitable for mocking an `async` function.
This uses Futures as they can be awaited multiple times so can be returned
to multiple callers.
"""
future: Future[TV] = Future()
future.set_result(result)
return future
def setup_awaitable_errors() -> Callable[[], None]: def setup_awaitable_errors() -> Callable[[], None]:
""" """
Convert warnings from a non-awaited coroutines into errors. Convert warnings from a non-awaited coroutines into errors.