mirror of
https://mau.dev/maunium/synapse.git
synced 2024-11-09 03:22:57 +01:00
Sanitize TransactionStore
This commit is contained in:
parent
f6583796fe
commit
278149f533
2 changed files with 104 additions and 87 deletions
|
@ -179,7 +179,7 @@ class FederationHandler(BaseHandler):
|
|||
# it's probably a good idea to mark it as not in retry-state
|
||||
# for sending (although this is a bit of a leap)
|
||||
retry_timings = yield self.store.get_destination_retry_timings(origin)
|
||||
if (retry_timings and retry_timings.retry_last_ts):
|
||||
if retry_timings and retry_timings["retry_last_ts"]:
|
||||
self.store.set_destination_retry_timings(origin, 0, 0)
|
||||
|
||||
room = yield self.store.get_room(event.room_id)
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import SQLBaseStore, Table, cached
|
||||
from ._base import SQLBaseStore, cached
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
|
@ -84,13 +84,18 @@ class TransactionStore(SQLBaseStore):
|
|||
|
||||
def _set_received_txn_response(self, txn, transaction_id, origin, code,
|
||||
response_json):
|
||||
query = (
|
||||
"UPDATE %s "
|
||||
"SET response_code = ?, response_json = ? "
|
||||
"WHERE transaction_id = ? AND origin = ?"
|
||||
) % ReceivedTransactionsTable.table_name
|
||||
|
||||
txn.execute(query, (code, response_json, transaction_id, origin))
|
||||
self._simple_update_one_txn(
|
||||
txn,
|
||||
table=ReceivedTransactionsTable.table_name,
|
||||
keyvalues={
|
||||
"transaction_id": transaction_id,
|
||||
"origin": origin,
|
||||
},
|
||||
updatevalues={
|
||||
"response_code": code,
|
||||
"response_json": response_json,
|
||||
}
|
||||
)
|
||||
|
||||
def prep_send_transaction(self, transaction_id, destination,
|
||||
origin_server_ts):
|
||||
|
@ -121,38 +126,32 @@ class TransactionStore(SQLBaseStore):
|
|||
# First we find out what the prev_txns should be.
|
||||
# Since we know that we are only sending one transaction at a time,
|
||||
# we can simply take the last one.
|
||||
query = "%s ORDER BY id DESC LIMIT 1" % (
|
||||
SentTransactions.select_statement("destination = ?"),
|
||||
query = (
|
||||
"SELECT * FROM sent_transactions"
|
||||
" WHERE destination = ?"
|
||||
" ORDER BY id DESC LIMIT 1"
|
||||
)
|
||||
|
||||
txn.execute(query, (destination,))
|
||||
results = SentTransactions.decode_results(txn.fetchall())
|
||||
results = self.cursor_to_dict(txn)
|
||||
|
||||
prev_txns = [r.transaction_id for r in results]
|
||||
prev_txns = [r["transaction_id"] for r in results]
|
||||
|
||||
# Actually add the new transaction to the sent_transactions table.
|
||||
|
||||
query = SentTransactions.insert_statement()
|
||||
txn.execute(query, SentTransactions.EntryType(
|
||||
self.get_next_stream_id(),
|
||||
transaction_id=transaction_id,
|
||||
destination=destination,
|
||||
ts=origin_server_ts,
|
||||
response_code=0,
|
||||
response_json=None
|
||||
))
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table=SentTransactions.table_name,
|
||||
values={
|
||||
"transaction_id": self.get_next_stream_id(),
|
||||
"destination": destination,
|
||||
"ts": origin_server_ts,
|
||||
"response_code": 0,
|
||||
"response_json": None,
|
||||
}
|
||||
)
|
||||
|
||||
# Update the tx id -> pdu id mapping
|
||||
|
||||
# values = [
|
||||
# (transaction_id, destination, pdu[0], pdu[1])
|
||||
# for pdu in pdu_list
|
||||
# ]
|
||||
#
|
||||
# logger.debug("Inserting: %s", repr(values))
|
||||
#
|
||||
# query = TransactionsToPduTable.insert_statement()
|
||||
# txn.executemany(query, values)
|
||||
# TODO Update the tx id -> pdu id mapping
|
||||
|
||||
return prev_txns
|
||||
|
||||
|
@ -171,15 +170,20 @@ class TransactionStore(SQLBaseStore):
|
|||
transaction_id, destination, code, response_dict
|
||||
)
|
||||
|
||||
def _delivered_txn(cls, txn, transaction_id, destination,
|
||||
def _delivered_txn(self, txn, transaction_id, destination,
|
||||
code, response_json):
|
||||
query = (
|
||||
"UPDATE %s "
|
||||
"SET response_code = ?, response_json = ? "
|
||||
"WHERE transaction_id = ? AND destination = ?"
|
||||
) % SentTransactions.table_name
|
||||
|
||||
txn.execute(query, (code, response_json, transaction_id, destination))
|
||||
self._simple_update_one_txn(
|
||||
txn,
|
||||
table=SentTransactions.table_name,
|
||||
keyvalues={
|
||||
"transaction_id": transaction_id,
|
||||
"destination": destination,
|
||||
},
|
||||
updatevalues={
|
||||
"response_code": code,
|
||||
"response_json": response_json,
|
||||
}
|
||||
)
|
||||
|
||||
def get_transactions_after(self, transaction_id, destination):
|
||||
"""Get all transactions after a given local transaction_id.
|
||||
|
@ -189,25 +193,26 @@ class TransactionStore(SQLBaseStore):
|
|||
destination (str)
|
||||
|
||||
Returns:
|
||||
list: A list of `ReceivedTransactionsTable.EntryType`
|
||||
list: A list of dicts
|
||||
"""
|
||||
return self.runInteraction(
|
||||
"get_transactions_after",
|
||||
self._get_transactions_after, transaction_id, destination
|
||||
)
|
||||
|
||||
def _get_transactions_after(cls, txn, transaction_id, destination):
|
||||
where = (
|
||||
"destination = ? AND id > (select id FROM %s WHERE "
|
||||
"transaction_id = ? AND destination = ?)"
|
||||
) % (
|
||||
SentTransactions.table_name
|
||||
def _get_transactions_after(self, txn, transaction_id, destination):
|
||||
query = (
|
||||
"SELECT * FROM sent_transactions"
|
||||
" WHERE destination = ? AND id >"
|
||||
" ("
|
||||
" SELECT id FROM sent_transactions"
|
||||
" WHERE transaction_id = ? AND destination = ?"
|
||||
" )"
|
||||
)
|
||||
query = SentTransactions.select_statement(where)
|
||||
|
||||
txn.execute(query, (destination, transaction_id, destination))
|
||||
|
||||
return ReceivedTransactionsTable.decode_results(txn.fetchall())
|
||||
return self.cursor_to_dict(txn)
|
||||
|
||||
@cached()
|
||||
def get_destination_retry_timings(self, destination):
|
||||
|
@ -218,19 +223,24 @@ class TransactionStore(SQLBaseStore):
|
|||
|
||||
Returns:
|
||||
None if not retrying
|
||||
Otherwise a DestinationsTable.EntryType for the retry scheme
|
||||
Otherwise a dict for the retry scheme
|
||||
"""
|
||||
return self.runInteraction(
|
||||
"get_destination_retry_timings",
|
||||
self._get_destination_retry_timings, destination)
|
||||
|
||||
def _get_destination_retry_timings(cls, txn, destination):
|
||||
query = DestinationsTable.select_statement("destination = ?")
|
||||
txn.execute(query, (destination,))
|
||||
result = txn.fetchall()
|
||||
if result:
|
||||
result = DestinationsTable.decode_single_result(result)
|
||||
if result.retry_last_ts > 0:
|
||||
def _get_destination_retry_timings(self, txn, destination):
|
||||
result = self._simple_select_one_txn(
|
||||
txn,
|
||||
table=DestinationsTable.table_name,
|
||||
keyvalues={
|
||||
"destination": destination,
|
||||
},
|
||||
retcols=DestinationsTable.fields,
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
if result["retry_last_ts"] > 0:
|
||||
return result
|
||||
else:
|
||||
return None
|
||||
|
@ -249,11 +259,11 @@ class TransactionStore(SQLBaseStore):
|
|||
# As this is the new value, we might as well prefill the cache
|
||||
self.get_destination_retry_timings.prefill(
|
||||
destination,
|
||||
DestinationsTable.EntryType(
|
||||
destination,
|
||||
retry_last_ts,
|
||||
retry_interval
|
||||
)
|
||||
{
|
||||
"destination": destination,
|
||||
"retry_last_ts": retry_last_ts,
|
||||
"retry_interval": retry_interval
|
||||
},
|
||||
)
|
||||
|
||||
# XXX: we could chose to not bother persisting this if our cache thinks
|
||||
|
@ -270,18 +280,27 @@ class TransactionStore(SQLBaseStore):
|
|||
retry_last_ts, retry_interval):
|
||||
|
||||
query = (
|
||||
"REPLACE INTO %s "
|
||||
"INSERT INTO destinations"
|
||||
" (destination, retry_last_ts, retry_interval)"
|
||||
" VALUES (?, ?, ?)"
|
||||
) % DestinationsTable.table_name
|
||||
" ON DUPLICATE KEY UPDATE"
|
||||
" retry_last_ts=?, retry_interval=?"
|
||||
)
|
||||
|
||||
txn.execute(query, (destination, retry_last_ts, retry_interval))
|
||||
txn.execute(
|
||||
query,
|
||||
(
|
||||
destination,
|
||||
retry_last_ts, retry_interval,
|
||||
retry_last_ts, retry_interval,
|
||||
)
|
||||
)
|
||||
|
||||
def get_destinations_needing_retry(self):
|
||||
"""Get all destinations which are due a retry for sending a transaction.
|
||||
|
||||
Returns:
|
||||
list: A list of `DestinationsTable.EntryType`
|
||||
list: A list of dicts
|
||||
"""
|
||||
|
||||
return self.runInteraction(
|
||||
|
@ -289,14 +308,17 @@ class TransactionStore(SQLBaseStore):
|
|||
self._get_destinations_needing_retry
|
||||
)
|
||||
|
||||
def _get_destinations_needing_retry(cls, txn):
|
||||
where = "retry_last_ts > 0 and retry_next_ts < now()"
|
||||
query = DestinationsTable.select_statement(where)
|
||||
txn.execute(query)
|
||||
return DestinationsTable.decode_results(txn.fetchall())
|
||||
def _get_destinations_needing_retry(self, txn):
|
||||
query = (
|
||||
"SELECT * FROM destinations"
|
||||
" WHERE retry_last_ts > 0 and retry_next_ts < ?"
|
||||
)
|
||||
|
||||
txn.execute(query, (self._clock.time_msec(),))
|
||||
return self.cursor_to_dict(txn)
|
||||
|
||||
|
||||
class ReceivedTransactionsTable(Table):
|
||||
class ReceivedTransactionsTable(object):
|
||||
table_name = "received_transactions"
|
||||
|
||||
fields = [
|
||||
|
@ -308,10 +330,8 @@ class ReceivedTransactionsTable(Table):
|
|||
"has_been_referenced",
|
||||
]
|
||||
|
||||
EntryType = namedtuple("ReceivedTransactionsEntry", fields)
|
||||
|
||||
|
||||
class SentTransactions(Table):
|
||||
class SentTransactions(object):
|
||||
table_name = "sent_transactions"
|
||||
|
||||
fields = [
|
||||
|
@ -326,7 +346,7 @@ class SentTransactions(Table):
|
|||
EntryType = namedtuple("SentTransactionsEntry", fields)
|
||||
|
||||
|
||||
class TransactionsToPduTable(Table):
|
||||
class TransactionsToPduTable(object):
|
||||
table_name = "transaction_id_to_pdu"
|
||||
|
||||
fields = [
|
||||
|
@ -336,10 +356,8 @@ class TransactionsToPduTable(Table):
|
|||
"pdu_origin",
|
||||
]
|
||||
|
||||
EntryType = namedtuple("TransactionsToPduEntry", fields)
|
||||
|
||||
|
||||
class DestinationsTable(Table):
|
||||
class DestinationsTable(object):
|
||||
table_name = "destinations"
|
||||
|
||||
fields = [
|
||||
|
@ -348,4 +366,3 @@ class DestinationsTable(Table):
|
|||
"retry_interval",
|
||||
]
|
||||
|
||||
EntryType = namedtuple("DestinationsEntry", fields)
|
||||
|
|
Loading…
Reference in a new issue