mirror of
https://mau.dev/maunium/synapse.git
synced 2024-11-16 15:01:23 +01:00
Refactor the notifier.wait_for_events code to be clearer. Add _NotifierUserStream.new_listener that accpets a token to avoid races.
This commit is contained in:
parent
050ebccf30
commit
22049ea700
3 changed files with 72 additions and 69 deletions
|
@ -16,7 +16,7 @@
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
from synapse.util.async import run_on_reactor
|
from synapse.util.async import run_on_reactor, ObservableDeferred
|
||||||
from synapse.types import StreamToken
|
from synapse.types import StreamToken
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
|
|
||||||
|
@ -45,20 +45,16 @@ class _NotificationListener(object):
|
||||||
The events stream handler will have yielded to the deferred, so to
|
The events stream handler will have yielded to the deferred, so to
|
||||||
notify the handler it is sufficient to resolve the deferred.
|
notify the handler it is sufficient to resolve the deferred.
|
||||||
"""
|
"""
|
||||||
|
__slots__ = ["deferred"]
|
||||||
|
|
||||||
def __init__(self, deferred):
|
def __init__(self, deferred):
|
||||||
self.deferred = deferred
|
object.__setattr__(self, "deferred", deferred)
|
||||||
|
|
||||||
def notified(self):
|
def __getattr__(self, name):
|
||||||
return self.deferred.called
|
return getattr(self.deferred, name)
|
||||||
|
|
||||||
def notify(self, token):
|
def __setattr__(self, name, value):
|
||||||
""" Inform whoever is listening about the new events.
|
setattr(self.deferred, name, value)
|
||||||
"""
|
|
||||||
try:
|
|
||||||
self.deferred.callback(token)
|
|
||||||
except defer.AlreadyCalledError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class _NotifierUserStream(object):
|
class _NotifierUserStream(object):
|
||||||
|
@ -75,11 +71,12 @@ class _NotifierUserStream(object):
|
||||||
appservice=None):
|
appservice=None):
|
||||||
self.user = str(user)
|
self.user = str(user)
|
||||||
self.appservice = appservice
|
self.appservice = appservice
|
||||||
self.listeners = set()
|
|
||||||
self.rooms = set(rooms)
|
self.rooms = set(rooms)
|
||||||
self.current_token = current_token
|
self.current_token = current_token
|
||||||
self.last_notified_ms = time_now_ms
|
self.last_notified_ms = time_now_ms
|
||||||
|
|
||||||
|
self.notify_deferred = ObservableDeferred(defer.Deferred())
|
||||||
|
|
||||||
def notify(self, stream_key, stream_id, time_now_ms):
|
def notify(self, stream_key, stream_id, time_now_ms):
|
||||||
"""Notify any listeners for this user of a new event from an
|
"""Notify any listeners for this user of a new event from an
|
||||||
event source.
|
event source.
|
||||||
|
@ -91,12 +88,10 @@ class _NotifierUserStream(object):
|
||||||
self.current_token = self.current_token.copy_and_advance(
|
self.current_token = self.current_token.copy_and_advance(
|
||||||
stream_key, stream_id
|
stream_key, stream_id
|
||||||
)
|
)
|
||||||
if self.listeners:
|
|
||||||
self.last_notified_ms = time_now_ms
|
self.last_notified_ms = time_now_ms
|
||||||
listeners = self.listeners
|
noify_deferred = self.notify_deferred
|
||||||
self.listeners = set()
|
self.notify_deferred = ObservableDeferred(defer.Deferred())
|
||||||
for listener in listeners:
|
noify_deferred.callback(self.current_token)
|
||||||
listener.notify(self.current_token)
|
|
||||||
|
|
||||||
def remove(self, notifier):
|
def remove(self, notifier):
|
||||||
""" Remove this listener from all the indexes in the Notifier
|
""" Remove this listener from all the indexes in the Notifier
|
||||||
|
@ -114,6 +109,18 @@ class _NotifierUserStream(object):
|
||||||
self.appservice, set()
|
self.appservice, set()
|
||||||
).discard(self)
|
).discard(self)
|
||||||
|
|
||||||
|
def count_listeners(self):
|
||||||
|
return len(self.noify_deferred.observers())
|
||||||
|
|
||||||
|
def new_listener(self, token):
|
||||||
|
"""Returns a deferred that is resolved when there is a new token
|
||||||
|
greater than the given token.
|
||||||
|
"""
|
||||||
|
if self.current_token.is_after(token):
|
||||||
|
return _NotificationListener(defer.succeed(self.current_token))
|
||||||
|
else:
|
||||||
|
return _NotificationListener(self.notify_deferred.observe())
|
||||||
|
|
||||||
|
|
||||||
class Notifier(object):
|
class Notifier(object):
|
||||||
""" This class is responsible for notifying any listeners when there are
|
""" This class is responsible for notifying any listeners when there are
|
||||||
|
@ -158,7 +165,7 @@ class Notifier(object):
|
||||||
for x in self.appservice_to_user_streams.values():
|
for x in self.appservice_to_user_streams.values():
|
||||||
all_user_streams |= x
|
all_user_streams |= x
|
||||||
|
|
||||||
return sum(len(stream.listeners) for stream in all_user_streams)
|
return sum(stream.count_listeners() for stream in all_user_streams)
|
||||||
metrics.register_callback("listeners", count_listeners)
|
metrics.register_callback("listeners", count_listeners)
|
||||||
|
|
||||||
metrics.register_callback(
|
metrics.register_callback(
|
||||||
|
@ -286,10 +293,6 @@ class Notifier(object):
|
||||||
"""Wait until the callback returns a non empty response or the
|
"""Wait until the callback returns a non empty response or the
|
||||||
timeout fires.
|
timeout fires.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
deferred = defer.Deferred()
|
|
||||||
time_now_ms = self.clock.time_msec()
|
|
||||||
|
|
||||||
user = str(user)
|
user = str(user)
|
||||||
user_stream = self.user_to_user_stream.get(user)
|
user_stream = self.user_to_user_stream.get(user)
|
||||||
if user_stream is None:
|
if user_stream is None:
|
||||||
|
@ -302,54 +305,38 @@ class Notifier(object):
|
||||||
rooms=rooms,
|
rooms=rooms,
|
||||||
appservice=appservice,
|
appservice=appservice,
|
||||||
current_token=current_token,
|
current_token=current_token,
|
||||||
time_now_ms=time_now_ms,
|
time_now_ms=self.clock.time_msec(),
|
||||||
)
|
)
|
||||||
self._register_with_keys(user_stream)
|
self._register_with_keys(user_stream)
|
||||||
else:
|
|
||||||
current_token = user_stream.current_token
|
|
||||||
|
|
||||||
result = None
|
result = None
|
||||||
if current_token.is_after(from_token):
|
|
||||||
result = yield callback(from_token, current_token)
|
|
||||||
|
|
||||||
if result:
|
|
||||||
defer.returnValue(result)
|
|
||||||
|
|
||||||
if timeout:
|
if timeout:
|
||||||
timer = [None]
|
listener = None
|
||||||
listeners = []
|
timer = self.clock.call_later(
|
||||||
timed_out = [False]
|
timeout/1000., lambda: listener.cancel()
|
||||||
|
)
|
||||||
|
|
||||||
def notify_listeners():
|
prev_token = from_token
|
||||||
user_stream.listeners.difference_update(listeners)
|
while not result:
|
||||||
for listener in listeners:
|
|
||||||
listener.notify(current_token)
|
|
||||||
del listeners[:]
|
|
||||||
|
|
||||||
def _timeout_listener():
|
|
||||||
timed_out[0] = True
|
|
||||||
timer[0] = None
|
|
||||||
notify_listeners()
|
|
||||||
|
|
||||||
# We create multiple notification listeners so we have to manage
|
|
||||||
# canceling the timeout ourselves.
|
|
||||||
timer[0] = self.clock.call_later(timeout/1000., _timeout_listener)
|
|
||||||
|
|
||||||
while not result and not timed_out[0]:
|
|
||||||
deferred = defer.Deferred()
|
|
||||||
notify_listeners()
|
|
||||||
listeners.append(_NotificationListener(deferred))
|
|
||||||
user_stream.listeners.update(listeners)
|
|
||||||
new_token = yield deferred
|
|
||||||
|
|
||||||
result = yield callback(current_token, new_token)
|
|
||||||
current_token = new_token
|
|
||||||
|
|
||||||
if timer[0] is not None:
|
|
||||||
try:
|
try:
|
||||||
self.clock.cancel_call_later(timer[0])
|
# We need to start listening to the streams *before* doing
|
||||||
except:
|
# the callback, as otherwise we may miss something.
|
||||||
logger.exception("Failed to cancel notifer timer")
|
current_token = user_stream.current_token
|
||||||
|
|
||||||
|
result = yield callback(prev_token, current_token)
|
||||||
|
if result:
|
||||||
|
break
|
||||||
|
|
||||||
|
prev_token = current_token
|
||||||
|
listener = user_stream.new_listener(prev_token)
|
||||||
|
yield listener.deferred
|
||||||
|
except defer.CancelledError:
|
||||||
|
break
|
||||||
|
|
||||||
|
self.clock.cancel_call_later(timer, ignore_errs=True)
|
||||||
|
else:
|
||||||
|
current_token = user_stream.current_token
|
||||||
|
result = yield callback(from_token, current_token)
|
||||||
|
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
|
||||||
|
@ -367,6 +354,9 @@ class Notifier(object):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def check_for_updates(before_token, after_token):
|
def check_for_updates(before_token, after_token):
|
||||||
|
if not after_token.is_after(before_token):
|
||||||
|
defer.returnValue(None)
|
||||||
|
|
||||||
events = []
|
events = []
|
||||||
end_token = from_token
|
end_token = from_token
|
||||||
for name, source in self.event_sources.sources.items():
|
for name, source in self.event_sources.sources.items():
|
||||||
|
@ -401,7 +391,7 @@ class Notifier(object):
|
||||||
expired_streams = []
|
expired_streams = []
|
||||||
expire_before_ts = time_now_ms - self.UNUSED_STREAM_EXPIRY_MS
|
expire_before_ts = time_now_ms - self.UNUSED_STREAM_EXPIRY_MS
|
||||||
for stream in self.user_to_user_stream.values():
|
for stream in self.user_to_user_stream.values():
|
||||||
if stream.listeners:
|
if stream.count_listeners():
|
||||||
continue
|
continue
|
||||||
if stream.last_notified_ms < expire_before_ts:
|
if stream.last_notified_ms < expire_before_ts:
|
||||||
expired_streams.append(stream)
|
expired_streams.append(stream)
|
||||||
|
|
|
@ -91,8 +91,12 @@ class Clock(object):
|
||||||
with PreserveLoggingContext():
|
with PreserveLoggingContext():
|
||||||
return reactor.callLater(delay, wrapped_callback, *args, **kwargs)
|
return reactor.callLater(delay, wrapped_callback, *args, **kwargs)
|
||||||
|
|
||||||
def cancel_call_later(self, timer):
|
def cancel_call_later(self, timer, ignore_errs=False):
|
||||||
|
try:
|
||||||
timer.cancel()
|
timer.cancel()
|
||||||
|
except:
|
||||||
|
if not ignore_errs:
|
||||||
|
raise
|
||||||
|
|
||||||
def time_bound_deferred(self, given_deferred, time_out):
|
def time_bound_deferred(self, given_deferred, time_out):
|
||||||
if given_deferred.called:
|
if given_deferred.called:
|
||||||
|
|
|
@ -45,7 +45,7 @@ class ObservableDeferred(object):
|
||||||
def __init__(self, deferred, consumeErrors=False):
|
def __init__(self, deferred, consumeErrors=False):
|
||||||
object.__setattr__(self, "_deferred", deferred)
|
object.__setattr__(self, "_deferred", deferred)
|
||||||
object.__setattr__(self, "_result", None)
|
object.__setattr__(self, "_result", None)
|
||||||
object.__setattr__(self, "_observers", [])
|
object.__setattr__(self, "_observers", set())
|
||||||
|
|
||||||
def callback(r):
|
def callback(r):
|
||||||
self._result = (True, r)
|
self._result = (True, r)
|
||||||
|
@ -74,12 +74,21 @@ class ObservableDeferred(object):
|
||||||
def observe(self):
|
def observe(self):
|
||||||
if not self._result:
|
if not self._result:
|
||||||
d = defer.Deferred()
|
d = defer.Deferred()
|
||||||
self._observers.append(d)
|
|
||||||
|
def remove(r):
|
||||||
|
self._observers.discard(d)
|
||||||
|
return r
|
||||||
|
d.addBoth(remove)
|
||||||
|
|
||||||
|
self._observers.add(d)
|
||||||
return d
|
return d
|
||||||
else:
|
else:
|
||||||
success, res = self._result
|
success, res = self._result
|
||||||
return defer.succeed(res) if success else defer.fail(res)
|
return defer.succeed(res) if success else defer.fail(res)
|
||||||
|
|
||||||
|
def observers(self):
|
||||||
|
return self._observers
|
||||||
|
|
||||||
def __getattr__(self, name):
|
def __getattr__(self, name):
|
||||||
return getattr(self._deferred, name)
|
return getattr(self._deferred, name)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue