Add type hints to synapse/storage/databases/main/transactions.py (#11589)

This commit is contained in:
Dirk Klimpel 2021-12-16 20:59:35 +01:00 committed by GitHub
parent 43f5cc7adc
commit 1847d027e6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 29 additions and 25 deletions

1
changelog.d/11589.misc Normal file
View file

@ -0,0 +1 @@
Add missing type hints to storage classes.

View file

@ -41,7 +41,6 @@ exclude = (?x)
|synapse/storage/databases/main/search.py |synapse/storage/databases/main/search.py
|synapse/storage/databases/main/state.py |synapse/storage/databases/main/state.py
|synapse/storage/databases/main/stats.py |synapse/storage/databases/main/stats.py
|synapse/storage/databases/main/transactions.py
|synapse/storage/databases/main/user_directory.py |synapse/storage/databases/main/user_directory.py
|synapse/storage/schema/ |synapse/storage/schema/
@ -216,6 +215,9 @@ disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.state_deltas] [mypy-synapse.storage.databases.main.state_deltas]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.transactions]
disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.user_erasure_store] [mypy-synapse.storage.databases.main.user_erasure_store]
disallow_untyped_defs = True disallow_untyped_defs = True

View file

@ -13,9 +13,8 @@
# limitations under the License. # limitations under the License.
import logging import logging
from collections import namedtuple
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, cast
import attr import attr
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
@ -39,16 +38,6 @@ db_binary_type = memoryview
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_TransactionRow = namedtuple(
"_TransactionRow",
("id", "transaction_id", "destination", "ts", "response_code", "response_json"),
)
_UpdateTransactionRow = namedtuple(
"_TransactionRow", ("response_code", "response_json")
)
class DestinationSortOrder(Enum): class DestinationSortOrder(Enum):
"""Enum to define the sorting method used when returning destinations.""" """Enum to define the sorting method used when returning destinations."""
@ -91,7 +80,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
now = self._clock.time_msec() now = self._clock.time_msec()
month_ago = now - 30 * 24 * 60 * 60 * 1000 month_ago = now - 30 * 24 * 60 * 60 * 1000
def _cleanup_transactions_txn(txn): def _cleanup_transactions_txn(txn: LoggingTransaction) -> None:
txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,)) txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
@ -121,7 +110,9 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
origin, origin,
) )
def _get_received_txn_response(self, txn, transaction_id, origin): def _get_received_txn_response(
self, txn: LoggingTransaction, transaction_id: str, origin: str
) -> Optional[Tuple[int, JsonDict]]:
result = self.db_pool.simple_select_one_txn( result = self.db_pool.simple_select_one_txn(
txn, txn,
table="received_transactions", table="received_transactions",
@ -196,7 +187,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
return result return result
def _get_destination_retry_timings( def _get_destination_retry_timings(
self, txn, destination: str self, txn: LoggingTransaction, destination: str
) -> Optional[DestinationRetryTimings]: ) -> Optional[DestinationRetryTimings]:
result = self.db_pool.simple_select_one_txn( result = self.db_pool.simple_select_one_txn(
txn, txn,
@ -231,7 +222,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
""" """
if self.database_engine.can_native_upsert: if self.database_engine.can_native_upsert:
return await self.db_pool.runInteraction( await self.db_pool.runInteraction(
"set_destination_retry_timings", "set_destination_retry_timings",
self._set_destination_retry_timings_native, self._set_destination_retry_timings_native,
destination, destination,
@ -241,7 +232,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
db_autocommit=True, # Safe as its a single upsert db_autocommit=True, # Safe as its a single upsert
) )
else: else:
return await self.db_pool.runInteraction( await self.db_pool.runInteraction(
"set_destination_retry_timings", "set_destination_retry_timings",
self._set_destination_retry_timings_emulated, self._set_destination_retry_timings_emulated,
destination, destination,
@ -251,8 +242,13 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
) )
def _set_destination_retry_timings_native( def _set_destination_retry_timings_native(
self, txn, destination, failure_ts, retry_last_ts, retry_interval self,
): txn: LoggingTransaction,
destination: str,
failure_ts: Optional[int],
retry_last_ts: int,
retry_interval: int,
) -> None:
assert self.database_engine.can_native_upsert assert self.database_engine.can_native_upsert
# Upsert retry time interval if retry_interval is zero (i.e. we're # Upsert retry time interval if retry_interval is zero (i.e. we're
@ -282,8 +278,13 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
) )
def _set_destination_retry_timings_emulated( def _set_destination_retry_timings_emulated(
self, txn, destination, failure_ts, retry_last_ts, retry_interval self,
): txn: LoggingTransaction,
destination: str,
failure_ts: Optional[int],
retry_last_ts: int,
retry_interval: int,
) -> None:
self.database_engine.lock_table(txn, "destinations") self.database_engine.lock_table(txn, "destinations")
# We need to be careful here as the data may have changed from under us # We need to be careful here as the data may have changed from under us
@ -393,7 +394,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
last_successful_stream_ordering: the stream_ordering of the most last_successful_stream_ordering: the stream_ordering of the most
recent successfully-sent PDU recent successfully-sent PDU
""" """
return await self.db_pool.simple_upsert( await self.db_pool.simple_upsert(
"destinations", "destinations",
keyvalues={"destination": destination}, keyvalues={"destination": destination},
values={"last_successful_stream_ordering": last_successful_stream_ordering}, values={"last_successful_stream_ordering": last_successful_stream_ordering},
@ -534,7 +535,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
else: else:
order = "ASC" order = "ASC"
args = [] args: List[object] = []
where_statement = "" where_statement = ""
if destination: if destination:
args.extend(["%" + destination.lower() + "%"]) args.extend(["%" + destination.lower() + "%"])
@ -543,7 +544,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
sql_base = f"FROM destinations {where_statement} " sql_base = f"FROM destinations {where_statement} "
sql = f"SELECT COUNT(*) as total_destinations {sql_base}" sql = f"SELECT COUNT(*) as total_destinations {sql_base}"
txn.execute(sql, args) txn.execute(sql, args)
count = txn.fetchone()[0] count = cast(Tuple[int], txn.fetchone())[0]
sql = f""" sql = f"""
SELECT destination, retry_last_ts, retry_interval, failure_ts, SELECT destination, retry_last_ts, retry_interval, failure_ts,