mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-15 10:13:53 +01:00
Merge pull request #5191 from matrix-org/erikj/refactor_pagination_bounds
Make generating SQL bounds for pagination generic
This commit is contained in:
commit
85ece3df46
2 changed files with 132 additions and 59 deletions
1
changelog.d/5191.misc
Normal file
1
changelog.d/5191.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Make generating SQL bounds for pagination generic.
|
|
@ -64,57 +64,133 @@ _EventDictReturn = namedtuple(
|
|||
)
|
||||
|
||||
|
||||
def lower_bound(token, engine, inclusive=False):
|
||||
inclusive = "=" if inclusive else ""
|
||||
if token.topological is None:
|
||||
return "(%d <%s %s)" % (token.stream, inclusive, "stream_ordering")
|
||||
else:
|
||||
def generate_pagination_where_clause(
|
||||
direction, column_names, from_token, to_token, engine,
|
||||
):
|
||||
"""Creates an SQL expression to bound the columns by the pagination
|
||||
tokens.
|
||||
|
||||
For example creates an SQL expression like:
|
||||
|
||||
(6, 7) >= (topological_ordering, stream_ordering)
|
||||
AND (5, 3) < (topological_ordering, stream_ordering)
|
||||
|
||||
would be generated for dir=b, from_token=(6, 7) and to_token=(5, 3).
|
||||
|
||||
Note that tokens are considered to be after the row they are in, e.g. if
|
||||
a row A has a token T, then we consider A to be before T. This convention
|
||||
is important when figuring out inequalities for the generated SQL, and
|
||||
produces the following result:
|
||||
- If paginating forwards then we exclude any rows matching the from
|
||||
token, but include those that match the to token.
|
||||
- If paginating backwards then we include any rows matching the from
|
||||
token, but include those that match the to token.
|
||||
|
||||
Args:
|
||||
direction (str): Whether we're paginating backwards("b") or
|
||||
forwards ("f").
|
||||
column_names (tuple[str, str]): The column names to bound. Must *not*
|
||||
be user defined as these get inserted directly into the SQL
|
||||
statement without escapes.
|
||||
from_token (tuple[int, int]|None): The start point for the pagination.
|
||||
This is an exclusive minimum bound if direction is "f", and an
|
||||
inclusive maximum bound if direction is "b".
|
||||
to_token (tuple[int, int]|None): The endpoint point for the pagination.
|
||||
This is an inclusive maximum bound if direction is "f", and an
|
||||
exclusive minimum bound if direction is "b".
|
||||
engine: The database engine to generate the clauses for
|
||||
|
||||
Returns:
|
||||
str: The sql expression
|
||||
"""
|
||||
assert direction in ("b", "f")
|
||||
|
||||
where_clause = []
|
||||
if from_token:
|
||||
where_clause.append(
|
||||
_make_generic_sql_bound(
|
||||
bound=">=" if direction == "b" else "<",
|
||||
column_names=column_names,
|
||||
values=from_token,
|
||||
engine=engine,
|
||||
)
|
||||
)
|
||||
|
||||
if to_token:
|
||||
where_clause.append(
|
||||
_make_generic_sql_bound(
|
||||
bound="<" if direction == "b" else ">=",
|
||||
column_names=column_names,
|
||||
values=to_token,
|
||||
engine=engine,
|
||||
)
|
||||
)
|
||||
|
||||
return " AND ".join(where_clause)
|
||||
|
||||
|
||||
def _make_generic_sql_bound(bound, column_names, values, engine):
|
||||
"""Create an SQL expression that bounds the given column names by the
|
||||
values, e.g. create the equivalent of `(1, 2) < (col1, col2)`.
|
||||
|
||||
Only works with two columns.
|
||||
|
||||
Older versions of SQLite don't support that syntax so we have to expand it
|
||||
out manually.
|
||||
|
||||
Args:
|
||||
bound (str): The comparison operator to use. One of ">", "<", ">=",
|
||||
"<=", where the values are on the left and columns on the right.
|
||||
names (tuple[str, str]): The column names. Must *not* be user defined
|
||||
as these get inserted directly into the SQL statement without
|
||||
escapes.
|
||||
values (tuple[int|None, int]): The values to bound the columns by. If
|
||||
the first value is None then only creates a bound on the second
|
||||
column.
|
||||
engine: The database engine to generate the SQL for
|
||||
|
||||
Returns:
|
||||
str
|
||||
"""
|
||||
|
||||
assert(bound in (">", "<", ">=", "<="))
|
||||
|
||||
name1, name2 = column_names
|
||||
val1, val2 = values
|
||||
|
||||
if val1 is None:
|
||||
val2 = int(val2)
|
||||
return "(%d %s %s)" % (val2, bound, name2)
|
||||
|
||||
val1 = int(val1)
|
||||
val2 = int(val2)
|
||||
|
||||
if isinstance(engine, PostgresEngine):
|
||||
# Postgres doesn't optimise ``(x < a) OR (x=a AND y<b)`` as well
|
||||
# as it optimises ``(x,y) < (a,b)`` on multicolumn indexes. So we
|
||||
# use the later form when running against postgres.
|
||||
return "((%d,%d) <%s (%s,%s))" % (
|
||||
token.topological,
|
||||
token.stream,
|
||||
inclusive,
|
||||
"topological_ordering",
|
||||
"stream_ordering",
|
||||
)
|
||||
return "(%d < %s OR (%d = %s AND %d <%s %s))" % (
|
||||
token.topological,
|
||||
"topological_ordering",
|
||||
token.topological,
|
||||
"topological_ordering",
|
||||
token.stream,
|
||||
inclusive,
|
||||
"stream_ordering",
|
||||
return "((%d,%d) %s (%s,%s))" % (
|
||||
val1, val2,
|
||||
bound,
|
||||
name1, name2,
|
||||
)
|
||||
|
||||
# We want to generate queries of e.g. the form:
|
||||
#
|
||||
# (val1 < name1 OR (val1 = name1 AND val2 <= name2))
|
||||
#
|
||||
# which is equivalent to (val1, val2) < (name1, name2)
|
||||
|
||||
def upper_bound(token, engine, inclusive=True):
|
||||
inclusive = "=" if inclusive else ""
|
||||
if token.topological is None:
|
||||
return "(%d >%s %s)" % (token.stream, inclusive, "stream_ordering")
|
||||
else:
|
||||
if isinstance(engine, PostgresEngine):
|
||||
# Postgres doesn't optimise ``(x > a) OR (x=a AND y>b)`` as well
|
||||
# as it optimises ``(x,y) > (a,b)`` on multicolumn indexes. So we
|
||||
# use the later form when running against postgres.
|
||||
return "((%d,%d) >%s (%s,%s))" % (
|
||||
token.topological,
|
||||
token.stream,
|
||||
inclusive,
|
||||
"topological_ordering",
|
||||
"stream_ordering",
|
||||
)
|
||||
return "(%d > %s OR (%d = %s AND %d >%s %s))" % (
|
||||
token.topological,
|
||||
"topological_ordering",
|
||||
token.topological,
|
||||
"topological_ordering",
|
||||
token.stream,
|
||||
inclusive,
|
||||
"stream_ordering",
|
||||
return """(
|
||||
{val1:d} {strict_bound} {name1}
|
||||
OR ({val1:d} = {name1} AND {val2:d} {bound} {name2})
|
||||
)""".format(
|
||||
name1=name1,
|
||||
val1=val1,
|
||||
name2=name2,
|
||||
val2=val2,
|
||||
strict_bound=bound[0], # The first bound must always be strict equality here
|
||||
bound=bound,
|
||||
)
|
||||
|
||||
|
||||
|
@ -762,19 +838,15 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
args = [False, room_id]
|
||||
if direction == 'b':
|
||||
order = "DESC"
|
||||
bounds = upper_bound(from_token, self.database_engine)
|
||||
if to_token:
|
||||
bounds = "%s AND %s" % (
|
||||
bounds,
|
||||
lower_bound(to_token, self.database_engine),
|
||||
)
|
||||
else:
|
||||
order = "ASC"
|
||||
bounds = lower_bound(from_token, self.database_engine)
|
||||
if to_token:
|
||||
bounds = "%s AND %s" % (
|
||||
bounds,
|
||||
upper_bound(to_token, self.database_engine),
|
||||
|
||||
bounds = generate_pagination_where_clause(
|
||||
direction=direction,
|
||||
column_names=("topological_ordering", "stream_ordering"),
|
||||
from_token=from_token,
|
||||
to_token=to_token,
|
||||
engine=self.database_engine,
|
||||
)
|
||||
|
||||
filter_clause, filter_args = filter_to_clause(event_filter)
|
||||
|
|
Loading…
Reference in a new issue