0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-15 14:23:50 +01:00

Add helper funcs to use postgres ANY

This means that we can write queries with `col = ANY(?)`, which helps
postgres.
This commit is contained in:
Erik Johnston 2019-10-02 19:06:12 +01:00
parent b5b03b7079
commit b4fbf71187

View file

@ -20,6 +20,7 @@ import random
import sys import sys
import threading import threading
import time import time
from typing import Iterable, List, Tuple
from six import PY2, iteritems, iterkeys, itervalues from six import PY2, iteritems, iterkeys, itervalues
from six.moves import builtins, intern, range from six.moves import builtins, intern, range
@ -1162,19 +1163,20 @@ class SQLBaseStore(object):
if not iterable: if not iterable:
return [] return []
sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
clauses = [] clauses = []
values = [] values = []
clauses.append("%s IN (%s)" % (column, ",".join("?" for _ in iterable)))
values.extend(iterable) add_in_list_sql_clause(txn.database_engine, column, iterable, clauses, values)
for key, value in iteritems(keyvalues): for key, value in iteritems(keyvalues):
clauses.append("%s = ?" % (key,)) clauses.append("%s = ?" % (key,))
values.append(value) values.append(value)
if clauses: sql = "SELECT %s FROM %s WHERE %s" % (
sql = "%s WHERE %s" % (sql, " AND ".join(clauses)) ", ".join(retcols),
table,
" AND ".join(clauses),
)
txn.execute(sql, values) txn.execute(sql, values)
return cls.cursor_to_dict(txn) return cls.cursor_to_dict(txn)
@ -1325,8 +1327,8 @@ class SQLBaseStore(object):
clauses = [] clauses = []
values = [] values = []
clauses.append("%s IN (%s)" % (column, ",".join("?" for _ in iterable)))
values.extend(iterable) add_in_list_sql_clause(txn.database_engine, column, iterable, clauses, values)
for key, value in iteritems(keyvalues): for key, value in iteritems(keyvalues):
clauses.append("%s = ?" % (key,)) clauses.append("%s = ?" % (key,))
@ -1693,3 +1695,49 @@ def db_to_json(db_content):
except Exception: except Exception:
logging.warning("Tried to decode '%r' as JSON and failed", db_content) logging.warning("Tried to decode '%r' as JSON and failed", db_content)
raise raise
def add_in_list_sql_clause(
database_engine, column: str, iterable: Iterable, clauses: List[str], args: List
):
"""Adds an SQL clause to the given list of clauses/args that checks the
given column is in the iterable. c.f. `make_in_list_sql_clause`
Args:
database_engine
column: Name of the column
iterable: The values to check the column against.
clauses: A list to add the expanded clause to
args: A list of arguments that we append the args to.
"""
clause, new_args = make_in_list_sql_clause(database_engine, column, iterable)
clauses.append(clause)
args.extend(new_args)
def make_in_list_sql_clause(
database_engine, column: str, iterable: Iterable
) -> Tuple[str, Iterable]:
"""Returns an SQL clause that checks the given column is in the iterable.
On SQLite this expands to `column IN (?, ?, ...)`, whereas on Postgres
it expands to `column = ANY(?)`. While both DBs support the `IN` form,
using the `ANY` form on postgres means that it views queries with
different length iterables as the same, helping the query stats.
Args:
database_engine
column: Name of the column
iterable: The values to check the column against.
Returns:
A tuple of SQL query and the args
"""
if isinstance(database_engine, PostgresEngine):
# This should hopefully be faster, but also makes postgres query
# stats easier to understand.
return "%s = ANY(?)" % (column,), [list(iterable)]
else:
return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), iterable