mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-15 14:23:50 +01:00
Add helper funcs to use postgres ANY
This means that we can write queries with `col = ANY(?)`, which helps postgres.
This commit is contained in:
parent
b5b03b7079
commit
b4fbf71187
1 changed files with 56 additions and 8 deletions
|
@ -20,6 +20,7 @@ import random
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
from typing import Iterable, List, Tuple
|
||||||
|
|
||||||
from six import PY2, iteritems, iterkeys, itervalues
|
from six import PY2, iteritems, iterkeys, itervalues
|
||||||
from six.moves import builtins, intern, range
|
from six.moves import builtins, intern, range
|
||||||
|
@ -1162,19 +1163,20 @@ class SQLBaseStore(object):
|
||||||
if not iterable:
|
if not iterable:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
|
|
||||||
|
|
||||||
clauses = []
|
clauses = []
|
||||||
values = []
|
values = []
|
||||||
clauses.append("%s IN (%s)" % (column, ",".join("?" for _ in iterable)))
|
|
||||||
values.extend(iterable)
|
add_in_list_sql_clause(txn.database_engine, column, iterable, clauses, values)
|
||||||
|
|
||||||
for key, value in iteritems(keyvalues):
|
for key, value in iteritems(keyvalues):
|
||||||
clauses.append("%s = ?" % (key,))
|
clauses.append("%s = ?" % (key,))
|
||||||
values.append(value)
|
values.append(value)
|
||||||
|
|
||||||
if clauses:
|
sql = "SELECT %s FROM %s WHERE %s" % (
|
||||||
sql = "%s WHERE %s" % (sql, " AND ".join(clauses))
|
", ".join(retcols),
|
||||||
|
table,
|
||||||
|
" AND ".join(clauses),
|
||||||
|
)
|
||||||
|
|
||||||
txn.execute(sql, values)
|
txn.execute(sql, values)
|
||||||
return cls.cursor_to_dict(txn)
|
return cls.cursor_to_dict(txn)
|
||||||
|
@ -1325,8 +1327,8 @@ class SQLBaseStore(object):
|
||||||
|
|
||||||
clauses = []
|
clauses = []
|
||||||
values = []
|
values = []
|
||||||
clauses.append("%s IN (%s)" % (column, ",".join("?" for _ in iterable)))
|
|
||||||
values.extend(iterable)
|
add_in_list_sql_clause(txn.database_engine, column, iterable, clauses, values)
|
||||||
|
|
||||||
for key, value in iteritems(keyvalues):
|
for key, value in iteritems(keyvalues):
|
||||||
clauses.append("%s = ?" % (key,))
|
clauses.append("%s = ?" % (key,))
|
||||||
|
@ -1693,3 +1695,49 @@ def db_to_json(db_content):
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.warning("Tried to decode '%r' as JSON and failed", db_content)
|
logging.warning("Tried to decode '%r' as JSON and failed", db_content)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def add_in_list_sql_clause(
|
||||||
|
database_engine, column: str, iterable: Iterable, clauses: List[str], args: List
|
||||||
|
):
|
||||||
|
"""Adds an SQL clause to the given list of clauses/args that checks the
|
||||||
|
given column is in the iterable. c.f. `make_in_list_sql_clause`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
database_engine
|
||||||
|
column: Name of the column
|
||||||
|
iterable: The values to check the column against.
|
||||||
|
clauses: A list to add the expanded clause to
|
||||||
|
args: A list of arguments that we append the args to.
|
||||||
|
"""
|
||||||
|
|
||||||
|
clause, new_args = make_in_list_sql_clause(database_engine, column, iterable)
|
||||||
|
clauses.append(clause)
|
||||||
|
args.extend(new_args)
|
||||||
|
|
||||||
|
|
||||||
|
def make_in_list_sql_clause(
|
||||||
|
database_engine, column: str, iterable: Iterable
|
||||||
|
) -> Tuple[str, Iterable]:
|
||||||
|
"""Returns an SQL clause that checks the given column is in the iterable.
|
||||||
|
|
||||||
|
On SQLite this expands to `column IN (?, ?, ...)`, whereas on Postgres
|
||||||
|
it expands to `column = ANY(?)`. While both DBs support the `IN` form,
|
||||||
|
using the `ANY` form on postgres means that it views queries with
|
||||||
|
different length iterables as the same, helping the query stats.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
database_engine
|
||||||
|
column: Name of the column
|
||||||
|
iterable: The values to check the column against.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of SQL query and the args
|
||||||
|
"""
|
||||||
|
|
||||||
|
if isinstance(database_engine, PostgresEngine):
|
||||||
|
# This should hopefully be faster, but also makes postgres query
|
||||||
|
# stats easier to understand.
|
||||||
|
return "%s = ANY(?)" % (column,), [list(iterable)]
|
||||||
|
else:
|
||||||
|
return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), iterable
|
||||||
|
|
Loading…
Reference in a new issue