0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-15 04:53:53 +01:00

Remove concept of a non-limited stream. (#7011)

This commit is contained in:
Erik Johnston 2020-03-20 14:40:47 +00:00 committed by GitHub
parent caec7d4fa0
commit fdb1344716
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 72 additions and 68 deletions

1
changelog.d/7011.misc Normal file
View file

@ -0,0 +1 @@
Remove concept of a non-limited stream.

View file

@ -747,7 +747,7 @@ class PresenceHandler(object):
return False return False
async def get_all_presence_updates(self, last_id, current_id): async def get_all_presence_updates(self, last_id, current_id, limit):
""" """
Gets a list of presence update rows from between the given stream ids. Gets a list of presence update rows from between the given stream ids.
Each row has: Each row has:
@ -762,7 +762,7 @@ class PresenceHandler(object):
""" """
# TODO(markjh): replicate the unpersisted changes. # TODO(markjh): replicate the unpersisted changes.
# This could use the in-memory stores for recent changes. # This could use the in-memory stores for recent changes.
rows = await self.store.get_all_presence_updates(last_id, current_id) rows = await self.store.get_all_presence_updates(last_id, current_id, limit)
return rows return rows
def notify_new_event(self): def notify_new_event(self):

View file

@ -15,6 +15,7 @@
import logging import logging
from collections import namedtuple from collections import namedtuple
from typing import List
from twisted.internet import defer from twisted.internet import defer
@ -257,7 +258,13 @@ class TypingHandler(object):
"typing_key", self._latest_room_serial, rooms=[member.room_id] "typing_key", self._latest_room_serial, rooms=[member.room_id]
) )
async def get_all_typing_updates(self, last_id, current_id): async def get_all_typing_updates(
self, last_id: int, current_id: int, limit: int
) -> List[dict]:
"""Get up to `limit` typing updates between the given tokens, earliest
updates first.
"""
if last_id == current_id: if last_id == current_id:
return [] return []
@ -275,7 +282,7 @@ class TypingHandler(object):
typing = self._room_typing[room_id] typing = self._room_typing[room_id]
rows.append((serial, room_id, list(typing))) rows.append((serial, room_id, list(typing)))
rows.sort() rows.sort()
return rows return rows[:limit]
def get_current_token(self): def get_current_token(self):
return self._latest_room_serial return self._latest_room_serial

View file

@ -166,11 +166,6 @@ class ReplicationStreamer(object):
self.pending_updates = False self.pending_updates = False
with Measure(self.clock, "repl.stream.get_updates"): with Measure(self.clock, "repl.stream.get_updates"):
# First we tell the streams that they should update their
# current tokens.
for stream in self.streams:
stream.advance_current_token()
all_streams = self.streams all_streams = self.streams
if self._replication_torture_level is not None: if self._replication_torture_level is not None:
@ -180,7 +175,7 @@ class ReplicationStreamer(object):
random.shuffle(all_streams) random.shuffle(all_streams)
for stream in all_streams: for stream in all_streams:
if stream.last_token == stream.upto_token: if stream.last_token == stream.current_token():
continue continue
if self._replication_torture_level: if self._replication_torture_level:
@ -192,7 +187,7 @@ class ReplicationStreamer(object):
"Getting stream: %s: %s -> %s", "Getting stream: %s: %s -> %s",
stream.NAME, stream.NAME,
stream.last_token, stream.last_token,
stream.upto_token, stream.current_token(),
) )
try: try:
updates, current_token = await stream.get_updates() updates, current_token = await stream.get_updates()

View file

@ -17,10 +17,12 @@
import itertools import itertools
import logging import logging
from collections import namedtuple from collections import namedtuple
from typing import Any, List, Optional from typing import Any, List, Optional, Tuple
import attr import attr
from synapse.types import JsonDict
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -119,13 +121,12 @@ class Stream(object):
"""Base class for the streams. """Base class for the streams.
Provides a `get_updates()` function that returns new updates since the last Provides a `get_updates()` function that returns new updates since the last
time it was called up until the point `advance_current_token` was called. time it was called.
""" """
NAME = None # type: str # The name of the stream NAME = None # type: str # The name of the stream
# The type of the row. Used by the default impl of parse_row. # The type of the row. Used by the default impl of parse_row.
ROW_TYPE = None # type: Any ROW_TYPE = None # type: Any
_LIMITED = True # Whether the update function takes a limit
@classmethod @classmethod
def parse_row(cls, row): def parse_row(cls, row):
@ -146,26 +147,15 @@ class Stream(object):
# The token from which we last asked for updates # The token from which we last asked for updates
self.last_token = self.current_token() self.last_token = self.current_token()
# The token that we will get updates up to
self.upto_token = self.current_token()
def advance_current_token(self):
"""Updates `upto_token` to "now", which updates up until which point
get_updates[_since] will fetch rows till.
"""
self.upto_token = self.current_token()
def discard_updates_and_advance(self): def discard_updates_and_advance(self):
"""Called when the stream should advance but the updates would be discarded, """Called when the stream should advance but the updates would be discarded,
e.g. when there are no currently connected workers. e.g. when there are no currently connected workers.
""" """
self.upto_token = self.current_token() self.last_token = self.current_token()
self.last_token = self.upto_token
async def get_updates(self): async def get_updates(self):
"""Gets all updates since the last time this function was called (or """Gets all updates since the last time this function was called (or
since the stream was constructed if it hadn't been called before), since the stream was constructed if it hadn't been called before).
until the `upto_token`
Returns: Returns:
Deferred[Tuple[List[Tuple[int, Any]], int]: Deferred[Tuple[List[Tuple[int, Any]], int]:
@ -178,44 +168,45 @@ class Stream(object):
return updates, current_token return updates, current_token
async def get_updates_since(self, from_token): async def get_updates_since(
self, from_token: int
) -> Tuple[List[Tuple[int, JsonDict]], int]:
"""Like get_updates except allows specifying from when we should """Like get_updates except allows specifying from when we should
stream updates stream updates
Returns: Returns:
Deferred[Tuple[List[Tuple[int, Any]], int]: Resolves to a pair `(updates, new_last_token)`, where `updates` is
Resolves to a pair ``(updates, current_token)``, where ``updates`` is a a list of `(token, row)` entries and `new_last_token` is the new
list of ``(token, row)`` entries. ``row`` will be json-serialised and position in stream.
sent over the replication steam.
""" """
if from_token in ("NOW", "now"):
return [], self.upto_token
current_token = self.upto_token if from_token in ("NOW", "now"):
return [], self.current_token()
current_token = self.current_token()
from_token = int(from_token) from_token = int(from_token)
if from_token == current_token: if from_token == current_token:
return [], current_token return [], current_token
logger.info("get_updates_since: %s", self.__class__) rows = await self.update_function(
if self._LIMITED: from_token, current_token, limit=MAX_EVENTS_BEHIND + 1
rows = await self.update_function( )
from_token, current_token, limit=MAX_EVENTS_BEHIND + 1
)
# never turn more than MAX_EVENTS_BEHIND + 1 into updates. # never turn more than MAX_EVENTS_BEHIND + 1 into updates.
rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1) rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1)
else:
rows = await self.update_function(from_token, current_token)
updates = [(row[0], row[1:]) for row in rows] updates = [(row[0], row[1:]) for row in rows]
# check we didn't get more rows than the limit. # check we didn't get more rows than the limit.
# doing it like this allows the update_function to be a generator. # doing it like this allows the update_function to be a generator.
if self._LIMITED and len(updates) >= MAX_EVENTS_BEHIND: if len(updates) >= MAX_EVENTS_BEHIND:
raise Exception("stream %s has fallen behind" % (self.NAME)) raise Exception("stream %s has fallen behind" % (self.NAME))
# The update function didn't hit the limit, so we must have got all
# the updates to `current_token`, and can return that as our new
# stream position.
return updates, current_token return updates, current_token
def current_token(self): def current_token(self):
@ -227,9 +218,8 @@ class Stream(object):
""" """
raise NotImplementedError() raise NotImplementedError()
def update_function(self, from_token, current_token, limit=None): def update_function(self, from_token, current_token, limit):
"""Get updates between from_token and to_token. If Stream._LIMITED is """Get updates between from_token and to_token.
True then limit is provided, otherwise it's not.
Returns: Returns:
Deferred(list(tuple)): the first entry in the tuple is the token for Deferred(list(tuple)): the first entry in the tuple is the token for
@ -257,7 +247,6 @@ class BackfillStream(Stream):
class PresenceStream(Stream): class PresenceStream(Stream):
NAME = "presence" NAME = "presence"
_LIMITED = False
ROW_TYPE = PresenceStreamRow ROW_TYPE = PresenceStreamRow
def __init__(self, hs): def __init__(self, hs):
@ -272,7 +261,6 @@ class PresenceStream(Stream):
class TypingStream(Stream): class TypingStream(Stream):
NAME = "typing" NAME = "typing"
_LIMITED = False
ROW_TYPE = TypingStreamRow ROW_TYPE = TypingStreamRow
def __init__(self, hs): def __init__(self, hs):
@ -372,7 +360,6 @@ class DeviceListsStream(Stream):
""" """
NAME = "device_lists" NAME = "device_lists"
_LIMITED = False
ROW_TYPE = DeviceListsStreamRow ROW_TYPE = DeviceListsStreamRow
def __init__(self, hs): def __init__(self, hs):
@ -462,7 +449,6 @@ class UserSignatureStream(Stream):
""" """
NAME = "user_signature" NAME = "user_signature"
_LIMITED = False
ROW_TYPE = UserSignatureStreamRow ROW_TYPE = UserSignatureStreamRow
def __init__(self, hs): def __init__(self, hs):

View file

@ -576,7 +576,7 @@ class DeviceWorkerStore(SQLBaseStore):
return set() return set()
async def get_all_device_list_changes_for_remotes( async def get_all_device_list_changes_for_remotes(
self, from_key: int, to_key: int self, from_key: int, to_key: int, limit: int,
) -> List[Tuple[int, str]]: ) -> List[Tuple[int, str]]:
"""Return a list of `(stream_id, entity)` which is the combined list of """Return a list of `(stream_id, entity)` which is the combined list of
changes to devices and which destinations need to be poked. Entity is changes to devices and which destinations need to be poked. Entity is
@ -592,10 +592,16 @@ class DeviceWorkerStore(SQLBaseStore):
SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
) AS e ) AS e
WHERE ? < stream_id AND stream_id <= ? WHERE ? < stream_id AND stream_id <= ?
LIMIT ?
""" """
return await self.db.execute( return await self.db.execute(
"get_all_device_list_changes_for_remotes", None, sql, from_key, to_key "get_all_device_list_changes_for_remotes",
None,
sql,
from_key,
to_key,
limit,
) )
@cached(max_entries=10000) @cached(max_entries=10000)

View file

@ -537,7 +537,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
return result return result
def get_all_user_signature_changes_for_remotes(self, from_key, to_key): def get_all_user_signature_changes_for_remotes(self, from_key, to_key, limit):
"""Return a list of changes from the user signature stream to notify remotes. """Return a list of changes from the user signature stream to notify remotes.
Note that the user signature stream represents when a user signs their Note that the user signature stream represents when a user signs their
device with their user-signing key, which is not published to other device with their user-signing key, which is not published to other
@ -552,13 +552,19 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
Deferred[list[(int,str)]] a list of `(stream_id, user_id)` Deferred[list[(int,str)]] a list of `(stream_id, user_id)`
""" """
sql = """ sql = """
SELECT MAX(stream_id) AS stream_id, from_user_id AS user_id SELECT stream_id, from_user_id AS user_id
FROM user_signature_stream FROM user_signature_stream
WHERE ? < stream_id AND stream_id <= ? WHERE ? < stream_id AND stream_id <= ?
GROUP BY user_id ORDER BY stream_id ASC
LIMIT ?
""" """
return self.db.execute( return self.db.execute(
"get_all_user_signature_changes_for_remotes", None, sql, from_key, to_key "get_all_user_signature_changes_for_remotes",
None,
sql,
from_key,
to_key,
limit,
) )

View file

@ -60,7 +60,7 @@ class PresenceStore(SQLBaseStore):
"status_msg": state.status_msg, "status_msg": state.status_msg,
"currently_active": state.currently_active, "currently_active": state.currently_active,
} }
for state in presence_states for stream_id, state in zip(stream_orderings, presence_states)
], ],
) )
@ -73,19 +73,22 @@ class PresenceStore(SQLBaseStore):
) )
txn.execute(sql + clause, [stream_id] + list(args)) txn.execute(sql + clause, [stream_id] + list(args))
def get_all_presence_updates(self, last_id, current_id): def get_all_presence_updates(self, last_id, current_id, limit):
if last_id == current_id: if last_id == current_id:
return defer.succeed([]) return defer.succeed([])
def get_all_presence_updates_txn(txn): def get_all_presence_updates_txn(txn):
sql = ( sql = """
"SELECT stream_id, user_id, state, last_active_ts," SELECT stream_id, user_id, state, last_active_ts,
" last_federation_update_ts, last_user_sync_ts, status_msg," last_federation_update_ts, last_user_sync_ts,
" currently_active" status_msg,
" FROM presence_stream" currently_active
" WHERE ? < stream_id AND stream_id <= ?" FROM presence_stream
) WHERE ? < stream_id AND stream_id <= ?
txn.execute(sql, (last_id, current_id)) ORDER BY stream_id ASC
LIMIT ?
"""
txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall() return txn.fetchall()
return self.db.runInteraction( return self.db.runInteraction(