0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-05-23 14:03:45 +02:00

Convert some of the general database methods to async (#8100)

This commit is contained in:
Patrick Cloke 2020-08-17 12:18:01 -04:00 committed by GitHub
parent e04e465b4d
commit 050e20e7ca
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 69 additions and 59 deletions

1
changelog.d/8100.misc Normal file
View file

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

View file

@ -332,8 +332,7 @@ class DatabasePool(object):
"""
return self._db_pool.running
@defer.inlineCallbacks
def _check_safe_to_upsert(self):
async def _check_safe_to_upsert(self):
"""
Is it safe to use native UPSERT?
@ -342,7 +341,7 @@ class DatabasePool(object):
If the background updates have not completed, wait 15 sec and check again.
"""
updates = yield self.simple_select_list(
updates = await self.simple_select_list(
"background_updates",
keyvalues=None,
retcols=["update_name"],
@ -614,8 +613,7 @@ class DatabasePool(object):
# "Simple" SQL API methods that operate on a single table with no JOINs,
# no complex WHERE clauses, just a dict of values for columns.
@defer.inlineCallbacks
def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"):
async def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"):
"""Executes an INSERT query on the named table.
Args:
@ -631,7 +629,7 @@ class DatabasePool(object):
`or_ignore` is True
"""
try:
yield self.runInteraction(desc, self.simple_insert_txn, table, values)
await self.runInteraction(desc, self.simple_insert_txn, table, values)
except self.engine.module.IntegrityError:
# We have to do or_ignore flag at this layer, since we can't reuse
# a cursor after we receive an error from the db.
@ -684,8 +682,7 @@ class DatabasePool(object):
txn.executemany(sql, vals)
@defer.inlineCallbacks
def simple_upsert(
async def simple_upsert(
self,
table,
keyvalues,
@ -714,14 +711,14 @@ class DatabasePool(object):
inserting
lock (bool): True to lock the table when doing the upsert.
Returns:
Deferred(None or bool): Native upserts always return None. Emulated
None or bool: Native upserts always return None. Emulated
upserts return True if a new entry was created, False if an existing
one was updated.
"""
attempts = 0
while True:
try:
result = yield self.runInteraction(
return await self.runInteraction(
desc,
self.simple_upsert_txn,
table,
@ -730,7 +727,6 @@ class DatabasePool(object):
insertion_values,
lock=lock,
)
return result
except self.engine.module.IntegrityError as e:
attempts += 1
if attempts >= 5:
@ -1121,8 +1117,7 @@ class DatabasePool(object):
return cls.cursor_to_dict(txn)
@defer.inlineCallbacks
def simple_select_many_batch(
async def simple_select_many_batch(
self,
table,
column,
@ -1156,7 +1151,7 @@ class DatabasePool(object):
it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size)
]
for chunk in chunks:
rows = yield self.runInteraction(
rows = await self.runInteraction(
desc,
self.simple_select_many_txn,
table,

View file

@ -169,7 +169,7 @@ class ApplicationServiceTransactionWorkerStore(
service(ApplicationService): The service whose state to set.
state(ApplicationServiceState): The connectivity state to apply.
Returns:
A Deferred which resolves when the state was set successfully.
An Awaitable which resolves when the state was set successfully.
"""
return self.db_pool.simple_upsert(
"application_services_state", {"as_id": service.id}, {"state": state}

View file

@ -847,13 +847,15 @@ class EventsWorkerStore(SQLBaseStore):
"""Given a list of event ids, check if we have already processed and
stored them as non outliers.
"""
rows = yield self.db_pool.simple_select_many_batch(
table="events",
retcols=("event_id",),
column="event_id",
iterable=list(event_ids),
keyvalues={"outlier": False},
desc="have_events_in_timeline",
rows = yield defer.ensureDeferred(
self.db_pool.simple_select_many_batch(
table="events",
retcols=("event_id",),
column="event_id",
iterable=list(event_ids),
keyvalues={"outlier": False},
desc="have_events_in_timeline",
)
)
return {r["event_id"] for r in rows}

View file

@ -17,9 +17,7 @@
import logging
import re
from typing import Dict, List, Optional
from twisted.internet.defer import Deferred
from typing import Awaitable, Dict, List, Optional
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
@ -563,7 +561,7 @@ class RegistrationWorkerStore(SQLBaseStore):
id_server (str)
Returns:
Deferred
Awaitable
"""
# We need to use an upsert, in case they user had already bound the
# threepid
@ -1084,7 +1082,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
def record_user_external_id(
self, auth_provider: str, external_id: str, user_id: str
) -> Deferred:
) -> Awaitable:
"""Record a mapping from an external user id to a mxid
Args:

View file

@ -767,13 +767,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return set(room_ids)
def get_membership_from_event_ids(
async def get_membership_from_event_ids(
self, member_event_ids: Iterable[str]
) -> List[dict]:
"""Get user_id and membership of a set of event IDs.
"""
return self.db_pool.simple_select_many_batch(
return await self.db_pool.simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=member_event_ids,

View file

@ -64,7 +64,7 @@ class ProfileTestCase(unittest.TestCase):
self.bob = UserID.from_string("@4567:test")
self.alice = UserID.from_string("@alice:remote")
yield self.store.create_profile(self.frank.localpart)
yield defer.ensureDeferred(self.store.create_profile(self.frank.localpart))
self.handler = hs.get_profile_handler()
self.hs = hs
@ -157,7 +157,7 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_incoming_fed_query(self):
yield self.store.create_profile("caroline")
yield defer.ensureDeferred(self.store.create_profile("caroline"))
yield self.store.set_profile_displayname("caroline", "Caroline")
response = yield defer.ensureDeferred(

View file

@ -156,7 +156,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
([], 0)
)
self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None
self.datastore.set_received_txn_response = lambda *args, **kwargs: defer.succeed(
self.datastore.set_received_txn_response = lambda *args, **kwargs: make_awaitable(
None
)

View file

@ -207,7 +207,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_set_appservices_state_down(self):
service = Mock(id=self.as_list[1]["id"])
yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
yield defer.ensureDeferred(
self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
)
rows = yield self.db_pool.runQuery(
self.engine.convert_param_style(
"SELECT as_id FROM application_services_state WHERE state=?"
@ -219,9 +221,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_set_appservices_state_multiple_up(self):
service = Mock(id=self.as_list[1]["id"])
yield self.store.set_appservice_state(service, ApplicationServiceState.UP)
yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
yield self.store.set_appservice_state(service, ApplicationServiceState.UP)
yield defer.ensureDeferred(
self.store.set_appservice_state(service, ApplicationServiceState.UP)
)
yield defer.ensureDeferred(
self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
)
yield defer.ensureDeferred(
self.store.set_appservice_state(service, ApplicationServiceState.UP)
)
rows = yield self.db_pool.runQuery(
self.engine.convert_param_style(
"SELECT as_id FROM application_services_state WHERE state=?"

View file

@ -66,8 +66,10 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_insert_1col(self):
self.mock_txn.rowcount = 1
yield self.datastore.db_pool.simple_insert(
table="tablename", values={"columname": "Value"}
yield defer.ensureDeferred(
self.datastore.db_pool.simple_insert(
table="tablename", values={"columname": "Value"}
)
)
self.mock_txn.execute.assert_called_with(
@ -78,10 +80,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_insert_3cols(self):
self.mock_txn.rowcount = 1
yield self.datastore.db_pool.simple_insert(
table="tablename",
# Use OrderedDict() so we can assert on the SQL generated
values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]),
yield defer.ensureDeferred(
self.datastore.db_pool.simple_insert(
table="tablename",
# Use OrderedDict() so we can assert on the SQL generated
values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]),
)
)
self.mock_txn.execute.assert_called_with(

View file

@ -142,20 +142,22 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def test_find_first_stream_ordering_after_ts(self):
def add_event(so, ts):
return self.store.db_pool.simple_insert(
"events",
{
"stream_ordering": so,
"received_ts": ts,
"event_id": "event%i" % so,
"type": "",
"room_id": "",
"content": "",
"processed": True,
"outlier": False,
"topological_ordering": 0,
"depth": 0,
},
return defer.ensureDeferred(
self.store.db_pool.simple_insert(
"events",
{
"stream_ordering": so,
"received_ts": ts,
"event_id": "event%i" % so,
"type": "",
"room_id": "",
"content": "",
"processed": True,
"outlier": False,
"topological_ordering": 0,
"depth": 0,
},
)
)
# start with the base case where there are no events in the table

View file

@ -35,7 +35,7 @@ class DataStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_users_paginate(self):
yield self.store.register_user(self.user.to_string(), "pass")
yield self.store.create_profile(self.user.localpart)
yield defer.ensureDeferred(self.store.create_profile(self.user.localpart))
yield self.store.set_profile_displayname(self.user.localpart, self.displayname)
users, total = yield self.store.get_users_paginate(

View file

@ -33,7 +33,7 @@ class ProfileStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_displayname(self):
yield self.store.create_profile(self.u_frank.localpart)
yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
@ -43,7 +43,7 @@ class ProfileStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_avatar_url(self):
yield self.store.create_profile(self.u_frank.localpart)
yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
yield self.store.set_profile_avatar_url(
self.u_frank.localpart, "http://my.site/here"