Fix batching for fetching cross-signing keys

There's no point carefully dividing a list into batches, and then completely
ignoring the batches.
This commit is contained in:
Richard van der Hoff 2020-05-06 00:03:38 +01:00
parent 5b8023dc7f
commit db5f9031b7

View file

@ -25,7 +25,9 @@ from twisted.internet import defer
from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import make_in_list_sql_clause
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter
class EndToEndKeyWorkerStore(SQLBaseStore): class EndToEndKeyWorkerStore(SQLBaseStore):
@ -391,26 +393,24 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
""" """
result = {} result = {}
batch_size = 100 for user_chunk in batch_iter(user_ids, 100):
chunks = [ clause, params = make_in_list_sql_clause(
user_ids[i : i + batch_size] for i in range(0, len(user_ids), batch_size) txn.database_engine, "k.user_id", user_chunk
] )
for user_chunk in chunks: sql = (
sql = """ """
SELECT k.user_id, k.keytype, k.keydata, k.stream_id SELECT k.user_id, k.keytype, k.keydata, k.stream_id
FROM e2e_cross_signing_keys k FROM e2e_cross_signing_keys k
INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id
FROM e2e_cross_signing_keys FROM e2e_cross_signing_keys
GROUP BY user_id, keytype) s GROUP BY user_id, keytype) s
USING (user_id, stream_id, keytype) USING (user_id, stream_id, keytype)
WHERE k.user_id IN (%s) WHERE
""" % ( """
",".join("?" for u in user_chunk), + clause
) )
query_params = []
query_params.extend(user_chunk)
txn.execute(sql, query_params) txn.execute(sql, params)
rows = self.db.cursor_to_dict(txn) rows = self.db.cursor_to_dict(txn)
for row in rows: for row in rows:
@ -453,15 +453,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
device_id = k device_id = k
devices[(user_id, device_id)] = key_type devices[(user_id, device_id)] = key_type
device_list = list(devices) for batch in batch_iter(devices.keys(), size=100):
# split into batches
batch_size = 100
chunks = [
device_list[i : i + batch_size]
for i in range(0, len(device_list), batch_size)
]
for user_chunk in chunks:
sql = """ sql = """
SELECT target_user_id, target_device_id, key_id, signature SELECT target_user_id, target_device_id, key_id, signature
FROM e2e_cross_signing_signatures FROM e2e_cross_signing_signatures
@ -469,11 +461,11 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
AND (%s) AND (%s)
""" % ( """ % (
" OR ".join( " OR ".join(
"(target_user_id = ? AND target_device_id = ?)" for d in devices "(target_user_id = ? AND target_device_id = ?)" for _ in batch
) )
) )
query_params = [from_user_id] query_params = [from_user_id]
for item in devices: for item in batch:
# item is a (user_id, device_id) tuple # item is a (user_id, device_id) tuple
query_params.extend(item) query_params.extend(item)