mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-13 12:53:38 +01:00
ObservableDeferred: run observers in order (#11229)
This commit is contained in:
parent
93aa670642
commit
46d0937447
4 changed files with 88 additions and 20 deletions
1
changelog.d/11229.misc
Normal file
1
changelog.d/11229.misc
Normal file
|
@ -0,0 +1 @@
|
|||
`ObservableDeferred`: run registered observers in order.
|
|
@ -22,11 +22,11 @@ from typing import (
|
|||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Collection,
|
||||
Dict,
|
||||
Generic,
|
||||
Hashable,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
TypeVar,
|
||||
|
@ -76,12 +76,17 @@ class ObservableDeferred(Generic[_T]):
|
|||
def __init__(self, deferred: "defer.Deferred[_T]", consumeErrors: bool = False):
|
||||
object.__setattr__(self, "_deferred", deferred)
|
||||
object.__setattr__(self, "_result", None)
|
||||
object.__setattr__(self, "_observers", set())
|
||||
object.__setattr__(self, "_observers", [])
|
||||
|
||||
def callback(r):
|
||||
object.__setattr__(self, "_result", (True, r))
|
||||
while self._observers:
|
||||
observer = self._observers.pop()
|
||||
|
||||
# once we have set _result, no more entries will be added to _observers,
|
||||
# so it's safe to replace it with the empty tuple.
|
||||
observers = self._observers
|
||||
object.__setattr__(self, "_observers", ())
|
||||
|
||||
for observer in observers:
|
||||
try:
|
||||
observer.callback(r)
|
||||
except Exception as e:
|
||||
|
@ -95,12 +100,16 @@ class ObservableDeferred(Generic[_T]):
|
|||
|
||||
def errback(f):
|
||||
object.__setattr__(self, "_result", (False, f))
|
||||
while self._observers:
|
||||
|
||||
# once we have set _result, no more entries will be added to _observers,
|
||||
# so it's safe to replace it with the empty tuple.
|
||||
observers = self._observers
|
||||
object.__setattr__(self, "_observers", ())
|
||||
|
||||
for observer in observers:
|
||||
# This is a little bit of magic to correctly propagate stack
|
||||
# traces when we `await` on one of the observer deferreds.
|
||||
f.value.__failure__ = f
|
||||
|
||||
observer = self._observers.pop()
|
||||
try:
|
||||
observer.errback(f)
|
||||
except Exception as e:
|
||||
|
@ -127,20 +136,13 @@ class ObservableDeferred(Generic[_T]):
|
|||
"""
|
||||
if not self._result:
|
||||
d: "defer.Deferred[_T]" = defer.Deferred()
|
||||
|
||||
def remove(r):
|
||||
self._observers.discard(d)
|
||||
return r
|
||||
|
||||
d.addBoth(remove)
|
||||
|
||||
self._observers.add(d)
|
||||
self._observers.append(d)
|
||||
return d
|
||||
else:
|
||||
success, res = self._result
|
||||
return defer.succeed(res) if success else defer.fail(res)
|
||||
|
||||
def observers(self) -> "List[defer.Deferred[_T]]":
|
||||
def observers(self) -> "Collection[defer.Deferred[_T]]":
|
||||
return self._observers
|
||||
|
||||
def has_called(self) -> bool:
|
||||
|
|
|
@ -47,9 +47,7 @@ class DeferredCacheTestCase(TestCase):
|
|||
self.assertTrue(set_d.called)
|
||||
return r
|
||||
|
||||
# TODO: Actually ObservableDeferred *doesn't* run its tests in order on py3.8.
|
||||
# maybe we should fix that?
|
||||
# get_d.addCallback(check1)
|
||||
get_d.addCallback(check1)
|
||||
|
||||
# now fire off all the deferreds
|
||||
origin_d.callback(99)
|
||||
|
|
|
@ -21,11 +21,78 @@ from synapse.logging.context import (
|
|||
PreserveLoggingContext,
|
||||
current_context,
|
||||
)
|
||||
from synapse.util.async_helpers import timeout_deferred
|
||||
from synapse.util.async_helpers import ObservableDeferred, timeout_deferred
|
||||
|
||||
from tests.unittest import TestCase
|
||||
|
||||
|
||||
class ObservableDeferredTest(TestCase):
|
||||
def test_succeed(self):
|
||||
origin_d = Deferred()
|
||||
observable = ObservableDeferred(origin_d)
|
||||
|
||||
observer1 = observable.observe()
|
||||
observer2 = observable.observe()
|
||||
|
||||
self.assertFalse(observer1.called)
|
||||
self.assertFalse(observer2.called)
|
||||
|
||||
# check the first observer is called first
|
||||
def check_called_first(res):
|
||||
self.assertFalse(observer2.called)
|
||||
return res
|
||||
|
||||
observer1.addBoth(check_called_first)
|
||||
|
||||
# store the results
|
||||
results = [None, None]
|
||||
|
||||
def check_val(res, idx):
|
||||
results[idx] = res
|
||||
return res
|
||||
|
||||
observer1.addCallback(check_val, 0)
|
||||
observer2.addCallback(check_val, 1)
|
||||
|
||||
origin_d.callback(123)
|
||||
self.assertEqual(results[0], 123, "observer 1 callback result")
|
||||
self.assertEqual(results[1], 123, "observer 2 callback result")
|
||||
|
||||
def test_failure(self):
|
||||
origin_d = Deferred()
|
||||
observable = ObservableDeferred(origin_d, consumeErrors=True)
|
||||
|
||||
observer1 = observable.observe()
|
||||
observer2 = observable.observe()
|
||||
|
||||
self.assertFalse(observer1.called)
|
||||
self.assertFalse(observer2.called)
|
||||
|
||||
# check the first observer is called first
|
||||
def check_called_first(res):
|
||||
self.assertFalse(observer2.called)
|
||||
return res
|
||||
|
||||
observer1.addBoth(check_called_first)
|
||||
|
||||
# store the results
|
||||
results = [None, None]
|
||||
|
||||
def check_val(res, idx):
|
||||
results[idx] = res
|
||||
return None
|
||||
|
||||
observer1.addErrback(check_val, 0)
|
||||
observer2.addErrback(check_val, 1)
|
||||
|
||||
try:
|
||||
raise Exception("gah!")
|
||||
except Exception as e:
|
||||
origin_d.errback(e)
|
||||
self.assertEqual(str(results[0].value), "gah!", "observer 1 errback result")
|
||||
self.assertEqual(str(results[1].value), "gah!", "observer 2 errback result")
|
||||
|
||||
|
||||
class TimeoutDeferredTest(TestCase):
|
||||
def setUp(self):
|
||||
self.clock = Clock()
|
Loading…
Reference in a new issue