forked from MirrorHub/synapse
Generics for ObservableDeferred
(#10491)
Now that `Deferred` is a generic class, let's update `ObeservableDeferred` to follow suit.
This commit is contained in:
parent
d0b294ad97
commit
858363d0b7
4 changed files with 15 additions and 9 deletions
1
changelog.d/10491.misc
Normal file
1
changelog.d/10491.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Improve type annotations for `ObservableDeferred`.
|
|
@ -111,8 +111,9 @@ class _NotifierUserStream:
|
||||||
self.last_notified_token = current_token
|
self.last_notified_token = current_token
|
||||||
self.last_notified_ms = time_now_ms
|
self.last_notified_ms = time_now_ms
|
||||||
|
|
||||||
with PreserveLoggingContext():
|
self.notify_deferred: ObservableDeferred[StreamToken] = ObservableDeferred(
|
||||||
self.notify_deferred = ObservableDeferred(defer.Deferred())
|
defer.Deferred()
|
||||||
|
)
|
||||||
|
|
||||||
def notify(
|
def notify(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -170,7 +170,9 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
|
||||||
end_item = queue[-1]
|
end_item = queue[-1]
|
||||||
else:
|
else:
|
||||||
# need to make a new queue item
|
# need to make a new queue item
|
||||||
deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
|
deferred: ObservableDeferred[_PersistResult] = ObservableDeferred(
|
||||||
|
defer.Deferred(), consumeErrors=True
|
||||||
|
)
|
||||||
|
|
||||||
end_item = _EventPersistQueueItem(
|
end_item = _EventPersistQueueItem(
|
||||||
events_and_contexts=[],
|
events_and_contexts=[],
|
||||||
|
|
|
@ -23,6 +23,7 @@ from typing import (
|
||||||
Awaitable,
|
Awaitable,
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
|
Generic,
|
||||||
Hashable,
|
Hashable,
|
||||||
Iterable,
|
Iterable,
|
||||||
List,
|
List,
|
||||||
|
@ -39,6 +40,7 @@ from twisted.internet import defer
|
||||||
from twisted.internet.defer import CancelledError
|
from twisted.internet.defer import CancelledError
|
||||||
from twisted.internet.interfaces import IReactorTime
|
from twisted.internet.interfaces import IReactorTime
|
||||||
from twisted.python import failure
|
from twisted.python import failure
|
||||||
|
from twisted.python.failure import Failure
|
||||||
|
|
||||||
from synapse.logging.context import (
|
from synapse.logging.context import (
|
||||||
PreserveLoggingContext,
|
PreserveLoggingContext,
|
||||||
|
@ -52,7 +54,7 @@ logger = logging.getLogger(__name__)
|
||||||
_T = TypeVar("_T")
|
_T = TypeVar("_T")
|
||||||
|
|
||||||
|
|
||||||
class ObservableDeferred:
|
class ObservableDeferred(Generic[_T]):
|
||||||
"""Wraps a deferred object so that we can add observer deferreds. These
|
"""Wraps a deferred object so that we can add observer deferreds. These
|
||||||
observer deferreds do not affect the callback chain of the original
|
observer deferreds do not affect the callback chain of the original
|
||||||
deferred.
|
deferred.
|
||||||
|
@ -70,7 +72,7 @@ class ObservableDeferred:
|
||||||
|
|
||||||
__slots__ = ["_deferred", "_observers", "_result"]
|
__slots__ = ["_deferred", "_observers", "_result"]
|
||||||
|
|
||||||
def __init__(self, deferred: defer.Deferred, consumeErrors: bool = False):
|
def __init__(self, deferred: "defer.Deferred[_T]", consumeErrors: bool = False):
|
||||||
object.__setattr__(self, "_deferred", deferred)
|
object.__setattr__(self, "_deferred", deferred)
|
||||||
object.__setattr__(self, "_result", None)
|
object.__setattr__(self, "_result", None)
|
||||||
object.__setattr__(self, "_observers", set())
|
object.__setattr__(self, "_observers", set())
|
||||||
|
@ -115,7 +117,7 @@ class ObservableDeferred:
|
||||||
|
|
||||||
deferred.addCallbacks(callback, errback)
|
deferred.addCallbacks(callback, errback)
|
||||||
|
|
||||||
def observe(self) -> defer.Deferred:
|
def observe(self) -> "defer.Deferred[_T]":
|
||||||
"""Observe the underlying deferred.
|
"""Observe the underlying deferred.
|
||||||
|
|
||||||
This returns a brand new deferred that is resolved when the underlying
|
This returns a brand new deferred that is resolved when the underlying
|
||||||
|
@ -123,7 +125,7 @@ class ObservableDeferred:
|
||||||
effect the underlying deferred.
|
effect the underlying deferred.
|
||||||
"""
|
"""
|
||||||
if not self._result:
|
if not self._result:
|
||||||
d: "defer.Deferred[Any]" = defer.Deferred()
|
d: "defer.Deferred[_T]" = defer.Deferred()
|
||||||
|
|
||||||
def remove(r):
|
def remove(r):
|
||||||
self._observers.discard(d)
|
self._observers.discard(d)
|
||||||
|
@ -137,7 +139,7 @@ class ObservableDeferred:
|
||||||
success, res = self._result
|
success, res = self._result
|
||||||
return defer.succeed(res) if success else defer.fail(res)
|
return defer.succeed(res) if success else defer.fail(res)
|
||||||
|
|
||||||
def observers(self) -> List[defer.Deferred]:
|
def observers(self) -> "List[defer.Deferred[_T]]":
|
||||||
return self._observers
|
return self._observers
|
||||||
|
|
||||||
def has_called(self) -> bool:
|
def has_called(self) -> bool:
|
||||||
|
@ -146,7 +148,7 @@ class ObservableDeferred:
|
||||||
def has_succeeded(self) -> bool:
|
def has_succeeded(self) -> bool:
|
||||||
return self._result is not None and self._result[0] is True
|
return self._result is not None and self._result[0] is True
|
||||||
|
|
||||||
def get_result(self) -> Any:
|
def get_result(self) -> Union[_T, Failure]:
|
||||||
return self._result[1]
|
return self._result[1]
|
||||||
|
|
||||||
def __getattr__(self, name: str) -> Any:
|
def __getattr__(self, name: str) -> Any:
|
||||||
|
|
Loading…
Reference in a new issue