mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-14 05:43:46 +01:00
Add type hints to synapse/storage/databases/main/transactions.py
(#11589)
This commit is contained in:
parent
43f5cc7adc
commit
1847d027e6
3 changed files with 29 additions and 25 deletions
1
changelog.d/11589.misc
Normal file
1
changelog.d/11589.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Add missing type hints to storage classes.
|
4
mypy.ini
4
mypy.ini
|
@ -41,7 +41,6 @@ exclude = (?x)
|
||||||
|synapse/storage/databases/main/search.py
|
|synapse/storage/databases/main/search.py
|
||||||
|synapse/storage/databases/main/state.py
|
|synapse/storage/databases/main/state.py
|
||||||
|synapse/storage/databases/main/stats.py
|
|synapse/storage/databases/main/stats.py
|
||||||
|synapse/storage/databases/main/transactions.py
|
|
||||||
|synapse/storage/databases/main/user_directory.py
|
|synapse/storage/databases/main/user_directory.py
|
||||||
|synapse/storage/schema/
|
|synapse/storage/schema/
|
||||||
|
|
||||||
|
@ -216,6 +215,9 @@ disallow_untyped_defs = True
|
||||||
[mypy-synapse.storage.databases.main.state_deltas]
|
[mypy-synapse.storage.databases.main.state_deltas]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.storage.databases.main.transactions]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
[mypy-synapse.storage.databases.main.user_erasure_store]
|
[mypy-synapse.storage.databases.main.user_erasure_store]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
|
|
@ -13,9 +13,8 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from collections import namedtuple
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, cast
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
from canonicaljson import encode_canonical_json
|
from canonicaljson import encode_canonical_json
|
||||||
|
@ -39,16 +38,6 @@ db_binary_type = memoryview
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
_TransactionRow = namedtuple(
|
|
||||||
"_TransactionRow",
|
|
||||||
("id", "transaction_id", "destination", "ts", "response_code", "response_json"),
|
|
||||||
)
|
|
||||||
|
|
||||||
_UpdateTransactionRow = namedtuple(
|
|
||||||
"_TransactionRow", ("response_code", "response_json")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DestinationSortOrder(Enum):
|
class DestinationSortOrder(Enum):
|
||||||
"""Enum to define the sorting method used when returning destinations."""
|
"""Enum to define the sorting method used when returning destinations."""
|
||||||
|
|
||||||
|
@ -91,7 +80,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
|
||||||
now = self._clock.time_msec()
|
now = self._clock.time_msec()
|
||||||
month_ago = now - 30 * 24 * 60 * 60 * 1000
|
month_ago = now - 30 * 24 * 60 * 60 * 1000
|
||||||
|
|
||||||
def _cleanup_transactions_txn(txn):
|
def _cleanup_transactions_txn(txn: LoggingTransaction) -> None:
|
||||||
txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
|
txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
|
||||||
|
|
||||||
await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
|
@ -121,7 +110,9 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
|
||||||
origin,
|
origin,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_received_txn_response(self, txn, transaction_id, origin):
|
def _get_received_txn_response(
|
||||||
|
self, txn: LoggingTransaction, transaction_id: str, origin: str
|
||||||
|
) -> Optional[Tuple[int, JsonDict]]:
|
||||||
result = self.db_pool.simple_select_one_txn(
|
result = self.db_pool.simple_select_one_txn(
|
||||||
txn,
|
txn,
|
||||||
table="received_transactions",
|
table="received_transactions",
|
||||||
|
@ -196,7 +187,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _get_destination_retry_timings(
|
def _get_destination_retry_timings(
|
||||||
self, txn, destination: str
|
self, txn: LoggingTransaction, destination: str
|
||||||
) -> Optional[DestinationRetryTimings]:
|
) -> Optional[DestinationRetryTimings]:
|
||||||
result = self.db_pool.simple_select_one_txn(
|
result = self.db_pool.simple_select_one_txn(
|
||||||
txn,
|
txn,
|
||||||
|
@ -231,7 +222,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self.database_engine.can_native_upsert:
|
if self.database_engine.can_native_upsert:
|
||||||
return await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"set_destination_retry_timings",
|
"set_destination_retry_timings",
|
||||||
self._set_destination_retry_timings_native,
|
self._set_destination_retry_timings_native,
|
||||||
destination,
|
destination,
|
||||||
|
@ -241,7 +232,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
|
||||||
db_autocommit=True, # Safe as its a single upsert
|
db_autocommit=True, # Safe as its a single upsert
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"set_destination_retry_timings",
|
"set_destination_retry_timings",
|
||||||
self._set_destination_retry_timings_emulated,
|
self._set_destination_retry_timings_emulated,
|
||||||
destination,
|
destination,
|
||||||
|
@ -251,8 +242,13 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _set_destination_retry_timings_native(
|
def _set_destination_retry_timings_native(
|
||||||
self, txn, destination, failure_ts, retry_last_ts, retry_interval
|
self,
|
||||||
):
|
txn: LoggingTransaction,
|
||||||
|
destination: str,
|
||||||
|
failure_ts: Optional[int],
|
||||||
|
retry_last_ts: int,
|
||||||
|
retry_interval: int,
|
||||||
|
) -> None:
|
||||||
assert self.database_engine.can_native_upsert
|
assert self.database_engine.can_native_upsert
|
||||||
|
|
||||||
# Upsert retry time interval if retry_interval is zero (i.e. we're
|
# Upsert retry time interval if retry_interval is zero (i.e. we're
|
||||||
|
@ -282,8 +278,13 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _set_destination_retry_timings_emulated(
|
def _set_destination_retry_timings_emulated(
|
||||||
self, txn, destination, failure_ts, retry_last_ts, retry_interval
|
self,
|
||||||
):
|
txn: LoggingTransaction,
|
||||||
|
destination: str,
|
||||||
|
failure_ts: Optional[int],
|
||||||
|
retry_last_ts: int,
|
||||||
|
retry_interval: int,
|
||||||
|
) -> None:
|
||||||
self.database_engine.lock_table(txn, "destinations")
|
self.database_engine.lock_table(txn, "destinations")
|
||||||
|
|
||||||
# We need to be careful here as the data may have changed from under us
|
# We need to be careful here as the data may have changed from under us
|
||||||
|
@ -393,7 +394,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
|
||||||
last_successful_stream_ordering: the stream_ordering of the most
|
last_successful_stream_ordering: the stream_ordering of the most
|
||||||
recent successfully-sent PDU
|
recent successfully-sent PDU
|
||||||
"""
|
"""
|
||||||
return await self.db_pool.simple_upsert(
|
await self.db_pool.simple_upsert(
|
||||||
"destinations",
|
"destinations",
|
||||||
keyvalues={"destination": destination},
|
keyvalues={"destination": destination},
|
||||||
values={"last_successful_stream_ordering": last_successful_stream_ordering},
|
values={"last_successful_stream_ordering": last_successful_stream_ordering},
|
||||||
|
@ -534,7 +535,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
|
||||||
else:
|
else:
|
||||||
order = "ASC"
|
order = "ASC"
|
||||||
|
|
||||||
args = []
|
args: List[object] = []
|
||||||
where_statement = ""
|
where_statement = ""
|
||||||
if destination:
|
if destination:
|
||||||
args.extend(["%" + destination.lower() + "%"])
|
args.extend(["%" + destination.lower() + "%"])
|
||||||
|
@ -543,7 +544,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
|
||||||
sql_base = f"FROM destinations {where_statement} "
|
sql_base = f"FROM destinations {where_statement} "
|
||||||
sql = f"SELECT COUNT(*) as total_destinations {sql_base}"
|
sql = f"SELECT COUNT(*) as total_destinations {sql_base}"
|
||||||
txn.execute(sql, args)
|
txn.execute(sql, args)
|
||||||
count = txn.fetchone()[0]
|
count = cast(Tuple[int], txn.fetchone())[0]
|
||||||
|
|
||||||
sql = f"""
|
sql = f"""
|
||||||
SELECT destination, retry_last_ts, retry_interval, failure_ts,
|
SELECT destination, retry_last_ts, retry_interval, failure_ts,
|
||||||
|
|
Loading…
Reference in a new issue