Convert additional database methods to async (select list, search, insert_many, delete_*) (#8168)

This commit is contained in:
Patrick Cloke 2020-08-27 07:41:01 -04:00 committed by GitHub
parent 4a739c73b4
commit 30426c7063
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 67 additions and 84 deletions

1
changelog.d/8168.misc Normal file
View file

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

View file

@ -414,13 +414,14 @@ class BackgroundUpdater(object):
self.register_background_update_handler(update_name, updater) self.register_background_update_handler(update_name, updater)
def _end_background_update(self, update_name): async def _end_background_update(self, update_name: str) -> None:
"""Removes a completed background update task from the queue. """Removes a completed background update task from the queue.
Args: Args:
update_name(str): The name of the completed task to remove update_name:: The name of the completed task to remove
Returns: Returns:
A deferred that completes once the task is removed. None, completes once the task is removed.
""" """
if update_name != self._current_background_update: if update_name != self._current_background_update:
raise Exception( raise Exception(
@ -428,7 +429,7 @@ class BackgroundUpdater(object):
% update_name % update_name
) )
self._current_background_update = None self._current_background_update = None
return self.db_pool.simple_delete_one( await self.db_pool.simple_delete_one(
"background_updates", keyvalues={"update_name": update_name} "background_updates", keyvalues={"update_name": update_name}
) )

View file

@ -605,7 +605,13 @@ class DatabasePool(object):
results = [dict(zip(col_headers, row)) for row in cursor] results = [dict(zip(col_headers, row)) for row in cursor]
return results return results
def execute(self, desc: str, decoder: Callable, query: str, *args: Any): async def execute(
self,
desc: str,
decoder: Optional[Callable[[Cursor], R]],
query: str,
*args: Any
) -> R:
"""Runs a single query for a result set. """Runs a single query for a result set.
Args: Args:
@ -614,7 +620,7 @@ class DatabasePool(object):
query - The query string to execute query - The query string to execute
*args - Query args. *args - Query args.
Returns: Returns:
Deferred which results to the result of decoder(results) The result of decoder(results)
""" """
def interaction(txn): def interaction(txn):
@ -624,7 +630,7 @@ class DatabasePool(object):
else: else:
return txn.fetchall() return txn.fetchall()
return self.runInteraction(desc, interaction) return await self.runInteraction(desc, interaction)
# "Simple" SQL API methods that operate on a single table with no JOINs, # "Simple" SQL API methods that operate on a single table with no JOINs,
# no complex WHERE clauses, just a dict of values for columns. # no complex WHERE clauses, just a dict of values for columns.
@ -673,15 +679,30 @@ class DatabasePool(object):
txn.execute(sql, vals) txn.execute(sql, vals)
def simple_insert_many( async def simple_insert_many(
self, table: str, values: List[Dict[str, Any]], desc: str self, table: str, values: List[Dict[str, Any]], desc: str
) -> defer.Deferred: ) -> None:
return self.runInteraction(desc, self.simple_insert_many_txn, table, values) """Executes an INSERT query on the named table.
Args:
table: string giving the table name
values: dict of new column names and values for them
desc: string giving a description of the transaction
"""
await self.runInteraction(desc, self.simple_insert_many_txn, table, values)
@staticmethod @staticmethod
def simple_insert_many_txn( def simple_insert_many_txn(
txn: LoggingTransaction, table: str, values: List[Dict[str, Any]] txn: LoggingTransaction, table: str, values: List[Dict[str, Any]]
) -> None: ) -> None:
"""Executes an INSERT query on the named table.
Args:
txn: The transaction to use.
table: string giving the table name
values: dict of new column names and values for them
desc: string giving a description of the transaction
"""
if not values: if not values:
return return
@ -1397,9 +1418,9 @@ class DatabasePool(object):
return dict(zip(retcols, row)) return dict(zip(retcols, row))
def simple_delete_one( async def simple_delete_one(
self, table: str, keyvalues: Dict[str, Any], desc: str = "simple_delete_one" self, table: str, keyvalues: Dict[str, Any], desc: str = "simple_delete_one"
) -> defer.Deferred: ) -> None:
"""Executes a DELETE query on the named table, expecting to delete a """Executes a DELETE query on the named table, expecting to delete a
single row. single row.
@ -1407,7 +1428,7 @@ class DatabasePool(object):
table: string giving the table name table: string giving the table name
keyvalues: dict of column names and values to select the row with keyvalues: dict of column names and values to select the row with
""" """
return self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues) await self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues)
@staticmethod @staticmethod
def simple_delete_one_txn( def simple_delete_one_txn(
@ -1446,15 +1467,15 @@ class DatabasePool(object):
txn.execute(sql, list(keyvalues.values())) txn.execute(sql, list(keyvalues.values()))
return txn.rowcount return txn.rowcount
def simple_delete_many( async def simple_delete_many(
self, self,
table: str, table: str,
column: str, column: str,
iterable: Iterable[Any], iterable: Iterable[Any],
keyvalues: Dict[str, Any], keyvalues: Dict[str, Any],
desc: str, desc: str,
) -> defer.Deferred: ) -> int:
return self.runInteraction( return await self.runInteraction(
desc, self.simple_delete_many_txn, table, column, iterable, keyvalues desc, self.simple_delete_many_txn, table, column, iterable, keyvalues
) )
@ -1537,52 +1558,6 @@ class DatabasePool(object):
return cache, min_val return cache, min_val
def simple_select_list_paginate(
self,
table: str,
orderby: str,
start: int,
limit: int,
retcols: Iterable[str],
filters: Optional[Dict[str, Any]] = None,
keyvalues: Optional[Dict[str, Any]] = None,
order_direction: str = "ASC",
desc: str = "simple_select_list_paginate",
) -> defer.Deferred:
"""
Executes a SELECT query on the named table with start and limit,
of row numbers, which may return zero or number of rows from start to limit,
returning the result as a list of dicts.
Args:
table: the table name
orderby: Column to order the results by.
start: Index to begin the query at.
limit: Number of results to return.
retcols: the names of the columns to return
filters:
column names and values to filter the rows with, or None to not
apply a WHERE ? LIKE ? clause.
keyvalues:
column names and values to select the rows with, or None to not
apply a WHERE clause.
order_direction: Whether the results should be ordered "ASC" or "DESC".
Returns:
defer.Deferred: resolves to list[dict[str, Any]]
"""
return self.runInteraction(
desc,
self.simple_select_list_paginate_txn,
table,
orderby,
start,
limit,
retcols,
filters=filters,
keyvalues=keyvalues,
order_direction=order_direction,
)
@classmethod @classmethod
def simple_select_list_paginate_txn( def simple_select_list_paginate_txn(
cls, cls,
@ -1647,14 +1622,14 @@ class DatabasePool(object):
return cls.cursor_to_dict(txn) return cls.cursor_to_dict(txn)
def simple_search_list( async def simple_search_list(
self, self,
table: str, table: str,
term: Optional[str], term: Optional[str],
col: str, col: str,
retcols: Iterable[str], retcols: Iterable[str],
desc="simple_search_list", desc="simple_search_list",
): ) -> Optional[List[Dict[str, Any]]]:
"""Executes a SELECT query on the named table, which may return zero or """Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts. more rows, returning the result as a list of dicts.
@ -1665,10 +1640,10 @@ class DatabasePool(object):
retcols: the names of the columns to return retcols: the names of the columns to return
Returns: Returns:
defer.Deferred: resolves to list[dict[str, Any]] or None A list of dictionaries or None.
""" """
return self.runInteraction( return await self.runInteraction(
desc, self.simple_search_list_txn, table, term, col, retcols desc, self.simple_search_list_txn, table, term, col, retcols
) )

View file

@ -18,7 +18,7 @@
import calendar import calendar
import logging import logging
import time import time
from typing import Any, Dict, List from typing import Any, Dict, List, Optional
from synapse.api.constants import PresenceState from synapse.api.constants import PresenceState
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
@ -559,17 +559,17 @@ class DataStore(
"get_users_paginate_txn", get_users_paginate_txn "get_users_paginate_txn", get_users_paginate_txn
) )
def search_users(self, term): async def search_users(self, term: str) -> Optional[List[Dict[str, Any]]]:
"""Function to search users list for one or more users with """Function to search users list for one or more users with
the matched term. the matched term.
Args: Args:
term (str): search term term: search term
col (str): column to query term should be matched to
Returns: Returns:
defer.Deferred: resolves to list[dict[str, Any]] A list of dictionaries or None.
""" """
return self.db_pool.simple_search_list( return await self.db_pool.simple_search_list(
table="users", table="users",
term=term, term=term,
col="name", col="name",

View file

@ -14,7 +14,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Dict, Iterable, List, Optional, Tuple from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
@ -27,6 +27,9 @@ from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
from synapse.handlers.e2e_keys import SignatureListItem
class EndToEndKeyWorkerStore(SQLBaseStore): class EndToEndKeyWorkerStore(SQLBaseStore):
@trace @trace
@ -730,14 +733,16 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
stream_id, stream_id,
) )
def store_e2e_cross_signing_signatures(self, user_id, signatures): async def store_e2e_cross_signing_signatures(
self, user_id: str, signatures: "Iterable[SignatureListItem]"
) -> None:
"""Stores cross-signing signatures. """Stores cross-signing signatures.
Args: Args:
user_id (str): the user who made the signatures user_id: the user who made the signatures
signatures (iterable[SignatureListItem]): signatures to add signatures: signatures to add
""" """
return self.db_pool.simple_insert_many( await self.db_pool.simple_insert_many(
"e2e_cross_signing_signatures", "e2e_cross_signing_signatures",
[ [
{ {

View file

@ -314,14 +314,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="store_remote_media_thumbnail", desc="store_remote_media_thumbnail",
) )
def get_remote_media_before(self, before_ts): async def get_remote_media_before(self, before_ts):
sql = ( sql = (
"SELECT media_origin, media_id, filesystem_id" "SELECT media_origin, media_id, filesystem_id"
" FROM remote_media_cache" " FROM remote_media_cache"
" WHERE last_access_ts < ?" " WHERE last_access_ts < ?"
) )
return self.db_pool.execute( return await self.db_pool.execute(
"get_remote_media_before", self.db_pool.cursor_to_dict, sql, before_ts "get_remote_media_before", self.db_pool.cursor_to_dict, sql, before_ts
) )

View file

@ -67,13 +67,12 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
# second step: complete the update # second step: complete the update
# we should now get run with a much bigger number of items to update # we should now get run with a much bigger number of items to update
@defer.inlineCallbacks async def update(progress, count):
def update(progress, count):
self.assertEqual(progress, {"my_key": 2}) self.assertEqual(progress, {"my_key": 2})
self.assertAlmostEqual( self.assertAlmostEqual(
count, target_background_update_duration_ms / duration_ms, places=0, count, target_background_update_duration_ms / duration_ms, places=0,
) )
yield self.updates._end_background_update("test_update") await self.updates._end_background_update("test_update")
return count return count
self.update_handler.side_effect = update self.update_handler.side_effect = update

View file

@ -197,8 +197,10 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_delete_one(self): def test_delete_one(self):
self.mock_txn.rowcount = 1 self.mock_txn.rowcount = 1
yield self.datastore.db_pool.simple_delete_one( yield defer.ensureDeferred(
table="tablename", keyvalues={"keycol": "Go away"} self.datastore.db_pool.simple_delete_one(
table="tablename", keyvalues={"keycol": "Go away"}
)
) )
self.mock_txn.execute.assert_called_with( self.mock_txn.execute.assert_called_with(