0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-16 23:13:50 +01:00

Fix assertion to stop transaction queue getting wedged

... and update some docstrings to correctly reflect the types being used.

get_new_device_msgs_for_remote can return a long under some circumstances,
which was being stored in last_device_list_stream_id_by_dest, and was then
upsetting things on the next loop.
This commit is contained in:
Richard van der Hoff 2017-03-15 12:16:55 +00:00
parent 3b2dd1b3c2
commit 29ed09e80a
6 changed files with 29 additions and 5 deletions

View file

@ -99,7 +99,12 @@ class TransactionQueue(object):
# destination -> list of tuple(failure, deferred) # destination -> list of tuple(failure, deferred)
self.pending_failures_by_dest = {} self.pending_failures_by_dest = {}
# destination -> stream_id of last successfully sent to-device message.
# NB: may be a long or an int.
self.last_device_stream_id_by_dest = {} self.last_device_stream_id_by_dest = {}
# destination -> stream_id of last successfully sent device list
# update.
self.last_device_list_stream_id_by_dest = {} self.last_device_list_stream_id_by_dest = {}
# HACK to get unique tx id # HACK to get unique tx id

View file

@ -27,4 +27,9 @@ class SlavedIdTracker(object):
self._current = (max if self.step > 0 else min)(self._current, new_id) self._current = (max if self.step > 0 else min)(self._current, new_id)
def get_current_token(self): def get_current_token(self):
"""
Returns:
int
"""
return self._current return self._current

View file

@ -357,12 +357,12 @@ class DeviceInboxStore(BackgroundUpdateStore):
""" """
Args: Args:
destination(str): The name of the remote server. destination(str): The name of the remote server.
last_stream_id(int): The last position of the device message stream last_stream_id(int|long): The last position of the device message stream
that the server sent up to. that the server sent up to.
current_stream_id(int): The current position of the device current_stream_id(int|long): The current position of the device
message stream. message stream.
Returns: Returns:
Deferred ([dict], int): List of messages for the device and where Deferred ([dict], int|long): List of messages for the device and where
in the stream the messages got to. in the stream the messages got to.
""" """

View file

@ -308,7 +308,7 @@ class DeviceStore(SQLBaseStore):
"""Get stream of updates to send to remote servers """Get stream of updates to send to remote servers
Returns: Returns:
(now_stream_id, [ { updates }, .. ]) (int, list[dict]): current stream id and list of updates
""" """
now_stream_id = self._device_list_id_gen.get_current_token() now_stream_id = self._device_list_id_gen.get_current_token()

View file

@ -30,6 +30,17 @@ class IdGenerator(object):
def _load_current_id(db_conn, table, column, step=1): def _load_current_id(db_conn, table, column, step=1):
"""
Args:
db_conn (object):
table (str):
column (str):
step (int):
Returns:
int
"""
cur = db_conn.cursor() cur = db_conn.cursor()
if step == 1: if step == 1:
cur.execute("SELECT MAX(%s) FROM %s" % (column, table,)) cur.execute("SELECT MAX(%s) FROM %s" % (column, table,))
@ -131,6 +142,9 @@ class StreamIdGenerator(object):
def get_current_token(self): def get_current_token(self):
"""Returns the maximum stream id such that all stream ids less than or """Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted. equal to it have been successfully persisted.
Returns:
int
""" """
with self._lock: with self._lock:
if self._unfinished_ids: if self._unfinished_ids:

View file

@ -50,7 +50,7 @@ class StreamChangeCache(object):
def has_entity_changed(self, entity, stream_pos): def has_entity_changed(self, entity, stream_pos):
"""Returns True if the entity may have been updated since stream_pos """Returns True if the entity may have been updated since stream_pos
""" """
assert type(stream_pos) is int assert type(stream_pos) is int or type(stream_pos) is long
if stream_pos < self._earliest_known_stream_pos: if stream_pos < self._earliest_known_stream_pos:
self.metrics.inc_misses() self.metrics.inc_misses()