mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-14 10:03:47 +01:00
Refactor pagination DB API to return concrete type
This makes it easier to document what is being returned by the storage functions and what some functions expect as arguments.
This commit is contained in:
parent
7dd13415db
commit
05e0a2462c
1 changed files with 48 additions and 28 deletions
|
@ -47,6 +47,7 @@ import abc
|
|||
import logging
|
||||
|
||||
from six.moves import range
|
||||
from collections import namedtuple
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -59,6 +60,12 @@ _STREAM_TOKEN = "stream"
|
|||
_TOPOLOGICAL_TOKEN = "topological"
|
||||
|
||||
|
||||
# Used as return values for pagination APIs
|
||||
_EventDictReturn = namedtuple("_EventDictReturn", (
|
||||
"event_id", "topological_ordering", "stream_ordering",
|
||||
))
|
||||
|
||||
|
||||
def lower_bound(token, engine, inclusive=False):
|
||||
inclusive = "=" if inclusive else ""
|
||||
if token.topological is None:
|
||||
|
@ -256,9 +263,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
" ORDER BY stream_ordering %s LIMIT ?"
|
||||
) % (order,)
|
||||
txn.execute(sql, (room_id, from_id, to_id, limit))
|
||||
|
||||
rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
|
||||
else:
|
||||
sql = (
|
||||
"SELECT event_id, stream_ordering FROM events WHERE"
|
||||
"SELECT event_id, topological_ordering, stream_ordering"
|
||||
" FROM events"
|
||||
" WHERE"
|
||||
" room_id = ?"
|
||||
" AND not outlier"
|
||||
" AND stream_ordering <= ?"
|
||||
|
@ -266,14 +277,14 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
) % (order, order,)
|
||||
txn.execute(sql, (room_id, to_id, limit))
|
||||
|
||||
rows = self.cursor_to_dict(txn)
|
||||
rows = [_EventDictReturn(row[0], row[1], row[2]) for row in txn]
|
||||
|
||||
return rows
|
||||
|
||||
rows = yield self.runInteraction("get_room_events_stream_for_room", f)
|
||||
|
||||
ret = yield self._get_events(
|
||||
[r["event_id"] for r in rows],
|
||||
[r.event_id for r in rows],
|
||||
get_prev_content=True
|
||||
)
|
||||
|
||||
|
@ -283,7 +294,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
ret.reverse()
|
||||
|
||||
if rows:
|
||||
key = "s%d" % min(r["stream_ordering"] for r in rows)
|
||||
key = "s%d" % min(r.stream_ordering for r in rows)
|
||||
else:
|
||||
# Assume we didn't get anything because there was nothing to
|
||||
# get.
|
||||
|
@ -330,14 +341,15 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
" ORDER BY stream_ordering ASC"
|
||||
)
|
||||
txn.execute(sql, (user_id, to_id,))
|
||||
rows = self.cursor_to_dict(txn)
|
||||
|
||||
rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
|
||||
|
||||
return rows
|
||||
|
||||
rows = yield self.runInteraction("get_membership_changes_for_user", f)
|
||||
|
||||
ret = yield self._get_events(
|
||||
[r["event_id"] for r in rows],
|
||||
[r.event_id for r in rows],
|
||||
get_prev_content=True
|
||||
)
|
||||
|
||||
|
@ -353,14 +365,14 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
|
||||
logger.debug("stream before")
|
||||
events = yield self._get_events(
|
||||
[r["event_id"] for r in rows],
|
||||
[r.event_id for r in rows],
|
||||
get_prev_content=True
|
||||
)
|
||||
logger.debug("stream after")
|
||||
|
||||
self._set_before_and_after(events, rows)
|
||||
|
||||
defer.returnValue((events, token))
|
||||
defer.returnValue((events, (token, end_token)))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_recent_event_ids_for_room(self, room_id, limit, end_token):
|
||||
|
@ -372,15 +384,14 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
end_token (str): The stream token representing now.
|
||||
|
||||
Returns:
|
||||
Deferred[tuple[list[dict], tuple[str, str]]]: Returns a list of
|
||||
dicts (which include event_ids, etc), and a tuple for
|
||||
`(start_token, end_token)` representing the range of rows
|
||||
returned.
|
||||
The returned events are in ascending order.
|
||||
Deferred[tuple[list[_EventDictReturn], str]]: Returns a list of
|
||||
_EventDictReturn and a token pointint to the start of the returned
|
||||
events.
|
||||
The events returned are in ascending order.
|
||||
"""
|
||||
# Allow a zero limit here, and no-op.
|
||||
if limit == 0:
|
||||
defer.returnValue(([], (end_token, end_token)))
|
||||
defer.returnValue(([], end_token))
|
||||
|
||||
end_token = RoomStreamToken.parse_stream_token(end_token)
|
||||
|
||||
|
@ -392,7 +403,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
# We want to return the results in ascending order.
|
||||
rows.reverse()
|
||||
|
||||
defer.returnValue((rows, (token, str(end_token))))
|
||||
defer.returnValue((rows, token))
|
||||
|
||||
def get_room_event_after_stream_ordering(self, room_id, stream_ordering):
|
||||
"""Gets details of the first event in a room at or after a stream ordering
|
||||
|
@ -496,10 +507,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
|
||||
@staticmethod
|
||||
def _set_before_and_after(events, rows, topo_order=True):
|
||||
"""Inserts ordering information to events' internal metadata from
|
||||
the DB rows.
|
||||
|
||||
Args:
|
||||
events (list[FrozenEvent])
|
||||
rows (list[_EventDictReturn])
|
||||
topo_order (bool): Whether the events were ordered topologically
|
||||
or by stream ordering
|
||||
"""
|
||||
for event, row in zip(events, rows):
|
||||
stream = row["stream_ordering"]
|
||||
if topo_order:
|
||||
topo = event.depth
|
||||
stream = row.stream_ordering
|
||||
if topo_order and row.topological_ordering:
|
||||
topo = row.topological_ordering
|
||||
else:
|
||||
topo = None
|
||||
internal = event.internal_metadata
|
||||
|
@ -586,12 +606,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
rows, start_token = self._paginate_room_events_txn(
|
||||
txn, room_id, before_token, direction='b', limit=before_limit,
|
||||
)
|
||||
events_before = [r["event_id"] for r in rows]
|
||||
events_before = [r.event_id for r in rows]
|
||||
|
||||
rows, end_token = self._paginate_room_events_txn(
|
||||
txn, room_id, after_token, direction='f', limit=after_limit,
|
||||
)
|
||||
events_after = [r["event_id"] for r in rows]
|
||||
events_after = [r.event_id for r in rows]
|
||||
|
||||
return {
|
||||
"before": {
|
||||
|
@ -672,9 +692,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
those that match the filter.
|
||||
|
||||
Returns:
|
||||
tuple[list[dict], str]: Returns the results as a list of dicts and
|
||||
a token that points to the end of the result set. The dicts have
|
||||
the keys "event_id", "toplogical_ordering" and "stream_ordering".
|
||||
tuple[list[_EventDictReturn], str]: Returns the results as a list
|
||||
of _EventDictReturn and a token that points to the end of the
|
||||
result set.
|
||||
"""
|
||||
# Tokens really represent positions between elements, but we use
|
||||
# the convention of pointing to the event before the gap. Hence
|
||||
|
@ -725,11 +745,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
|
||||
txn.execute(sql, args)
|
||||
|
||||
rows = self.cursor_to_dict(txn)
|
||||
rows = [_EventDictReturn(row[0], row[1], row[2]) for row in txn]
|
||||
|
||||
if rows:
|
||||
topo = rows[-1]["topological_ordering"]
|
||||
toke = rows[-1]["stream_ordering"]
|
||||
topo = rows[-1].topological_ordering
|
||||
toke = rows[-1].stream_ordering
|
||||
if direction == 'b':
|
||||
# Tokens are positions between events.
|
||||
# This token points *after* the last event in the chunk.
|
||||
|
@ -764,7 +784,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
Returns:
|
||||
tuple[list[dict], str]: Returns the results as a list of dicts and
|
||||
a token that points to the end of the result set. The dicts have
|
||||
the keys "event_id", "toplogical_ordering" and "stream_orderign".
|
||||
the keys "event_id", "topological_ordering" and "stream_orderign".
|
||||
"""
|
||||
|
||||
from_key = RoomStreamToken.parse(from_key)
|
||||
|
@ -777,7 +797,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
)
|
||||
|
||||
events = yield self._get_events(
|
||||
[r["event_id"] for r in rows],
|
||||
[r.event_id for r in rows],
|
||||
get_prev_content=True
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in a new issue