Do not yield on awaitables in tests. (#8193)

This commit is contained in:
Patrick Cloke 2020-08-27 17:24:46 -04:00 committed by GitHub
parent b49a5b9307
commit e00816ad98
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 230 additions and 131 deletions

1
changelog.d/8193.misc Normal file
View file

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

View file

@ -369,8 +369,10 @@ class FilteringTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_filter_presence_match(self): def test_filter_presence_match(self):
user_filter_json = {"presence": {"types": ["m.*"]}} user_filter_json = {"presence": {"types": ["m.*"]}}
filter_id = yield self.datastore.add_user_filter( filter_id = yield defer.ensureDeferred(
user_localpart=user_localpart, user_filter=user_filter_json self.datastore.add_user_filter(
user_localpart=user_localpart, user_filter=user_filter_json
)
) )
event = MockEvent(sender="@foo:bar", type="m.profile") event = MockEvent(sender="@foo:bar", type="m.profile")
events = [event] events = [event]
@ -388,8 +390,10 @@ class FilteringTestCase(unittest.TestCase):
def test_filter_presence_no_match(self): def test_filter_presence_no_match(self):
user_filter_json = {"presence": {"types": ["m.*"]}} user_filter_json = {"presence": {"types": ["m.*"]}}
filter_id = yield self.datastore.add_user_filter( filter_id = yield defer.ensureDeferred(
user_localpart=user_localpart + "2", user_filter=user_filter_json self.datastore.add_user_filter(
user_localpart=user_localpart + "2", user_filter=user_filter_json
)
) )
event = MockEvent( event = MockEvent(
event_id="$asdasd:localhost", event_id="$asdasd:localhost",
@ -410,8 +414,10 @@ class FilteringTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_filter_room_state_match(self): def test_filter_room_state_match(self):
user_filter_json = {"room": {"state": {"types": ["m.*"]}}} user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
filter_id = yield self.datastore.add_user_filter( filter_id = yield defer.ensureDeferred(
user_localpart=user_localpart, user_filter=user_filter_json self.datastore.add_user_filter(
user_localpart=user_localpart, user_filter=user_filter_json
)
) )
event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar") event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar")
events = [event] events = [event]
@ -428,8 +434,10 @@ class FilteringTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_filter_room_state_no_match(self): def test_filter_room_state_no_match(self):
user_filter_json = {"room": {"state": {"types": ["m.*"]}}} user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
filter_id = yield self.datastore.add_user_filter( filter_id = yield defer.ensureDeferred(
user_localpart=user_localpart, user_filter=user_filter_json self.datastore.add_user_filter(
user_localpart=user_localpart, user_filter=user_filter_json
)
) )
event = MockEvent( event = MockEvent(
sender="@foo:bar", type="org.matrix.custom.event", room_id="!foo:bar" sender="@foo:bar", type="org.matrix.custom.event", room_id="!foo:bar"
@ -465,8 +473,10 @@ class FilteringTestCase(unittest.TestCase):
def test_add_filter(self): def test_add_filter(self):
user_filter_json = {"room": {"state": {"types": ["m.*"]}}} user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
filter_id = yield self.filtering.add_user_filter( filter_id = yield defer.ensureDeferred(
user_localpart=user_localpart, user_filter=user_filter_json self.filtering.add_user_filter(
user_localpart=user_localpart, user_filter=user_filter_json
)
) )
self.assertEquals(filter_id, 0) self.assertEquals(filter_id, 0)
@ -485,8 +495,10 @@ class FilteringTestCase(unittest.TestCase):
def test_get_filter(self): def test_get_filter(self):
user_filter_json = {"room": {"state": {"types": ["m.*"]}}} user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
filter_id = yield self.datastore.add_user_filter( filter_id = yield defer.ensureDeferred(
user_localpart=user_localpart, user_filter=user_filter_json self.datastore.add_user_filter(
user_localpart=user_localpart, user_filter=user_filter_json
)
) )
filter = yield defer.ensureDeferred( filter = yield defer.ensureDeferred(

View file

@ -190,7 +190,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
# should fail immediately on an unsigned object # should fail immediately on an unsigned object
d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned") d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned")
self.failureResultOf(d, SynapseError) self.get_failure(d, SynapseError)
# should succeed on a signed object # should succeed on a signed object
d = _verify_json_for_server(kr, "server9", json1, 500, "test signed") d = _verify_json_for_server(kr, "server9", json1, 500, "test signed")
@ -221,7 +221,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
# should fail immediately on an unsigned object # should fail immediately on an unsigned object
d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned") d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned")
self.failureResultOf(d, SynapseError) self.get_failure(d, SynapseError)
# should fail on a signed object with a non-zero minimum_valid_until_ms, # should fail on a signed object with a non-zero minimum_valid_until_ms,
# as it tries to refetch the keys and fails. # as it tries to refetch the keys and fails.

View file

@ -15,8 +15,6 @@
from mock import Mock from mock import Mock
from twisted.internet import defer
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client.v1 import login, room from synapse.rest.client.v1 import login, room
@ -60,7 +58,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
# Artificially raise the complexity # Artificially raise the complexity
store = self.hs.get_datastore() store = self.hs.get_datastore()
store.get_current_state_event_counts = lambda x: defer.succeed(500 * 1.23) store.get_current_state_event_counts = lambda x: make_awaitable(500 * 1.23)
# Get the room complexity again -- make sure it's our artificial value # Get the room complexity again -- make sure it's our artificial value
request, channel = self.make_request( request, channel = self.make_request(
@ -160,7 +158,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
) )
# Artificially raise the complexity # Artificially raise the complexity
self.hs.get_datastore().get_current_state_event_counts = lambda x: defer.succeed( self.hs.get_datastore().get_current_state_event_counts = lambda x: make_awaitable(
600 600
) )

View file

@ -155,7 +155,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: make_awaitable( self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: make_awaitable(
([], 0) ([], 0)
) )
self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: make_awaitable(
None
)
self.datastore.set_received_txn_response = lambda *args, **kwargs: make_awaitable( self.datastore.set_received_txn_response = lambda *args, **kwargs: make_awaitable(
None None
) )

View file

@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
from synapse.api.errors import Codes from synapse.api.errors import Codes
from synapse.rest.client.v2_alpha import filter from synapse.rest.client.v2_alpha import filter
@ -73,8 +75,10 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN) self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN)
def test_get_filter(self): def test_get_filter(self):
filter_id = self.filtering.add_user_filter( filter_id = defer.ensureDeferred(
user_localpart="apple", user_filter=self.EXAMPLE_FILTER self.filtering.add_user_filter(
user_localpart="apple", user_filter=self.EXAMPLE_FILTER
)
) )
self.reactor.advance(1) self.reactor.advance(1)
filter_id = filter_id.result filter_id = filter_id.result

View file

@ -243,7 +243,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
def test_create_appservice_txn_first(self): def test_create_appservice_txn_first(self):
service = Mock(id=self.as_list[0]["id"]) service = Mock(id=self.as_list[0]["id"])
events = [Mock(event_id="e1"), Mock(event_id="e2")] events = [Mock(event_id="e1"), Mock(event_id="e2")]
txn = yield self.store.create_appservice_txn(service, events) txn = yield defer.ensureDeferred(
self.store.create_appservice_txn(service, events)
)
self.assertEquals(txn.id, 1) self.assertEquals(txn.id, 1)
self.assertEquals(txn.events, events) self.assertEquals(txn.events, events)
self.assertEquals(txn.service, service) self.assertEquals(txn.service, service)
@ -255,7 +257,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
yield self._set_last_txn(service.id, 9643) # AS is falling behind yield self._set_last_txn(service.id, 9643) # AS is falling behind
yield self._insert_txn(service.id, 9644, events) yield self._insert_txn(service.id, 9644, events)
yield self._insert_txn(service.id, 9645, events) yield self._insert_txn(service.id, 9645, events)
txn = yield self.store.create_appservice_txn(service, events) txn = yield defer.ensureDeferred(
self.store.create_appservice_txn(service, events)
)
self.assertEquals(txn.id, 9646) self.assertEquals(txn.id, 9646)
self.assertEquals(txn.events, events) self.assertEquals(txn.events, events)
self.assertEquals(txn.service, service) self.assertEquals(txn.service, service)
@ -265,7 +269,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
service = Mock(id=self.as_list[0]["id"]) service = Mock(id=self.as_list[0]["id"])
events = [Mock(event_id="e1"), Mock(event_id="e2")] events = [Mock(event_id="e1"), Mock(event_id="e2")]
yield self._set_last_txn(service.id, 9643) yield self._set_last_txn(service.id, 9643)
txn = yield self.store.create_appservice_txn(service, events) txn = yield defer.ensureDeferred(
self.store.create_appservice_txn(service, events)
)
self.assertEquals(txn.id, 9644) self.assertEquals(txn.id, 9644)
self.assertEquals(txn.events, events) self.assertEquals(txn.events, events)
self.assertEquals(txn.service, service) self.assertEquals(txn.service, service)
@ -286,7 +292,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
yield self._insert_txn(self.as_list[2]["id"], 10, events) yield self._insert_txn(self.as_list[2]["id"], 10, events)
yield self._insert_txn(self.as_list[3]["id"], 9643, events) yield self._insert_txn(self.as_list[3]["id"], 9643, events)
txn = yield self.store.create_appservice_txn(service, events) txn = yield defer.ensureDeferred(
self.store.create_appservice_txn(service, events)
)
self.assertEquals(txn.id, 9644) self.assertEquals(txn.id, 9644)
self.assertEquals(txn.events, events) self.assertEquals(txn.events, events)
self.assertEquals(txn.service, service) self.assertEquals(txn.service, service)
@ -298,7 +306,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
txn_id = 1 txn_id = 1
yield self._insert_txn(service.id, txn_id, events) yield self._insert_txn(service.id, txn_id, events)
yield self.store.complete_appservice_txn(txn_id=txn_id, service=service) yield defer.ensureDeferred(
self.store.complete_appservice_txn(txn_id=txn_id, service=service)
)
res = yield self.db_pool.runQuery( res = yield self.db_pool.runQuery(
self.engine.convert_param_style( self.engine.convert_param_style(
@ -324,7 +334,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
txn_id = 5 txn_id = 5
yield self._set_last_txn(service.id, 4) yield self._set_last_txn(service.id, 4)
yield self._insert_txn(service.id, txn_id, events) yield self._insert_txn(service.id, txn_id, events)
yield self.store.complete_appservice_txn(txn_id=txn_id, service=service) yield defer.ensureDeferred(
self.store.complete_appservice_txn(txn_id=txn_id, service=service)
)
res = yield self.db_pool.runQuery( res = yield self.db_pool.runQuery(
self.engine.convert_param_style( self.engine.convert_param_style(

View file

@ -1,7 +1,5 @@
from mock import Mock from mock import Mock
from twisted.internet import defer
from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.background_updates import BackgroundUpdater
from tests import unittest from tests import unittest
@ -38,11 +36,10 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
) )
# first step: make a bit of progress # first step: make a bit of progress
@defer.inlineCallbacks async def update(progress, count):
def update(progress, count): await self.clock.sleep((count * duration_ms) / 1000)
yield self.clock.sleep((count * duration_ms) / 1000)
progress = {"my_key": progress["my_key"] + 1} progress = {"my_key": progress["my_key"] + 1}
yield store.db_pool.runInteraction( await store.db_pool.runInteraction(
"update_progress", "update_progress",
self.updates._background_update_progress_txn, self.updates._background_update_progress_txn,
"test_update", "test_update",

View file

@ -32,7 +32,9 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
yield defer.ensureDeferred(self.store.store_device("user", "device", None)) yield defer.ensureDeferred(self.store.store_device("user", "device", None))
yield self.store.set_e2e_device_keys("user", "device", now, json) yield defer.ensureDeferred(
self.store.set_e2e_device_keys("user", "device", now, json)
)
res = yield defer.ensureDeferred( res = yield defer.ensureDeferred(
self.store.get_e2e_device_keys((("user", "device"),)) self.store.get_e2e_device_keys((("user", "device"),))
@ -49,12 +51,16 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
yield defer.ensureDeferred(self.store.store_device("user", "device", None)) yield defer.ensureDeferred(self.store.store_device("user", "device", None))
changed = yield self.store.set_e2e_device_keys("user", "device", now, json) changed = yield defer.ensureDeferred(
self.store.set_e2e_device_keys("user", "device", now, json)
)
self.assertTrue(changed) self.assertTrue(changed)
# If we try to upload the same key then we should be told nothing # If we try to upload the same key then we should be told nothing
# changed # changed
changed = yield self.store.set_e2e_device_keys("user", "device", now, json) changed = yield defer.ensureDeferred(
self.store.set_e2e_device_keys("user", "device", now, json)
)
self.assertFalse(changed) self.assertFalse(changed)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -62,7 +68,9 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
now = 1470174257070 now = 1470174257070
json = {"key": "value"} json = {"key": "value"}
yield self.store.set_e2e_device_keys("user", "device", now, json) yield defer.ensureDeferred(
self.store.set_e2e_device_keys("user", "device", now, json)
)
yield defer.ensureDeferred( yield defer.ensureDeferred(
self.store.store_device("user", "device", "display_name") self.store.store_device("user", "device", "display_name")
) )
@ -86,10 +94,18 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
yield defer.ensureDeferred(self.store.store_device("user2", "device1", None)) yield defer.ensureDeferred(self.store.store_device("user2", "device1", None))
yield defer.ensureDeferred(self.store.store_device("user2", "device2", None)) yield defer.ensureDeferred(self.store.store_device("user2", "device2", None))
yield self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"}) yield defer.ensureDeferred(
yield self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"}) self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"})
yield self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"}) )
yield self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"}) yield defer.ensureDeferred(
self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"})
)
yield defer.ensureDeferred(
self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"})
)
yield defer.ensureDeferred(
self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"})
)
res = yield defer.ensureDeferred( res = yield defer.ensureDeferred(
self.store.get_e2e_device_keys((("user1", "device1"), ("user2", "device2"))) self.store.get_e2e_device_keys((("user1", "device1"), ("user2", "device2")))

View file

@ -60,8 +60,10 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def _assert_counts(noitf_count, highlight_count): def _assert_counts(noitf_count, highlight_count):
counts = yield self.store.db_pool.runInteraction( counts = yield defer.ensureDeferred(
"", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0 self.store.db_pool.runInteraction(
"", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0
)
) )
self.assertEquals( self.assertEquals(
counts, counts,
@ -81,25 +83,31 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
event.event_id, {user_id: action} event.event_id, {user_id: action}
) )
) )
yield self.store.db_pool.runInteraction( yield defer.ensureDeferred(
"", self.store.db_pool.runInteraction(
self.persist_events_store._set_push_actions_for_event_and_users_txn, "",
[(event, None)], self.persist_events_store._set_push_actions_for_event_and_users_txn,
[(event, None)], [(event, None)],
[(event, None)],
)
) )
def _rotate(stream): def _rotate(stream):
return self.store.db_pool.runInteraction( return defer.ensureDeferred(
"", self.store._rotate_notifs_before_txn, stream self.store.db_pool.runInteraction(
"", self.store._rotate_notifs_before_txn, stream
)
) )
def _mark_read(stream, depth): def _mark_read(stream, depth):
return self.store.db_pool.runInteraction( return defer.ensureDeferred(
"", self.store.db_pool.runInteraction(
self.store._remove_old_push_actions_before_txn, "",
room_id, self.store._remove_old_push_actions_before_txn,
user_id, room_id,
stream, user_id,
stream,
)
) )
yield _assert_counts(0, 0) yield _assert_counts(0, 0)
@ -163,16 +171,24 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
) )
# start with the base case where there are no events in the table # start with the base case where there are no events in the table
r = yield self.store.find_first_stream_ordering_after_ts(11) r = yield defer.ensureDeferred(
self.store.find_first_stream_ordering_after_ts(11)
)
self.assertEqual(r, 0) self.assertEqual(r, 0)
# now with one event # now with one event
yield add_event(2, 10) yield add_event(2, 10)
r = yield self.store.find_first_stream_ordering_after_ts(9) r = yield defer.ensureDeferred(
self.store.find_first_stream_ordering_after_ts(9)
)
self.assertEqual(r, 2) self.assertEqual(r, 2)
r = yield self.store.find_first_stream_ordering_after_ts(10) r = yield defer.ensureDeferred(
self.store.find_first_stream_ordering_after_ts(10)
)
self.assertEqual(r, 2) self.assertEqual(r, 2)
r = yield self.store.find_first_stream_ordering_after_ts(11) r = yield defer.ensureDeferred(
self.store.find_first_stream_ordering_after_ts(11)
)
self.assertEqual(r, 3) self.assertEqual(r, 3)
# add a bunch of dummy events to the events table # add a bunch of dummy events to the events table
@ -185,25 +201,37 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
): ):
yield add_event(stream_ordering, ts) yield add_event(stream_ordering, ts)
r = yield self.store.find_first_stream_ordering_after_ts(110) r = yield defer.ensureDeferred(
self.store.find_first_stream_ordering_after_ts(110)
)
self.assertEqual(r, 3, "First event after 110ms should be 3, was %i" % r) self.assertEqual(r, 3, "First event after 110ms should be 3, was %i" % r)
# 4 and 5 are both after 120: we want 4 rather than 5 # 4 and 5 are both after 120: we want 4 rather than 5
r = yield self.store.find_first_stream_ordering_after_ts(120) r = yield defer.ensureDeferred(
self.store.find_first_stream_ordering_after_ts(120)
)
self.assertEqual(r, 4, "First event after 120ms should be 4, was %i" % r) self.assertEqual(r, 4, "First event after 120ms should be 4, was %i" % r)
r = yield self.store.find_first_stream_ordering_after_ts(129) r = yield defer.ensureDeferred(
self.store.find_first_stream_ordering_after_ts(129)
)
self.assertEqual(r, 10, "First event after 129ms should be 10, was %i" % r) self.assertEqual(r, 10, "First event after 129ms should be 10, was %i" % r)
# check we can get the last event # check we can get the last event
r = yield self.store.find_first_stream_ordering_after_ts(140) r = yield defer.ensureDeferred(
self.store.find_first_stream_ordering_after_ts(140)
)
self.assertEqual(r, 20, "First event after 14ms should be 20, was %i" % r) self.assertEqual(r, 20, "First event after 14ms should be 20, was %i" % r)
# off the end # off the end
r = yield self.store.find_first_stream_ordering_after_ts(160) r = yield defer.ensureDeferred(
self.store.find_first_stream_ordering_after_ts(160)
)
self.assertEqual(r, 21) self.assertEqual(r, 21)
# check we can find an event at ordering zero # check we can find an event at ordering zero
yield add_event(0, 5) yield add_event(0, 5)
r = yield self.store.find_first_stream_ordering_after_ts(1) r = yield defer.ensureDeferred(
self.store.find_first_stream_ordering_after_ts(1)
)
self.assertEqual(r, 0) self.assertEqual(r, 0)

View file

@ -34,14 +34,16 @@ class DataStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_users_paginate(self): def test_get_users_paginate(self):
yield self.store.register_user(self.user.to_string(), "pass") yield defer.ensureDeferred(
self.store.register_user(self.user.to_string(), "pass")
)
yield defer.ensureDeferred(self.store.create_profile(self.user.localpart)) yield defer.ensureDeferred(self.store.create_profile(self.user.localpart))
yield defer.ensureDeferred( yield defer.ensureDeferred(
self.store.set_profile_displayname(self.user.localpart, self.displayname) self.store.set_profile_displayname(self.user.localpart, self.displayname)
) )
users, total = yield self.store.get_users_paginate( users, total = yield defer.ensureDeferred(
0, 10, name="bc", guests=False self.store.get_users_paginate(0, 10, name="bc", guests=False)
) )
self.assertEquals(1, total) self.assertEquals(1, total)

View file

@ -37,7 +37,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_register(self): def test_register(self):
yield self.store.register_user(self.user_id, self.pwhash) yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash))
self.assertEquals( self.assertEquals(
{ {
@ -58,14 +58,16 @@ class RegistrationStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_add_tokens(self): def test_add_tokens(self):
yield self.store.register_user(self.user_id, self.pwhash) yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash))
yield defer.ensureDeferred( yield defer.ensureDeferred(
self.store.add_access_token_to_user( self.store.add_access_token_to_user(
self.user_id, self.tokens[1], self.device_id, valid_until_ms=None self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
) )
) )
result = yield self.store.get_user_by_access_token(self.tokens[1]) result = yield defer.ensureDeferred(
self.store.get_user_by_access_token(self.tokens[1])
)
self.assertDictContainsSubset( self.assertDictContainsSubset(
{"name": self.user_id, "device_id": self.device_id}, result {"name": self.user_id, "device_id": self.device_id}, result
@ -76,7 +78,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_user_delete_access_tokens(self): def test_user_delete_access_tokens(self):
# add some tokens # add some tokens
yield self.store.register_user(self.user_id, self.pwhash) yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash))
yield defer.ensureDeferred( yield defer.ensureDeferred(
self.store.add_access_token_to_user( self.store.add_access_token_to_user(
self.user_id, self.tokens[0], device_id=None, valid_until_ms=None self.user_id, self.tokens[0], device_id=None, valid_until_ms=None
@ -89,22 +91,28 @@ class RegistrationStoreTestCase(unittest.TestCase):
) )
# now delete some # now delete some
yield self.store.user_delete_access_tokens( yield defer.ensureDeferred(
self.user_id, device_id=self.device_id self.store.user_delete_access_tokens(self.user_id, device_id=self.device_id)
) )
# check they were deleted # check they were deleted
user = yield self.store.get_user_by_access_token(self.tokens[1]) user = yield defer.ensureDeferred(
self.store.get_user_by_access_token(self.tokens[1])
)
self.assertIsNone(user, "access token was not deleted by device_id") self.assertIsNone(user, "access token was not deleted by device_id")
# check the one not associated with the device was not deleted # check the one not associated with the device was not deleted
user = yield self.store.get_user_by_access_token(self.tokens[0]) user = yield defer.ensureDeferred(
self.store.get_user_by_access_token(self.tokens[0])
)
self.assertEqual(self.user_id, user["name"]) self.assertEqual(self.user_id, user["name"])
# now delete the rest # now delete the rest
yield self.store.user_delete_access_tokens(self.user_id) yield defer.ensureDeferred(self.store.user_delete_access_tokens(self.user_id))
user = yield self.store.get_user_by_access_token(self.tokens[0]) user = yield defer.ensureDeferred(
self.store.get_user_by_access_token(self.tokens[0])
)
self.assertIsNone(user, "access token was not deleted without device_id") self.assertIsNone(user, "access token was not deleted without device_id")
@defer.inlineCallbacks @defer.inlineCallbacks
@ -112,16 +120,20 @@ class RegistrationStoreTestCase(unittest.TestCase):
TEST_USER = "@test:test" TEST_USER = "@test:test"
SUPPORT_USER = "@support:test" SUPPORT_USER = "@support:test"
res = yield self.store.is_support_user(None) res = yield defer.ensureDeferred(self.store.is_support_user(None))
self.assertFalse(res) self.assertFalse(res)
yield self.store.register_user(user_id=TEST_USER, password_hash=None) yield defer.ensureDeferred(
res = yield self.store.is_support_user(TEST_USER) self.store.register_user(user_id=TEST_USER, password_hash=None)
)
res = yield defer.ensureDeferred(self.store.is_support_user(TEST_USER))
self.assertFalse(res) self.assertFalse(res)
yield self.store.register_user( yield defer.ensureDeferred(
user_id=SUPPORT_USER, password_hash=None, user_type=UserTypes.SUPPORT self.store.register_user(
user_id=SUPPORT_USER, password_hash=None, user_type=UserTypes.SUPPORT
)
) )
res = yield self.store.is_support_user(SUPPORT_USER) res = yield defer.ensureDeferred(self.store.is_support_user(SUPPORT_USER))
self.assertTrue(res) self.assertTrue(res)
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -31,10 +31,18 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
# alice and bob are both in !room_id. bobby is not but shares # alice and bob are both in !room_id. bobby is not but shares
# a homeserver with alice. # a homeserver with alice.
yield self.store.update_profile_in_user_dir(ALICE, "alice", None) yield defer.ensureDeferred(
yield self.store.update_profile_in_user_dir(BOB, "bob", None) self.store.update_profile_in_user_dir(ALICE, "alice", None)
yield self.store.update_profile_in_user_dir(BOBBY, "bobby", None) )
yield self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB)) yield defer.ensureDeferred(
self.store.update_profile_in_user_dir(BOB, "bob", None)
)
yield defer.ensureDeferred(
self.store.update_profile_in_user_dir(BOBBY, "bobby", None)
)
yield defer.ensureDeferred(
self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB))
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_search_user_dir(self): def test_search_user_dir(self):

View file

@ -80,16 +80,16 @@ class StateGroupStore(object):
self._next_group = 1 self._next_group = 1
def get_state_groups_ids(self, room_id, event_ids): async def get_state_groups_ids(self, room_id, event_ids):
groups = {} groups = {}
for event_id in event_ids: for event_id in event_ids:
group = self._event_to_state_group.get(event_id) group = self._event_to_state_group.get(event_id)
if group: if group:
groups[group] = self._group_to_state[group] groups[group] = self._group_to_state[group]
return defer.succeed(groups) return groups
def store_state_group( async def store_state_group(
self, event_id, room_id, prev_group, delta_ids, current_state_ids self, event_id, room_id, prev_group, delta_ids, current_state_ids
): ):
state_group = self._next_group state_group = self._next_group
@ -97,19 +97,17 @@ class StateGroupStore(object):
self._group_to_state[state_group] = dict(current_state_ids) self._group_to_state[state_group] = dict(current_state_ids)
return defer.succeed(state_group) return state_group
def get_events(self, event_ids, **kwargs): async def get_events(self, event_ids, **kwargs):
return defer.succeed( return {
{ e_id: self._event_id_to_event[e_id]
e_id: self._event_id_to_event[e_id] for e_id in event_ids
for e_id in event_ids if e_id in self._event_id_to_event
if e_id in self._event_id_to_event }
}
)
def get_state_group_delta(self, name): async def get_state_group_delta(self, name):
return defer.succeed((None, None)) return (None, None)
def register_events(self, events): def register_events(self, events):
for e in events: for e in events:
@ -121,8 +119,8 @@ class StateGroupStore(object):
def register_event_id_state_group(self, event_id, state_group): def register_event_id_state_group(self, event_id, state_group):
self._event_to_state_group[event_id] = state_group self._event_to_state_group[event_id] = state_group
def get_room_version_id(self, room_id): async def get_room_version_id(self, room_id):
return defer.succeed(RoomVersions.V1.identifier) return RoomVersions.V1.identifier
class DictObj(dict): class DictObj(dict):
@ -476,12 +474,14 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""), create_event(type="test2", state_key=""),
] ]
group_name = yield self.store.store_state_group( group_name = yield defer.ensureDeferred(
prev_event_id, self.store.store_state_group(
event.room_id, prev_event_id,
None, event.room_id,
None, None,
{(e.type, e.state_key): e.event_id for e in old_state}, None,
{(e.type, e.state_key): e.event_id for e in old_state},
)
) )
self.store.register_event_id_state_group(prev_event_id, group_name) self.store.register_event_id_state_group(prev_event_id, group_name)
@ -508,12 +508,14 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""), create_event(type="test2", state_key=""),
] ]
group_name = yield self.store.store_state_group( group_name = yield defer.ensureDeferred(
prev_event_id, self.store.store_state_group(
event.room_id, prev_event_id,
None, event.room_id,
None, None,
{(e.type, e.state_key): e.event_id for e in old_state}, None,
{(e.type, e.state_key): e.event_id for e in old_state},
)
) )
self.store.register_event_id_state_group(prev_event_id, group_name) self.store.register_event_id_state_group(prev_event_id, group_name)
@ -691,21 +693,25 @@ class StateTestCase(unittest.TestCase):
def _get_context( def _get_context(
self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2 self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2
): ):
sg1 = yield self.store.store_state_group( sg1 = yield defer.ensureDeferred(
prev_event_id_1, self.store.store_state_group(
event.room_id, prev_event_id_1,
None, event.room_id,
None, None,
{(e.type, e.state_key): e.event_id for e in old_state_1}, None,
{(e.type, e.state_key): e.event_id for e in old_state_1},
)
) )
self.store.register_event_id_state_group(prev_event_id_1, sg1) self.store.register_event_id_state_group(prev_event_id_1, sg1)
sg2 = yield self.store.store_state_group( sg2 = yield defer.ensureDeferred(
prev_event_id_2, self.store.store_state_group(
event.room_id, prev_event_id_2,
None, event.room_id,
None, None,
{(e.type, e.state_key): e.event_id for e in old_state_2}, None,
{(e.type, e.state_key): e.event_id for e in old_state_2},
)
) )
self.store.register_event_id_state_group(prev_event_id_2, sg2) self.store.register_event_id_state_group(prev_event_id_2, sg2)

View file

@ -37,7 +37,6 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
self.hs = yield setup_test_homeserver(self.addCleanup) self.hs = yield setup_test_homeserver(self.addCleanup)
self.event_creation_handler = self.hs.get_event_creation_handler() self.event_creation_handler = self.hs.get_event_creation_handler()
self.event_builder_factory = self.hs.get_event_builder_factory() self.event_builder_factory = self.hs.get_event_builder_factory()
self.store = self.hs.get_datastore()
self.storage = self.hs.get_storage() self.storage = self.hs.get_storage()
yield defer.ensureDeferred(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")) yield defer.ensureDeferred(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM"))
@ -99,7 +98,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
events_to_filter.append(evt) events_to_filter.append(evt)
# the erasey user gets erased # the erasey user gets erased
yield self.hs.get_datastore().mark_user_erased("@erased:local_hs") yield defer.ensureDeferred(
self.hs.get_datastore().mark_user_erased("@erased:local_hs")
)
# ... and the filtering happens. # ... and the filtering happens.
filtered = yield defer.ensureDeferred( filtered = yield defer.ensureDeferred(