0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2025-01-21 13:12:01 +01:00

Merge pull request #6156 from matrix-org/erikj/postgres_any

Use Postgres ANY for selecting many values.
This commit is contained in:
Erik Johnston 2019-10-10 16:41:36 +01:00 committed by GitHub
commit 83d86106a8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 189 additions and 108 deletions

1
changelog.d/6156.misc Normal file
View file

@ -0,0 +1 @@
Use Postgres ANY for selecting many values.

View file

@ -20,6 +20,7 @@ import random
import sys import sys
import threading import threading
import time import time
from typing import Iterable, Tuple
from six import PY2, iteritems, iterkeys, itervalues from six import PY2, iteritems, iterkeys, itervalues
from six.moves import builtins, intern, range from six.moves import builtins, intern, range
@ -1163,19 +1164,18 @@ class SQLBaseStore(object):
if not iterable: if not iterable:
return [] return []
sql = "SELECT %s FROM %s" % (", ".join(retcols), table) clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable)
clauses = [clause]
clauses = []
values = []
clauses.append("%s IN (%s)" % (column, ",".join("?" for _ in iterable)))
values.extend(iterable)
for key, value in iteritems(keyvalues): for key, value in iteritems(keyvalues):
clauses.append("%s = ?" % (key,)) clauses.append("%s = ?" % (key,))
values.append(value) values.append(value)
if clauses: sql = "SELECT %s FROM %s WHERE %s" % (
sql = "%s WHERE %s" % (sql, " AND ".join(clauses)) ", ".join(retcols),
table,
" AND ".join(clauses),
)
txn.execute(sql, values) txn.execute(sql, values)
return cls.cursor_to_dict(txn) return cls.cursor_to_dict(txn)
@ -1324,10 +1324,8 @@ class SQLBaseStore(object):
sql = "DELETE FROM %s" % table sql = "DELETE FROM %s" % table
clauses = [] clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable)
values = [] clauses = [clause]
clauses.append("%s IN (%s)" % (column, ",".join("?" for _ in iterable)))
values.extend(iterable)
for key, value in iteritems(keyvalues): for key, value in iteritems(keyvalues):
clauses.append("%s = ?" % (key,)) clauses.append("%s = ?" % (key,))
@ -1694,3 +1692,30 @@ def db_to_json(db_content):
except Exception: except Exception:
logging.warning("Tried to decode '%r' as JSON and failed", db_content) logging.warning("Tried to decode '%r' as JSON and failed", db_content)
raise raise
def make_in_list_sql_clause(
database_engine, column: str, iterable: Iterable
) -> Tuple[str, Iterable]:
"""Returns an SQL clause that checks the given column is in the iterable.
On SQLite this expands to `column IN (?, ?, ...)`, whereas on Postgres
it expands to `column = ANY(?)`. While both DBs support the `IN` form,
using the `ANY` form on postgres means that it views queries with
different length iterables as the same, helping the query stats.
Args:
database_engine
column: Name of the column
iterable: The values to check the column against.
Returns:
A tuple of SQL query and the args
"""
if database_engine.supports_using_any_list:
# This should hopefully be faster, but also makes postgres query
# stats easier to understand.
return "%s = ANY(?)" % (column,), [list(iterable)]
else:
return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable)

View file

@ -20,7 +20,7 @@ from canonicaljson import json
from twisted.internet import defer from twisted.internet import defer
from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
@ -378,15 +378,15 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
else: else:
if not devices: if not devices:
continue continue
sql = (
"SELECT device_id FROM devices" clause, args = make_in_list_sql_clause(
" WHERE user_id = ? AND device_id IN (" txn.database_engine, "device_id", devices
+ ",".join("?" * len(devices))
+ ")"
) )
sql = "SELECT device_id FROM devices WHERE user_id = ? AND " + clause
# TODO: Maybe this needs to be done in batches if there are # TODO: Maybe this needs to be done in batches if there are
# too many local devices for a given user. # too many local devices for a given user.
txn.execute(sql, [user_id] + devices) txn.execute(sql, [user_id] + list(args))
for row in txn: for row in txn:
# Only insert into the local inbox if the device exists on # Only insert into the local inbox if the device exists on
# this server # this server

View file

@ -28,7 +28,12 @@ from synapse.logging.opentracing import (
whitelisted_homeserver, whitelisted_homeserver,
) )
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import Cache, SQLBaseStore, db_to_json from synapse.storage._base import (
Cache,
SQLBaseStore,
db_to_json,
make_in_list_sql_clause,
)
from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.util import batch_iter from synapse.util import batch_iter
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
@ -448,11 +453,14 @@ class DeviceWorkerStore(SQLBaseStore):
sql = """ sql = """
SELECT DISTINCT user_id FROM device_lists_stream SELECT DISTINCT user_id FROM device_lists_stream
WHERE stream_id > ? WHERE stream_id > ?
AND user_id IN (%s) AND
""" """
for chunk in batch_iter(to_check, 100): for chunk in batch_iter(to_check, 100):
txn.execute(sql % (",".join("?" for _ in chunk),), (from_key,) + chunk) clause, args = make_in_list_sql_clause(
txn.database_engine, "user_id", chunk
)
txn.execute(sql + clause, (from_key,) + tuple(args))
changes.update(user_id for user_id, in txn) changes.update(user_id for user_id, in txn)
return changes return changes

View file

@ -86,6 +86,12 @@ class PostgresEngine(object):
""" """
return True return True
@property
def supports_using_any_list(self):
"""Do we support using `a = ANY(?)` and passing a list
"""
return True
def is_deadlock(self, error): def is_deadlock(self, error):
if isinstance(error, self.module.DatabaseError): if isinstance(error, self.module.DatabaseError):
# https://www.postgresql.org/docs/current/static/errcodes-appendix.html # https://www.postgresql.org/docs/current/static/errcodes-appendix.html

View file

@ -46,6 +46,12 @@ class Sqlite3Engine(object):
""" """
return self.module.sqlite_version_info >= (3, 15, 0) return self.module.sqlite_version_info >= (3, 15, 0)
@property
def supports_using_any_list(self):
"""Do we support using `a = ANY(?)` and passing a list
"""
return False
def check_database(self, txn): def check_database(self, txn):
pass pass

View file

@ -25,7 +25,7 @@ from twisted.internet import defer
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.events_worker import EventsWorkerStore from synapse.storage.events_worker import EventsWorkerStore
from synapse.storage.signatures import SignatureWorkerStore from synapse.storage.signatures import SignatureWorkerStore
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
@ -68,7 +68,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
else: else:
results = set() results = set()
base_sql = "SELECT auth_id FROM event_auth WHERE event_id IN (%s)" base_sql = "SELECT auth_id FROM event_auth WHERE "
front = set(event_ids) front = set(event_ids)
while front: while front:
@ -76,7 +76,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
front_list = list(front) front_list = list(front)
chunks = [front_list[x : x + 100] for x in range(0, len(front), 100)] chunks = [front_list[x : x + 100] for x in range(0, len(front), 100)]
for chunk in chunks: for chunk in chunks:
txn.execute(base_sql % (",".join(["?"] * len(chunk)),), chunk) clause, args = make_in_list_sql_clause(
txn.database_engine, "event_id", chunk
)
txn.execute(base_sql + clause, list(args))
new_front.update([r[0] for r in txn]) new_front.update([r[0] for r in txn])
new_front -= results new_front -= results

View file

@ -39,6 +39,7 @@ from synapse.logging.utils import log_function
from synapse.metrics import BucketCollector from synapse.metrics import BucketCollector
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.state import StateResolutionStore from synapse.state import StateResolutionStore
from synapse.storage._base import make_in_list_sql_clause
from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.storage.event_federation import EventFederationStore from synapse.storage.event_federation import EventFederationStore
from synapse.storage.events_worker import EventsWorkerStore from synapse.storage.events_worker import EventsWorkerStore
@ -641,14 +642,16 @@ class EventsStore(
LEFT JOIN rejections USING (event_id) LEFT JOIN rejections USING (event_id)
LEFT JOIN event_json USING (event_id) LEFT JOIN event_json USING (event_id)
WHERE WHERE
prev_event_id IN (%s) NOT events.outlier
AND NOT events.outlier
AND rejections.event_id IS NULL AND rejections.event_id IS NULL
""" % ( AND
",".join("?" for _ in batch), """
clause, args = make_in_list_sql_clause(
self.database_engine, "prev_event_id", batch
) )
txn.execute(sql, batch) txn.execute(sql + clause, args)
results.extend(r[0] for r in txn if not json.loads(r[1]).get("soft_failed")) results.extend(r[0] for r in txn if not json.loads(r[1]).get("soft_failed"))
for chunk in batch_iter(event_ids, 100): for chunk in batch_iter(event_ids, 100):
@ -695,13 +698,15 @@ class EventsStore(
LEFT JOIN rejections USING (event_id) LEFT JOIN rejections USING (event_id)
LEFT JOIN event_json USING (event_id) LEFT JOIN event_json USING (event_id)
WHERE WHERE
event_id IN (%s) NOT events.outlier
AND NOT events.outlier AND
""" % ( """
",".join("?" for _ in to_recursively_check),
clause, args = make_in_list_sql_clause(
self.database_engine, "event_id", to_recursively_check
) )
txn.execute(sql, to_recursively_check) txn.execute(sql + clause, args)
to_recursively_check = [] to_recursively_check = []
for event_id, prev_event_id, metadata, rejected in txn: for event_id, prev_event_id, metadata, rejected in txn:
@ -1543,10 +1548,14 @@ class EventsStore(
" FROM events as e" " FROM events as e"
" LEFT JOIN rejections as rej USING (event_id)" " LEFT JOIN rejections as rej USING (event_id)"
" LEFT JOIN redactions as r ON e.event_id = r.redacts" " LEFT JOIN redactions as r ON e.event_id = r.redacts"
" WHERE e.event_id IN (%s)" " WHERE "
) % (",".join(["?"] * len(ev_map)),) )
txn.execute(sql, list(ev_map)) clause, args = make_in_list_sql_clause(
self.database_engine, "e.event_id", list(ev_map)
)
txn.execute(sql + clause, args)
rows = self.cursor_to_dict(txn) rows = self.cursor_to_dict(txn)
for row in rows: for row in rows:
event = ev_map[row["event_id"]] event = ev_map[row["event_id"]]
@ -2249,11 +2258,12 @@ class EventsStore(
sql = """ sql = """
SELECT DISTINCT state_group FROM event_to_state_groups SELECT DISTINCT state_group FROM event_to_state_groups
LEFT JOIN events_to_purge AS ep USING (event_id) LEFT JOIN events_to_purge AS ep USING (event_id)
WHERE state_group IN (%s) AND ep.event_id IS NULL WHERE ep.event_id IS NULL AND
""" % ( """
",".join("?" for _ in current_search), clause, args = make_in_list_sql_clause(
txn.database_engine, "state_group", current_search
) )
txn.execute(sql, list(current_search)) txn.execute(sql + clause, list(args))
referenced = set(sg for sg, in txn) referenced = set(sg for sg, in txn)
referenced_groups |= referenced referenced_groups |= referenced

View file

@ -21,6 +21,7 @@ from canonicaljson import json
from twisted.internet import defer from twisted.internet import defer
from synapse.storage._base import make_in_list_sql_clause
from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.background_updates import BackgroundUpdateStore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -325,12 +326,13 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
INNER JOIN event_json USING (event_id) INNER JOIN event_json USING (event_id)
LEFT JOIN rejections USING (event_id) LEFT JOIN rejections USING (event_id)
WHERE WHERE
prev_event_id IN (%s) NOT events.outlier
AND NOT events.outlier AND
""" % ( """
",".join("?" for _ in to_check), clause, args = make_in_list_sql_clause(
self.database_engine, "prev_event_id", to_check
) )
txn.execute(sql, to_check) txn.execute(sql + clause, list(args))
for prev_event_id, event_id, metadata, rejected in txn: for prev_event_id, event_id, metadata, rejected in txn:
if event_id in graph: if event_id in graph:

View file

@ -31,12 +31,11 @@ from synapse.events.snapshot import EventContext # noqa: F401
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.logging.context import LoggingContext, PreserveLoggingContext from synapse.logging.context import LoggingContext, PreserveLoggingContext
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.types import get_domain_from_id from synapse.types import get_domain_from_id
from synapse.util import batch_iter from synapse.util import batch_iter
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from ._base import SQLBaseStore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -623,10 +622,14 @@ class EventsWorkerStore(SQLBaseStore):
" rej.reason " " rej.reason "
" FROM event_json as e" " FROM event_json as e"
" LEFT JOIN rejections as rej USING (event_id)" " LEFT JOIN rejections as rej USING (event_id)"
" WHERE e.event_id IN (%s)" " WHERE "
) % (",".join(["?"] * len(evs)),) )
txn.execute(sql, evs) clause, args = make_in_list_sql_clause(
txn.database_engine, "e.event_id", evs
)
txn.execute(sql + clause, args)
for row in txn: for row in txn:
event_id = row[0] event_id = row[0]
@ -640,11 +643,11 @@ class EventsWorkerStore(SQLBaseStore):
} }
# check for redactions # check for redactions
redactions_sql = ( redactions_sql = "SELECT event_id, redacts FROM redactions WHERE "
"SELECT event_id, redacts FROM redactions WHERE redacts IN (%s)"
) % (",".join(["?"] * len(evs)),)
txn.execute(redactions_sql, evs) clause, args = make_in_list_sql_clause(txn.database_engine, "redacts", evs)
txn.execute(redactions_sql + clause, args)
for (redacter, redacted) in txn: for (redacter, redacted) in txn:
d = event_dict.get(redacted) d = event_dict.get(redacted)
@ -753,10 +756,11 @@ class EventsWorkerStore(SQLBaseStore):
results = set() results = set()
def have_seen_events_txn(txn, chunk): def have_seen_events_txn(txn, chunk):
sql = "SELECT event_id FROM events as e WHERE e.event_id IN (%s)" % ( sql = "SELECT event_id FROM events as e WHERE "
",".join("?" * len(chunk)), clause, args = make_in_list_sql_clause(
txn.database_engine, "e.event_id", chunk
) )
txn.execute(sql, chunk) txn.execute(sql + clause, args)
for (event_id,) in txn: for (event_id,) in txn:
results.add(event_id) results.add(event_id)

View file

@ -18,11 +18,10 @@ from collections import namedtuple
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import PresenceState from synapse.api.constants import PresenceState
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.util import batch_iter from synapse.util import batch_iter
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
from ._base import SQLBaseStore
class UserPresenceState( class UserPresenceState(
namedtuple( namedtuple(
@ -119,14 +118,13 @@ class PresenceStore(SQLBaseStore):
) )
# Delete old rows to stop database from getting really big # Delete old rows to stop database from getting really big
sql = ( sql = "DELETE FROM presence_stream WHERE stream_id < ? AND "
"DELETE FROM presence_stream WHERE" " stream_id < ?" " AND user_id IN (%s)"
)
for states in batch_iter(presence_states, 50): for states in batch_iter(presence_states, 50):
args = [stream_id] clause, args = make_in_list_sql_clause(
args.extend(s.user_id for s in states) self.database_engine, "user_id", [s.user_id for s in states]
txn.execute(sql % (",".join("?" for _ in states),), args) )
txn.execute(sql + clause, [stream_id] + list(args))
def get_all_presence_updates(self, last_id, current_id): def get_all_presence_updates(self, last_id, current_id):
if last_id == current_id: if last_id == current_id:

View file

@ -21,12 +21,11 @@ from canonicaljson import json
from twisted.internet import defer from twisted.internet import defer
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
from ._base import SQLBaseStore
from .util.id_generators import StreamIdGenerator
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -217,24 +216,26 @@ class ReceiptsWorkerStore(SQLBaseStore):
def f(txn): def f(txn):
if from_key: if from_key:
sql = ( sql = """
"SELECT * FROM receipts_linearized WHERE" SELECT * FROM receipts_linearized WHERE
" room_id IN (%s) AND stream_id > ? AND stream_id <= ?" stream_id > ? AND stream_id <= ? AND
) % (",".join(["?"] * len(room_ids))) """
args = list(room_ids) clause, args = make_in_list_sql_clause(
args.extend([from_key, to_key]) self.database_engine, "room_id", room_ids
)
txn.execute(sql, args) txn.execute(sql + clause, [from_key, to_key] + list(args))
else: else:
sql = ( sql = """
"SELECT * FROM receipts_linearized WHERE" SELECT * FROM receipts_linearized WHERE
" room_id IN (%s) AND stream_id <= ?" stream_id <= ? AND
) % (",".join(["?"] * len(room_ids))) """
args = list(room_ids) clause, args = make_in_list_sql_clause(
args.append(to_key) self.database_engine, "room_id", room_ids
)
txn.execute(sql, args) txn.execute(sql + clause, [to_key] + list(args))
return self.cursor_to_dict(txn) return self.cursor_to_dict(txn)
@ -433,13 +434,19 @@ class ReceiptsStore(ReceiptsWorkerStore):
# we need to points in graph -> linearized form. # we need to points in graph -> linearized form.
# TODO: Make this better. # TODO: Make this better.
def graph_to_linear(txn): def graph_to_linear(txn):
query = ( clause, args = make_in_list_sql_clause(
"SELECT event_id WHERE room_id = ? AND stream_ordering IN (" self.database_engine, "event_id", event_ids
" SELECT max(stream_ordering) WHERE event_id IN (%s)" )
")"
) % (",".join(["?"] * len(event_ids)))
txn.execute(query, [room_id] + event_ids) sql = """
SELECT event_id WHERE room_id = ? AND stream_ordering IN (
SELECT max(stream_ordering) WHERE %s
)
""" % (
clause,
)
txn.execute(sql, [room_id] + list(args))
rows = txn.fetchall() rows = txn.fetchall()
if rows: if rows:
return rows[0][0] return rows[0][0]

View file

@ -26,7 +26,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.metrics import LaterGauge from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import LoggingTransaction from synapse.storage._base import LoggingTransaction, make_in_list_sql_clause
from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.storage.engines import Sqlite3Engine from synapse.storage.engines import Sqlite3Engine
from synapse.storage.events_worker import EventsWorkerStore from synapse.storage.events_worker import EventsWorkerStore
@ -372,6 +372,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
results = [] results = []
if membership_list: if membership_list:
if self._current_state_events_membership_up_to_date: if self._current_state_events_membership_up_to_date:
clause, args = make_in_list_sql_clause(
self.database_engine, "c.membership", membership_list
)
sql = """ sql = """
SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering
FROM current_state_events AS c FROM current_state_events AS c
@ -379,11 +382,14 @@ class RoomMemberWorkerStore(EventsWorkerStore):
WHERE WHERE
c.type = 'm.room.member' c.type = 'm.room.member'
AND state_key = ? AND state_key = ?
AND c.membership IN (%s) AND %s
""" % ( """ % (
",".join("?" * len(membership_list)) clause,
) )
else: else:
clause, args = make_in_list_sql_clause(
self.database_engine, "m.membership", membership_list
)
sql = """ sql = """
SELECT room_id, e.sender, m.membership, event_id, e.stream_ordering SELECT room_id, e.sender, m.membership, event_id, e.stream_ordering
FROM current_state_events AS c FROM current_state_events AS c
@ -392,12 +398,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
WHERE WHERE
c.type = 'm.room.member' c.type = 'm.room.member'
AND state_key = ? AND state_key = ?
AND m.membership IN (%s) AND %s
""" % ( """ % (
",".join("?" * len(membership_list)) clause,
) )
txn.execute(sql, (user_id, *membership_list)) txn.execute(sql, (user_id, *args))
results = [RoomsForUser(**r) for r in self.cursor_to_dict(txn)] results = [RoomsForUser(**r) for r in self.cursor_to_dict(txn)]
if do_invite: if do_invite:

View file

@ -24,6 +24,7 @@ from canonicaljson import json
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.storage._base import make_in_list_sql_clause
from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from .background_updates import BackgroundUpdateStore from .background_updates import BackgroundUpdateStore
@ -385,8 +386,10 @@ class SearchStore(SearchBackgroundUpdateStore):
# Make sure we don't explode because the person is in too many rooms. # Make sure we don't explode because the person is in too many rooms.
# We filter the results below regardless. # We filter the results below regardless.
if len(room_ids) < 500: if len(room_ids) < 500:
clauses.append("room_id IN (%s)" % (",".join(["?"] * len(room_ids)),)) clause, args = make_in_list_sql_clause(
args.extend(room_ids) self.database_engine, "room_id", room_ids
)
clauses = [clause]
local_clauses = [] local_clauses = []
for key in keys: for key in keys:
@ -492,8 +495,10 @@ class SearchStore(SearchBackgroundUpdateStore):
# Make sure we don't explode because the person is in too many rooms. # Make sure we don't explode because the person is in too many rooms.
# We filter the results below regardless. # We filter the results below regardless.
if len(room_ids) < 500: if len(room_ids) < 500:
clauses.append("room_id IN (%s)" % (",".join(["?"] * len(room_ids)),)) clause, args = make_in_list_sql_clause(
args.extend(room_ids) self.database_engine, "room_id", room_ids
)
clauses = [clause]
local_clauses = [] local_clauses = []
for key in keys: for key in keys:

View file

@ -56,15 +56,15 @@ class UserErasureWorkerStore(SQLBaseStore):
# iterate it multiple times, and (b) avoiding duplicates. # iterate it multiple times, and (b) avoiding duplicates.
user_ids = tuple(set(user_ids)) user_ids = tuple(set(user_ids))
def _get_erased_users(txn): rows = yield self._simple_select_many_batch(
txn.execute( table="erased_users",
"SELECT user_id FROM erased_users WHERE user_id IN (%s)" column="user_id",
% (",".join("?" * len(user_ids))), iterable=user_ids,
user_ids, retcols=("user_id",),
) desc="are_users_erased",
return set(r[0] for r in txn) )
erased_users = set(row["user_id"] for row in rows)
erased_users = yield self.runInteraction("are_users_erased", _get_erased_users)
res = dict((u, u in erased_users) for u in user_ids) res = dict((u, u in erased_users) for u in user_ids)
return res return res