Refactor pagination DB API to return concrete type

This makes it easier to document what is being returned by the storage
functions and what some functions expect as arguments.
This commit is contained in:
Erik Johnston 2018-05-09 11:18:23 +01:00
parent 7dd13415db
commit 05e0a2462c

View file

@ -47,6 +47,7 @@ import abc
import logging import logging
from six.moves import range from six.moves import range
from collections import namedtuple
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -59,6 +60,12 @@ _STREAM_TOKEN = "stream"
_TOPOLOGICAL_TOKEN = "topological" _TOPOLOGICAL_TOKEN = "topological"
# Used as return values for pagination APIs
_EventDictReturn = namedtuple("_EventDictReturn", (
"event_id", "topological_ordering", "stream_ordering",
))
def lower_bound(token, engine, inclusive=False): def lower_bound(token, engine, inclusive=False):
inclusive = "=" if inclusive else "" inclusive = "=" if inclusive else ""
if token.topological is None: if token.topological is None:
@ -256,9 +263,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
" ORDER BY stream_ordering %s LIMIT ?" " ORDER BY stream_ordering %s LIMIT ?"
) % (order,) ) % (order,)
txn.execute(sql, (room_id, from_id, to_id, limit)) txn.execute(sql, (room_id, from_id, to_id, limit))
rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
else: else:
sql = ( sql = (
"SELECT event_id, stream_ordering FROM events WHERE" "SELECT event_id, topological_ordering, stream_ordering"
" FROM events"
" WHERE"
" room_id = ?" " room_id = ?"
" AND not outlier" " AND not outlier"
" AND stream_ordering <= ?" " AND stream_ordering <= ?"
@ -266,14 +277,14 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
) % (order, order,) ) % (order, order,)
txn.execute(sql, (room_id, to_id, limit)) txn.execute(sql, (room_id, to_id, limit))
rows = self.cursor_to_dict(txn) rows = [_EventDictReturn(row[0], row[1], row[2]) for row in txn]
return rows return rows
rows = yield self.runInteraction("get_room_events_stream_for_room", f) rows = yield self.runInteraction("get_room_events_stream_for_room", f)
ret = yield self._get_events( ret = yield self._get_events(
[r["event_id"] for r in rows], [r.event_id for r in rows],
get_prev_content=True get_prev_content=True
) )
@ -283,7 +294,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
ret.reverse() ret.reverse()
if rows: if rows:
key = "s%d" % min(r["stream_ordering"] for r in rows) key = "s%d" % min(r.stream_ordering for r in rows)
else: else:
# Assume we didn't get anything because there was nothing to # Assume we didn't get anything because there was nothing to
# get. # get.
@ -330,14 +341,15 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
" ORDER BY stream_ordering ASC" " ORDER BY stream_ordering ASC"
) )
txn.execute(sql, (user_id, to_id,)) txn.execute(sql, (user_id, to_id,))
rows = self.cursor_to_dict(txn)
rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
return rows return rows
rows = yield self.runInteraction("get_membership_changes_for_user", f) rows = yield self.runInteraction("get_membership_changes_for_user", f)
ret = yield self._get_events( ret = yield self._get_events(
[r["event_id"] for r in rows], [r.event_id for r in rows],
get_prev_content=True get_prev_content=True
) )
@ -353,14 +365,14 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
logger.debug("stream before") logger.debug("stream before")
events = yield self._get_events( events = yield self._get_events(
[r["event_id"] for r in rows], [r.event_id for r in rows],
get_prev_content=True get_prev_content=True
) )
logger.debug("stream after") logger.debug("stream after")
self._set_before_and_after(events, rows) self._set_before_and_after(events, rows)
defer.returnValue((events, token)) defer.returnValue((events, (token, end_token)))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_recent_event_ids_for_room(self, room_id, limit, end_token): def get_recent_event_ids_for_room(self, room_id, limit, end_token):
@ -372,15 +384,14 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
end_token (str): The stream token representing now. end_token (str): The stream token representing now.
Returns: Returns:
Deferred[tuple[list[dict], tuple[str, str]]]: Returns a list of Deferred[tuple[list[_EventDictReturn], str]]: Returns a list of
dicts (which include event_ids, etc), and a tuple for _EventDictReturn and a token pointint to the start of the returned
`(start_token, end_token)` representing the range of rows events.
returned. The events returned are in ascending order.
The returned events are in ascending order.
""" """
# Allow a zero limit here, and no-op. # Allow a zero limit here, and no-op.
if limit == 0: if limit == 0:
defer.returnValue(([], (end_token, end_token))) defer.returnValue(([], end_token))
end_token = RoomStreamToken.parse_stream_token(end_token) end_token = RoomStreamToken.parse_stream_token(end_token)
@ -392,7 +403,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
# We want to return the results in ascending order. # We want to return the results in ascending order.
rows.reverse() rows.reverse()
defer.returnValue((rows, (token, str(end_token)))) defer.returnValue((rows, token))
def get_room_event_after_stream_ordering(self, room_id, stream_ordering): def get_room_event_after_stream_ordering(self, room_id, stream_ordering):
"""Gets details of the first event in a room at or after a stream ordering """Gets details of the first event in a room at or after a stream ordering
@ -496,10 +507,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
@staticmethod @staticmethod
def _set_before_and_after(events, rows, topo_order=True): def _set_before_and_after(events, rows, topo_order=True):
"""Inserts ordering information to events' internal metadata from
the DB rows.
Args:
events (list[FrozenEvent])
rows (list[_EventDictReturn])
topo_order (bool): Whether the events were ordered topologically
or by stream ordering
"""
for event, row in zip(events, rows): for event, row in zip(events, rows):
stream = row["stream_ordering"] stream = row.stream_ordering
if topo_order: if topo_order and row.topological_ordering:
topo = event.depth topo = row.topological_ordering
else: else:
topo = None topo = None
internal = event.internal_metadata internal = event.internal_metadata
@ -586,12 +606,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
rows, start_token = self._paginate_room_events_txn( rows, start_token = self._paginate_room_events_txn(
txn, room_id, before_token, direction='b', limit=before_limit, txn, room_id, before_token, direction='b', limit=before_limit,
) )
events_before = [r["event_id"] for r in rows] events_before = [r.event_id for r in rows]
rows, end_token = self._paginate_room_events_txn( rows, end_token = self._paginate_room_events_txn(
txn, room_id, after_token, direction='f', limit=after_limit, txn, room_id, after_token, direction='f', limit=after_limit,
) )
events_after = [r["event_id"] for r in rows] events_after = [r.event_id for r in rows]
return { return {
"before": { "before": {
@ -672,9 +692,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
those that match the filter. those that match the filter.
Returns: Returns:
tuple[list[dict], str]: Returns the results as a list of dicts and tuple[list[_EventDictReturn], str]: Returns the results as a list
a token that points to the end of the result set. The dicts have of _EventDictReturn and a token that points to the end of the
the keys "event_id", "toplogical_ordering" and "stream_ordering". result set.
""" """
# Tokens really represent positions between elements, but we use # Tokens really represent positions between elements, but we use
# the convention of pointing to the event before the gap. Hence # the convention of pointing to the event before the gap. Hence
@ -725,11 +745,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
txn.execute(sql, args) txn.execute(sql, args)
rows = self.cursor_to_dict(txn) rows = [_EventDictReturn(row[0], row[1], row[2]) for row in txn]
if rows: if rows:
topo = rows[-1]["topological_ordering"] topo = rows[-1].topological_ordering
toke = rows[-1]["stream_ordering"] toke = rows[-1].stream_ordering
if direction == 'b': if direction == 'b':
# Tokens are positions between events. # Tokens are positions between events.
# This token points *after* the last event in the chunk. # This token points *after* the last event in the chunk.
@ -764,7 +784,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns: Returns:
tuple[list[dict], str]: Returns the results as a list of dicts and tuple[list[dict], str]: Returns the results as a list of dicts and
a token that points to the end of the result set. The dicts have a token that points to the end of the result set. The dicts have
the keys "event_id", "toplogical_ordering" and "stream_orderign". the keys "event_id", "topological_ordering" and "stream_orderign".
""" """
from_key = RoomStreamToken.parse(from_key) from_key = RoomStreamToken.parse(from_key)
@ -777,7 +797,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
) )
events = yield self._get_events( events = yield self._get_events(
[r["event_id"] for r in rows], [r.event_id for r in rows],
get_prev_content=True get_prev_content=True
) )