diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index abe16334e..a94cbc27d 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -20,6 +20,7 @@ import random import sys import threading import time +from typing import Iterable, List, Tuple from six import PY2, iteritems, iterkeys, itervalues from six.moves import builtins, intern, range @@ -1162,19 +1163,20 @@ class SQLBaseStore(object): if not iterable: return [] - sql = "SELECT %s FROM %s" % (", ".join(retcols), table) - clauses = [] values = [] - clauses.append("%s IN (%s)" % (column, ",".join("?" for _ in iterable))) - values.extend(iterable) + + add_in_list_sql_clause(txn.database_engine, column, iterable, clauses, values) for key, value in iteritems(keyvalues): clauses.append("%s = ?" % (key,)) values.append(value) - if clauses: - sql = "%s WHERE %s" % (sql, " AND ".join(clauses)) + sql = "SELECT %s FROM %s WHERE %s" % ( + ", ".join(retcols), + table, + " AND ".join(clauses), + ) txn.execute(sql, values) return cls.cursor_to_dict(txn) @@ -1325,8 +1327,8 @@ class SQLBaseStore(object): clauses = [] values = [] - clauses.append("%s IN (%s)" % (column, ",".join("?" for _ in iterable))) - values.extend(iterable) + + add_in_list_sql_clause(txn.database_engine, column, iterable, clauses, values) for key, value in iteritems(keyvalues): clauses.append("%s = ?" % (key,)) @@ -1693,3 +1695,49 @@ def db_to_json(db_content): except Exception: logging.warning("Tried to decode '%r' as JSON and failed", db_content) raise + + +def add_in_list_sql_clause( + database_engine, column: str, iterable: Iterable, clauses: List[str], args: List +): + """Adds an SQL clause to the given list of clauses/args that checks the + given column is in the iterable. c.f. `make_in_list_sql_clause` + + Args: + database_engine + column: Name of the column + iterable: The values to check the column against. + clauses: A list to add the expanded clause to + args: A list of arguments that we append the args to. + """ + + clause, new_args = make_in_list_sql_clause(database_engine, column, iterable) + clauses.append(clause) + args.extend(new_args) + + +def make_in_list_sql_clause( + database_engine, column: str, iterable: Iterable +) -> Tuple[str, Iterable]: + """Returns an SQL clause that checks the given column is in the iterable. + + On SQLite this expands to `column IN (?, ?, ...)`, whereas on Postgres + it expands to `column = ANY(?)`. While both DBs support the `IN` form, + using the `ANY` form on postgres means that it views queries with + different length iterables as the same, helping the query stats. + + Args: + database_engine + column: Name of the column + iterable: The values to check the column against. + + Returns: + A tuple of SQL query and the args + """ + + if isinstance(database_engine, PostgresEngine): + # This should hopefully be faster, but also makes postgres query + # stats easier to understand. + return "%s = ANY(?)" % (column,), [list(iterable)] + else: + return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), iterable