Merge pull request #5191 from matrix-org/erikj/refactor_pagination_bounds

Make generating SQL bounds for pagination generic
This commit is contained in:
Erik Johnston 2019-05-17 17:24:36 +01:00 committed by GitHub
commit 85ece3df46
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 132 additions and 59 deletions

1
changelog.d/5191.misc Normal file
View file

@ -0,0 +1 @@
Make generating SQL bounds for pagination generic.

View file

@ -64,57 +64,133 @@ _EventDictReturn = namedtuple(
) )
def lower_bound(token, engine, inclusive=False): def generate_pagination_where_clause(
inclusive = "=" if inclusive else "" direction, column_names, from_token, to_token, engine,
if token.topological is None: ):
return "(%d <%s %s)" % (token.stream, inclusive, "stream_ordering") """Creates an SQL expression to bound the columns by the pagination
else: tokens.
For example creates an SQL expression like:
(6, 7) >= (topological_ordering, stream_ordering)
AND (5, 3) < (topological_ordering, stream_ordering)
would be generated for dir=b, from_token=(6, 7) and to_token=(5, 3).
Note that tokens are considered to be after the row they are in, e.g. if
a row A has a token T, then we consider A to be before T. This convention
is important when figuring out inequalities for the generated SQL, and
produces the following result:
- If paginating forwards then we exclude any rows matching the from
token, but include those that match the to token.
- If paginating backwards then we include any rows matching the from
token, but include those that match the to token.
Args:
direction (str): Whether we're paginating backwards("b") or
forwards ("f").
column_names (tuple[str, str]): The column names to bound. Must *not*
be user defined as these get inserted directly into the SQL
statement without escapes.
from_token (tuple[int, int]|None): The start point for the pagination.
This is an exclusive minimum bound if direction is "f", and an
inclusive maximum bound if direction is "b".
to_token (tuple[int, int]|None): The endpoint point for the pagination.
This is an inclusive maximum bound if direction is "f", and an
exclusive minimum bound if direction is "b".
engine: The database engine to generate the clauses for
Returns:
str: The sql expression
"""
assert direction in ("b", "f")
where_clause = []
if from_token:
where_clause.append(
_make_generic_sql_bound(
bound=">=" if direction == "b" else "<",
column_names=column_names,
values=from_token,
engine=engine,
)
)
if to_token:
where_clause.append(
_make_generic_sql_bound(
bound="<" if direction == "b" else ">=",
column_names=column_names,
values=to_token,
engine=engine,
)
)
return " AND ".join(where_clause)
def _make_generic_sql_bound(bound, column_names, values, engine):
"""Create an SQL expression that bounds the given column names by the
values, e.g. create the equivalent of `(1, 2) < (col1, col2)`.
Only works with two columns.
Older versions of SQLite don't support that syntax so we have to expand it
out manually.
Args:
bound (str): The comparison operator to use. One of ">", "<", ">=",
"<=", where the values are on the left and columns on the right.
names (tuple[str, str]): The column names. Must *not* be user defined
as these get inserted directly into the SQL statement without
escapes.
values (tuple[int|None, int]): The values to bound the columns by. If
the first value is None then only creates a bound on the second
column.
engine: The database engine to generate the SQL for
Returns:
str
"""
assert(bound in (">", "<", ">=", "<="))
name1, name2 = column_names
val1, val2 = values
if val1 is None:
val2 = int(val2)
return "(%d %s %s)" % (val2, bound, name2)
val1 = int(val1)
val2 = int(val2)
if isinstance(engine, PostgresEngine): if isinstance(engine, PostgresEngine):
# Postgres doesn't optimise ``(x < a) OR (x=a AND y<b)`` as well # Postgres doesn't optimise ``(x < a) OR (x=a AND y<b)`` as well
# as it optimises ``(x,y) < (a,b)`` on multicolumn indexes. So we # as it optimises ``(x,y) < (a,b)`` on multicolumn indexes. So we
# use the later form when running against postgres. # use the later form when running against postgres.
return "((%d,%d) <%s (%s,%s))" % ( return "((%d,%d) %s (%s,%s))" % (
token.topological, val1, val2,
token.stream, bound,
inclusive, name1, name2,
"topological_ordering",
"stream_ordering",
)
return "(%d < %s OR (%d = %s AND %d <%s %s))" % (
token.topological,
"topological_ordering",
token.topological,
"topological_ordering",
token.stream,
inclusive,
"stream_ordering",
) )
# We want to generate queries of e.g. the form:
#
# (val1 < name1 OR (val1 = name1 AND val2 <= name2))
#
# which is equivalent to (val1, val2) < (name1, name2)
def upper_bound(token, engine, inclusive=True): return """(
inclusive = "=" if inclusive else "" {val1:d} {strict_bound} {name1}
if token.topological is None: OR ({val1:d} = {name1} AND {val2:d} {bound} {name2})
return "(%d >%s %s)" % (token.stream, inclusive, "stream_ordering") )""".format(
else: name1=name1,
if isinstance(engine, PostgresEngine): val1=val1,
# Postgres doesn't optimise ``(x > a) OR (x=a AND y>b)`` as well name2=name2,
# as it optimises ``(x,y) > (a,b)`` on multicolumn indexes. So we val2=val2,
# use the later form when running against postgres. strict_bound=bound[0], # The first bound must always be strict equality here
return "((%d,%d) >%s (%s,%s))" % ( bound=bound,
token.topological,
token.stream,
inclusive,
"topological_ordering",
"stream_ordering",
)
return "(%d > %s OR (%d = %s AND %d >%s %s))" % (
token.topological,
"topological_ordering",
token.topological,
"topological_ordering",
token.stream,
inclusive,
"stream_ordering",
) )
@ -762,19 +838,15 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
args = [False, room_id] args = [False, room_id]
if direction == 'b': if direction == 'b':
order = "DESC" order = "DESC"
bounds = upper_bound(from_token, self.database_engine)
if to_token:
bounds = "%s AND %s" % (
bounds,
lower_bound(to_token, self.database_engine),
)
else: else:
order = "ASC" order = "ASC"
bounds = lower_bound(from_token, self.database_engine)
if to_token: bounds = generate_pagination_where_clause(
bounds = "%s AND %s" % ( direction=direction,
bounds, column_names=("topological_ordering", "stream_ordering"),
upper_bound(to_token, self.database_engine), from_token=from_token,
to_token=to_token,
engine=self.database_engine,
) )
filter_clause, filter_args = filter_to_clause(event_filter) filter_clause, filter_args = filter_to_clause(event_filter)