forked from MirrorHub/synapse
Another batch of type annotations (#12726)
This commit is contained in:
parent
39bed28b28
commit
aec69d2481
11 changed files with 144 additions and 79 deletions
1
changelog.d/12726.misc
Normal file
1
changelog.d/12726.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Add type annotations to increase the number of modules passing `disallow-untyped-defs`.
|
21
mypy.ini
21
mypy.ini
|
@ -128,15 +128,30 @@ disallow_untyped_defs = True
|
||||||
[mypy-synapse.http.federation.*]
|
[mypy-synapse.http.federation.*]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.http.connectproxyclient]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.http.proxyagent]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
[mypy-synapse.http.request_metrics]
|
[mypy-synapse.http.request_metrics]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
[mypy-synapse.http.server]
|
[mypy-synapse.http.server]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.logging._remote]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
[mypy-synapse.logging.context]
|
[mypy-synapse.logging.context]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.logging.formatter]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.logging.handlers]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
[mypy-synapse.metrics.*]
|
[mypy-synapse.metrics.*]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
@ -166,6 +181,9 @@ disallow_untyped_defs = True
|
||||||
[mypy-synapse.state.*]
|
[mypy-synapse.state.*]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.storage.databases.background_updates]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
[mypy-synapse.storage.databases.main.account_data]
|
[mypy-synapse.storage.databases.main.account_data]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
@ -232,6 +250,9 @@ disallow_untyped_defs = True
|
||||||
[mypy-synapse.streams.*]
|
[mypy-synapse.streams.*]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.types]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
[mypy-synapse.util.*]
|
[mypy-synapse.util.*]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
from canonicaljson import encode_canonical_json
|
from canonicaljson import encode_canonical_json
|
||||||
|
@ -1105,22 +1105,19 @@ class E2eKeysHandler:
|
||||||
# can request over federation
|
# can request over federation
|
||||||
raise NotFoundError("No %s key found for %s" % (key_type, user_id))
|
raise NotFoundError("No %s key found for %s" % (key_type, user_id))
|
||||||
|
|
||||||
(
|
cross_signing_keys = await self._retrieve_cross_signing_keys_for_remote_user(
|
||||||
key,
|
user, key_type
|
||||||
key_id,
|
)
|
||||||
verify_key,
|
if cross_signing_keys is None:
|
||||||
) = await self._retrieve_cross_signing_keys_for_remote_user(user, key_type)
|
|
||||||
|
|
||||||
if key is None:
|
|
||||||
raise NotFoundError("No %s key found for %s" % (key_type, user_id))
|
raise NotFoundError("No %s key found for %s" % (key_type, user_id))
|
||||||
|
|
||||||
return key, key_id, verify_key
|
return cross_signing_keys
|
||||||
|
|
||||||
async def _retrieve_cross_signing_keys_for_remote_user(
|
async def _retrieve_cross_signing_keys_for_remote_user(
|
||||||
self,
|
self,
|
||||||
user: UserID,
|
user: UserID,
|
||||||
desired_key_type: str,
|
desired_key_type: str,
|
||||||
) -> Tuple[Optional[dict], Optional[str], Optional[VerifyKey]]:
|
) -> Optional[Tuple[Dict[str, Any], str, VerifyKey]]:
|
||||||
"""Queries cross-signing keys for a remote user and saves them to the database
|
"""Queries cross-signing keys for a remote user and saves them to the database
|
||||||
|
|
||||||
Only the key specified by `key_type` will be returned, while all retrieved keys
|
Only the key specified by `key_type` will be returned, while all retrieved keys
|
||||||
|
@ -1146,12 +1143,10 @@ class E2eKeysHandler:
|
||||||
type(e),
|
type(e),
|
||||||
e,
|
e,
|
||||||
)
|
)
|
||||||
return None, None, None
|
return None
|
||||||
|
|
||||||
# Process each of the retrieved cross-signing keys
|
# Process each of the retrieved cross-signing keys
|
||||||
desired_key = None
|
desired_key_data = None
|
||||||
desired_key_id = None
|
|
||||||
desired_verify_key = None
|
|
||||||
retrieved_device_ids = []
|
retrieved_device_ids = []
|
||||||
for key_type in ["master", "self_signing"]:
|
for key_type in ["master", "self_signing"]:
|
||||||
key_content = remote_result.get(key_type + "_key")
|
key_content = remote_result.get(key_type + "_key")
|
||||||
|
@ -1196,9 +1191,7 @@ class E2eKeysHandler:
|
||||||
|
|
||||||
# If this is the desired key type, save it and its ID/VerifyKey
|
# If this is the desired key type, save it and its ID/VerifyKey
|
||||||
if key_type == desired_key_type:
|
if key_type == desired_key_type:
|
||||||
desired_key = key_content
|
desired_key_data = key_content, key_id, verify_key
|
||||||
desired_verify_key = verify_key
|
|
||||||
desired_key_id = key_id
|
|
||||||
|
|
||||||
# At the same time, store this key in the db for subsequent queries
|
# At the same time, store this key in the db for subsequent queries
|
||||||
await self.store.set_e2e_cross_signing_key(
|
await self.store.set_e2e_cross_signing_key(
|
||||||
|
@ -1212,7 +1205,7 @@ class E2eKeysHandler:
|
||||||
user.to_string(), retrieved_device_ids
|
user.to_string(), retrieved_device_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
return desired_key, desired_key_id, desired_verify_key
|
return desired_key_data
|
||||||
|
|
||||||
|
|
||||||
def _check_cross_signing_key(
|
def _check_cross_signing_key(
|
||||||
|
|
|
@ -14,15 +14,22 @@
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional, Union
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
from zope.interface import implementer
|
from zope.interface import implementer
|
||||||
|
|
||||||
from twisted.internet import defer, protocol
|
from twisted.internet import defer, protocol
|
||||||
from twisted.internet.error import ConnectError
|
from twisted.internet.error import ConnectError
|
||||||
from twisted.internet.interfaces import IReactorCore, IStreamClientEndpoint
|
from twisted.internet.interfaces import (
|
||||||
|
IAddress,
|
||||||
|
IConnector,
|
||||||
|
IProtocol,
|
||||||
|
IReactorCore,
|
||||||
|
IStreamClientEndpoint,
|
||||||
|
)
|
||||||
from twisted.internet.protocol import ClientFactory, Protocol, connectionDone
|
from twisted.internet.protocol import ClientFactory, Protocol, connectionDone
|
||||||
|
from twisted.python.failure import Failure
|
||||||
from twisted.web import http
|
from twisted.web import http
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -81,14 +88,14 @@ class HTTPConnectProxyEndpoint:
|
||||||
self._port = port
|
self._port = port
|
||||||
self._proxy_creds = proxy_creds
|
self._proxy_creds = proxy_creds
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self) -> str:
|
||||||
return "<HTTPConnectProxyEndpoint %s>" % (self._proxy_endpoint,)
|
return "<HTTPConnectProxyEndpoint %s>" % (self._proxy_endpoint,)
|
||||||
|
|
||||||
# Mypy encounters a false positive here: it complains that ClientFactory
|
# Mypy encounters a false positive here: it complains that ClientFactory
|
||||||
# is incompatible with IProtocolFactory. But ClientFactory inherits from
|
# is incompatible with IProtocolFactory. But ClientFactory inherits from
|
||||||
# Factory, which implements IProtocolFactory. So I think this is a bug
|
# Factory, which implements IProtocolFactory. So I think this is a bug
|
||||||
# in mypy-zope.
|
# in mypy-zope.
|
||||||
def connect(self, protocolFactory: ClientFactory): # type: ignore[override]
|
def connect(self, protocolFactory: ClientFactory) -> "defer.Deferred[IProtocol]": # type: ignore[override]
|
||||||
f = HTTPProxiedClientFactory(
|
f = HTTPProxiedClientFactory(
|
||||||
self._host, self._port, protocolFactory, self._proxy_creds
|
self._host, self._port, protocolFactory, self._proxy_creds
|
||||||
)
|
)
|
||||||
|
@ -125,10 +132,10 @@ class HTTPProxiedClientFactory(protocol.ClientFactory):
|
||||||
self.proxy_creds = proxy_creds
|
self.proxy_creds = proxy_creds
|
||||||
self.on_connection: "defer.Deferred[None]" = defer.Deferred()
|
self.on_connection: "defer.Deferred[None]" = defer.Deferred()
|
||||||
|
|
||||||
def startedConnecting(self, connector):
|
def startedConnecting(self, connector: IConnector) -> None:
|
||||||
return self.wrapped_factory.startedConnecting(connector)
|
return self.wrapped_factory.startedConnecting(connector)
|
||||||
|
|
||||||
def buildProtocol(self, addr):
|
def buildProtocol(self, addr: IAddress) -> "HTTPConnectProtocol":
|
||||||
wrapped_protocol = self.wrapped_factory.buildProtocol(addr)
|
wrapped_protocol = self.wrapped_factory.buildProtocol(addr)
|
||||||
if wrapped_protocol is None:
|
if wrapped_protocol is None:
|
||||||
raise TypeError("buildProtocol produced None instead of a Protocol")
|
raise TypeError("buildProtocol produced None instead of a Protocol")
|
||||||
|
@ -141,13 +148,13 @@ class HTTPProxiedClientFactory(protocol.ClientFactory):
|
||||||
self.proxy_creds,
|
self.proxy_creds,
|
||||||
)
|
)
|
||||||
|
|
||||||
def clientConnectionFailed(self, connector, reason):
|
def clientConnectionFailed(self, connector: IConnector, reason: Failure) -> None:
|
||||||
logger.debug("Connection to proxy failed: %s", reason)
|
logger.debug("Connection to proxy failed: %s", reason)
|
||||||
if not self.on_connection.called:
|
if not self.on_connection.called:
|
||||||
self.on_connection.errback(reason)
|
self.on_connection.errback(reason)
|
||||||
return self.wrapped_factory.clientConnectionFailed(connector, reason)
|
return self.wrapped_factory.clientConnectionFailed(connector, reason)
|
||||||
|
|
||||||
def clientConnectionLost(self, connector, reason):
|
def clientConnectionLost(self, connector: IConnector, reason: Failure) -> None:
|
||||||
logger.debug("Connection to proxy lost: %s", reason)
|
logger.debug("Connection to proxy lost: %s", reason)
|
||||||
if not self.on_connection.called:
|
if not self.on_connection.called:
|
||||||
self.on_connection.errback(reason)
|
self.on_connection.errback(reason)
|
||||||
|
@ -191,10 +198,10 @@ class HTTPConnectProtocol(protocol.Protocol):
|
||||||
)
|
)
|
||||||
self.http_setup_client.on_connected.addCallback(self.proxyConnected)
|
self.http_setup_client.on_connected.addCallback(self.proxyConnected)
|
||||||
|
|
||||||
def connectionMade(self):
|
def connectionMade(self) -> None:
|
||||||
self.http_setup_client.makeConnection(self.transport)
|
self.http_setup_client.makeConnection(self.transport)
|
||||||
|
|
||||||
def connectionLost(self, reason=connectionDone):
|
def connectionLost(self, reason: Failure = connectionDone) -> None:
|
||||||
if self.wrapped_protocol.connected:
|
if self.wrapped_protocol.connected:
|
||||||
self.wrapped_protocol.connectionLost(reason)
|
self.wrapped_protocol.connectionLost(reason)
|
||||||
|
|
||||||
|
@ -203,7 +210,7 @@ class HTTPConnectProtocol(protocol.Protocol):
|
||||||
if not self.connected_deferred.called:
|
if not self.connected_deferred.called:
|
||||||
self.connected_deferred.errback(reason)
|
self.connected_deferred.errback(reason)
|
||||||
|
|
||||||
def proxyConnected(self, _):
|
def proxyConnected(self, _: Union[None, "defer.Deferred[None]"]) -> None:
|
||||||
self.wrapped_protocol.makeConnection(self.transport)
|
self.wrapped_protocol.makeConnection(self.transport)
|
||||||
|
|
||||||
self.connected_deferred.callback(self.wrapped_protocol)
|
self.connected_deferred.callback(self.wrapped_protocol)
|
||||||
|
@ -213,7 +220,7 @@ class HTTPConnectProtocol(protocol.Protocol):
|
||||||
if buf:
|
if buf:
|
||||||
self.wrapped_protocol.dataReceived(buf)
|
self.wrapped_protocol.dataReceived(buf)
|
||||||
|
|
||||||
def dataReceived(self, data: bytes):
|
def dataReceived(self, data: bytes) -> None:
|
||||||
# if we've set up the HTTP protocol, we can send the data there
|
# if we've set up the HTTP protocol, we can send the data there
|
||||||
if self.wrapped_protocol.connected:
|
if self.wrapped_protocol.connected:
|
||||||
return self.wrapped_protocol.dataReceived(data)
|
return self.wrapped_protocol.dataReceived(data)
|
||||||
|
@ -243,7 +250,7 @@ class HTTPConnectSetupClient(http.HTTPClient):
|
||||||
self.proxy_creds = proxy_creds
|
self.proxy_creds = proxy_creds
|
||||||
self.on_connected: "defer.Deferred[None]" = defer.Deferred()
|
self.on_connected: "defer.Deferred[None]" = defer.Deferred()
|
||||||
|
|
||||||
def connectionMade(self):
|
def connectionMade(self) -> None:
|
||||||
logger.debug("Connected to proxy, sending CONNECT")
|
logger.debug("Connected to proxy, sending CONNECT")
|
||||||
self.sendCommand(b"CONNECT", b"%s:%d" % (self.host, self.port))
|
self.sendCommand(b"CONNECT", b"%s:%d" % (self.host, self.port))
|
||||||
|
|
||||||
|
@ -257,14 +264,14 @@ class HTTPConnectSetupClient(http.HTTPClient):
|
||||||
|
|
||||||
self.endHeaders()
|
self.endHeaders()
|
||||||
|
|
||||||
def handleStatus(self, version: bytes, status: bytes, message: bytes):
|
def handleStatus(self, version: bytes, status: bytes, message: bytes) -> None:
|
||||||
logger.debug("Got Status: %s %s %s", status, message, version)
|
logger.debug("Got Status: %s %s %s", status, message, version)
|
||||||
if status != b"200":
|
if status != b"200":
|
||||||
raise ProxyConnectError(f"Unexpected status on CONNECT: {status!s}")
|
raise ProxyConnectError(f"Unexpected status on CONNECT: {status!s}")
|
||||||
|
|
||||||
def handleEndHeaders(self):
|
def handleEndHeaders(self) -> None:
|
||||||
logger.debug("End Headers")
|
logger.debug("End Headers")
|
||||||
self.on_connected.callback(None)
|
self.on_connected.callback(None)
|
||||||
|
|
||||||
def handleResponse(self, body):
|
def handleResponse(self, body: bytes) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -245,7 +245,7 @@ def http_proxy_endpoint(
|
||||||
proxy: Optional[bytes],
|
proxy: Optional[bytes],
|
||||||
reactor: IReactorCore,
|
reactor: IReactorCore,
|
||||||
tls_options_factory: Optional[IPolicyForHTTPS],
|
tls_options_factory: Optional[IPolicyForHTTPS],
|
||||||
**kwargs,
|
**kwargs: object,
|
||||||
) -> Tuple[Optional[IStreamClientEndpoint], Optional[ProxyCredentials]]:
|
) -> Tuple[Optional[IStreamClientEndpoint], Optional[ProxyCredentials]]:
|
||||||
"""Parses an http proxy setting and returns an endpoint for the proxy
|
"""Parses an http proxy setting and returns an endpoint for the proxy
|
||||||
|
|
||||||
|
|
|
@ -31,7 +31,11 @@ from twisted.internet.endpoints import (
|
||||||
TCP4ClientEndpoint,
|
TCP4ClientEndpoint,
|
||||||
TCP6ClientEndpoint,
|
TCP6ClientEndpoint,
|
||||||
)
|
)
|
||||||
from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint
|
from twisted.internet.interfaces import (
|
||||||
|
IPushProducer,
|
||||||
|
IReactorTCP,
|
||||||
|
IStreamClientEndpoint,
|
||||||
|
)
|
||||||
from twisted.internet.protocol import Factory, Protocol
|
from twisted.internet.protocol import Factory, Protocol
|
||||||
from twisted.internet.tcp import Connection
|
from twisted.internet.tcp import Connection
|
||||||
from twisted.python.failure import Failure
|
from twisted.python.failure import Failure
|
||||||
|
@ -59,14 +63,14 @@ class LogProducer:
|
||||||
_buffer: Deque[logging.LogRecord]
|
_buffer: Deque[logging.LogRecord]
|
||||||
_paused: bool = attr.ib(default=False, init=False)
|
_paused: bool = attr.ib(default=False, init=False)
|
||||||
|
|
||||||
def pauseProducing(self):
|
def pauseProducing(self) -> None:
|
||||||
self._paused = True
|
self._paused = True
|
||||||
|
|
||||||
def stopProducing(self):
|
def stopProducing(self) -> None:
|
||||||
self._paused = True
|
self._paused = True
|
||||||
self._buffer = deque()
|
self._buffer = deque()
|
||||||
|
|
||||||
def resumeProducing(self):
|
def resumeProducing(self) -> None:
|
||||||
# If we're already producing, nothing to do.
|
# If we're already producing, nothing to do.
|
||||||
self._paused = False
|
self._paused = False
|
||||||
|
|
||||||
|
@ -102,8 +106,8 @@ class RemoteHandler(logging.Handler):
|
||||||
host: str,
|
host: str,
|
||||||
port: int,
|
port: int,
|
||||||
maximum_buffer: int = 1000,
|
maximum_buffer: int = 1000,
|
||||||
level=logging.NOTSET,
|
level: int = logging.NOTSET,
|
||||||
_reactor=None,
|
_reactor: Optional[IReactorTCP] = None,
|
||||||
):
|
):
|
||||||
super().__init__(level=level)
|
super().__init__(level=level)
|
||||||
self.host = host
|
self.host = host
|
||||||
|
@ -118,7 +122,7 @@ class RemoteHandler(logging.Handler):
|
||||||
if _reactor is None:
|
if _reactor is None:
|
||||||
from twisted.internet import reactor
|
from twisted.internet import reactor
|
||||||
|
|
||||||
_reactor = reactor
|
_reactor = reactor # type: ignore[assignment]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ip = ip_address(self.host)
|
ip = ip_address(self.host)
|
||||||
|
@ -139,7 +143,7 @@ class RemoteHandler(logging.Handler):
|
||||||
self._stopping = False
|
self._stopping = False
|
||||||
self._connect()
|
self._connect()
|
||||||
|
|
||||||
def close(self):
|
def close(self) -> None:
|
||||||
self._stopping = True
|
self._stopping = True
|
||||||
self._service.stopService()
|
self._service.stopService()
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,8 @@
|
||||||
import logging
|
import logging
|
||||||
import traceback
|
import traceback
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
|
from types import TracebackType
|
||||||
|
from typing import Optional, Tuple, Type
|
||||||
|
|
||||||
|
|
||||||
class LogFormatter(logging.Formatter):
|
class LogFormatter(logging.Formatter):
|
||||||
|
@ -28,10 +30,14 @@ class LogFormatter(logging.Formatter):
|
||||||
where it was caught are logged).
|
where it was caught are logged).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def formatException(
|
||||||
super().__init__(*args, **kwargs)
|
self,
|
||||||
|
ei: Tuple[
|
||||||
def formatException(self, ei):
|
Optional[Type[BaseException]],
|
||||||
|
Optional[BaseException],
|
||||||
|
Optional[TracebackType],
|
||||||
|
],
|
||||||
|
) -> str:
|
||||||
sio = StringIO()
|
sio = StringIO()
|
||||||
(typ, val, tb) = ei
|
(typ, val, tb) = ei
|
||||||
|
|
||||||
|
|
|
@ -49,7 +49,7 @@ class PeriodicallyFlushingMemoryHandler(MemoryHandler):
|
||||||
)
|
)
|
||||||
self._flushing_thread.start()
|
self._flushing_thread.start()
|
||||||
|
|
||||||
def on_reactor_running():
|
def on_reactor_running() -> None:
|
||||||
self._reactor_started = True
|
self._reactor_started = True
|
||||||
|
|
||||||
reactor_to_use: IReactorCore
|
reactor_to_use: IReactorCore
|
||||||
|
@ -74,7 +74,7 @@ class PeriodicallyFlushingMemoryHandler(MemoryHandler):
|
||||||
else:
|
else:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _flush_periodically(self):
|
def _flush_periodically(self) -> None:
|
||||||
"""
|
"""
|
||||||
Whilst this handler is active, flush the handler periodically.
|
Whilst this handler is active, flush the handler periodically.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -13,6 +13,8 @@
|
||||||
# limitations under the License.import logging
|
# limitations under the License.import logging
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from types import TracebackType
|
||||||
|
from typing import Optional, Type
|
||||||
|
|
||||||
from opentracing import Scope, ScopeManager
|
from opentracing import Scope, ScopeManager
|
||||||
|
|
||||||
|
@ -107,19 +109,26 @@ class _LogContextScope(Scope):
|
||||||
and - if enter_logcontext was set - the logcontext is finished too.
|
and - if enter_logcontext was set - the logcontext is finished too.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, manager, span, logcontext, enter_logcontext, finish_on_close):
|
def __init__(
|
||||||
|
self,
|
||||||
|
manager: LogContextScopeManager,
|
||||||
|
span,
|
||||||
|
logcontext,
|
||||||
|
enter_logcontext: bool,
|
||||||
|
finish_on_close: bool,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
manager (LogContextScopeManager):
|
manager:
|
||||||
the manager that is responsible for this scope.
|
the manager that is responsible for this scope.
|
||||||
span (Span):
|
span (Span):
|
||||||
the opentracing span which this scope represents the local
|
the opentracing span which this scope represents the local
|
||||||
lifetime for.
|
lifetime for.
|
||||||
logcontext (LogContext):
|
logcontext (LogContext):
|
||||||
the logcontext to which this scope is attached.
|
the logcontext to which this scope is attached.
|
||||||
enter_logcontext (Boolean):
|
enter_logcontext:
|
||||||
if True the logcontext will be exited when the scope is finished
|
if True the logcontext will be exited when the scope is finished
|
||||||
finish_on_close (Boolean):
|
finish_on_close:
|
||||||
if True finish the span when the scope is closed
|
if True finish the span when the scope is closed
|
||||||
"""
|
"""
|
||||||
super().__init__(manager, span)
|
super().__init__(manager, span)
|
||||||
|
@ -127,16 +136,21 @@ class _LogContextScope(Scope):
|
||||||
self._finish_on_close = finish_on_close
|
self._finish_on_close = finish_on_close
|
||||||
self._enter_logcontext = enter_logcontext
|
self._enter_logcontext = enter_logcontext
|
||||||
|
|
||||||
def __exit__(self, exc_type, value, traceback):
|
def __exit__(
|
||||||
|
self,
|
||||||
|
exc_type: Optional[Type[BaseException]],
|
||||||
|
value: Optional[BaseException],
|
||||||
|
traceback: Optional[TracebackType],
|
||||||
|
) -> None:
|
||||||
if exc_type == twisted.internet.defer._DefGen_Return:
|
if exc_type == twisted.internet.defer._DefGen_Return:
|
||||||
# filter out defer.returnValue() calls
|
# filter out defer.returnValue() calls
|
||||||
exc_type = value = traceback = None
|
exc_type = value = traceback = None
|
||||||
super().__exit__(exc_type, value, traceback)
|
super().__exit__(exc_type, value, traceback)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self) -> str:
|
||||||
return f"Scope<{self.span}>"
|
return f"Scope<{self.span}>"
|
||||||
|
|
||||||
def close(self):
|
def close(self) -> None:
|
||||||
active_scope = self.manager.active
|
active_scope = self.manager.active
|
||||||
if active_scope is not self:
|
if active_scope is not self:
|
||||||
logger.error(
|
logger.error(
|
||||||
|
|
|
@ -12,20 +12,24 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
|
from types import TracebackType
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
AsyncContextManager,
|
AsyncContextManager,
|
||||||
Awaitable,
|
Awaitable,
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
Iterable,
|
Iterable,
|
||||||
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
|
Type,
|
||||||
)
|
)
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.storage.types import Connection
|
from synapse.storage.types import Connection, Cursor
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
from synapse.util import Clock, json_encoder
|
from synapse.util import Clock, json_encoder
|
||||||
|
|
||||||
|
@ -74,7 +78,12 @@ class _BackgroundUpdateContextManager:
|
||||||
|
|
||||||
return self._update_duration_ms
|
return self._update_duration_ms
|
||||||
|
|
||||||
async def __aexit__(self, *exc) -> None:
|
async def __aexit__(
|
||||||
|
self,
|
||||||
|
exc_type: Optional[Type[BaseException]],
|
||||||
|
exc: Optional[BaseException],
|
||||||
|
tb: Optional[TracebackType],
|
||||||
|
) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@ -352,7 +361,7 @@ class BackgroundUpdater:
|
||||||
True if we have finished running all the background updates, otherwise False
|
True if we have finished running all the background updates, otherwise False
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_background_updates_txn(txn):
|
def get_background_updates_txn(txn: Cursor) -> List[Dict[str, Any]]:
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"""
|
"""
|
||||||
SELECT update_name, depends_on FROM background_updates
|
SELECT update_name, depends_on FROM background_updates
|
||||||
|
@ -469,7 +478,7 @@ class BackgroundUpdater:
|
||||||
self,
|
self,
|
||||||
update_name: str,
|
update_name: str,
|
||||||
update_handler: Callable[[JsonDict, int], Awaitable[int]],
|
update_handler: Callable[[JsonDict, int], Awaitable[int]],
|
||||||
):
|
) -> None:
|
||||||
"""Register a handler for doing a background update.
|
"""Register a handler for doing a background update.
|
||||||
|
|
||||||
The handler should take two arguments:
|
The handler should take two arguments:
|
||||||
|
@ -603,7 +612,7 @@ class BackgroundUpdater:
|
||||||
else:
|
else:
|
||||||
runner = create_index_sqlite
|
runner = create_index_sqlite
|
||||||
|
|
||||||
async def updater(progress, batch_size):
|
async def updater(progress: JsonDict, batch_size: int) -> int:
|
||||||
if runner is not None:
|
if runner is not None:
|
||||||
logger.info("Adding index %s to %s", index_name, table)
|
logger.info("Adding index %s to %s", index_name, table)
|
||||||
await self.db_pool.runWithConnection(runner)
|
await self.db_pool.runWithConnection(runner)
|
||||||
|
|
|
@ -24,6 +24,7 @@ from typing import (
|
||||||
Mapping,
|
Mapping,
|
||||||
Match,
|
Match,
|
||||||
MutableMapping,
|
MutableMapping,
|
||||||
|
NoReturn,
|
||||||
Optional,
|
Optional,
|
||||||
Set,
|
Set,
|
||||||
Tuple,
|
Tuple,
|
||||||
|
@ -35,6 +36,7 @@ from typing import (
|
||||||
import attr
|
import attr
|
||||||
from frozendict import frozendict
|
from frozendict import frozendict
|
||||||
from signedjson.key import decode_verify_key_bytes
|
from signedjson.key import decode_verify_key_bytes
|
||||||
|
from signedjson.types import VerifyKey
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
from unpaddedbase64 import decode_base64
|
from unpaddedbase64 import decode_base64
|
||||||
from zope.interface import Interface
|
from zope.interface import Interface
|
||||||
|
@ -55,6 +57,7 @@ from synapse.util.stringutils import parse_and_validate_server_name
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.appservice.api import ApplicationService
|
from synapse.appservice.api import ApplicationService
|
||||||
from synapse.storage.databases.main import DataStore, PurgeEventsStore
|
from synapse.storage.databases.main import DataStore, PurgeEventsStore
|
||||||
|
from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore
|
||||||
|
|
||||||
# Define a state map type from type/state_key to T (usually an event ID or
|
# Define a state map type from type/state_key to T (usually an event ID or
|
||||||
# event)
|
# event)
|
||||||
|
@ -114,7 +117,7 @@ class Requester:
|
||||||
app_service: Optional["ApplicationService"]
|
app_service: Optional["ApplicationService"]
|
||||||
authenticated_entity: str
|
authenticated_entity: str
|
||||||
|
|
||||||
def serialize(self):
|
def serialize(self) -> Dict[str, Any]:
|
||||||
"""Converts self to a type that can be serialized as JSON, and then
|
"""Converts self to a type that can be serialized as JSON, and then
|
||||||
deserialized by `deserialize`
|
deserialized by `deserialize`
|
||||||
|
|
||||||
|
@ -132,7 +135,9 @@ class Requester:
|
||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def deserialize(store, input):
|
def deserialize(
|
||||||
|
store: "ApplicationServiceWorkerStore", input: Dict[str, Any]
|
||||||
|
) -> "Requester":
|
||||||
"""Converts a dict that was produced by `serialize` back into a
|
"""Converts a dict that was produced by `serialize` back into a
|
||||||
Requester.
|
Requester.
|
||||||
|
|
||||||
|
@ -236,10 +241,10 @@ class DomainSpecificString(metaclass=abc.ABCMeta):
|
||||||
domain: str
|
domain: str
|
||||||
|
|
||||||
# Because this is a frozen class, it is deeply immutable.
|
# Because this is a frozen class, it is deeply immutable.
|
||||||
def __copy__(self):
|
def __copy__(self: DS) -> DS:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __deepcopy__(self, memo):
|
def __deepcopy__(self: DS, memo: Dict[str, object]) -> DS:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -729,12 +734,14 @@ class StreamToken:
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def room_stream_id(self):
|
def room_stream_id(self) -> int:
|
||||||
return self.room_key.stream
|
return self.room_key.stream
|
||||||
|
|
||||||
def copy_and_advance(self, key, new_value) -> "StreamToken":
|
def copy_and_advance(self, key: str, new_value: Any) -> "StreamToken":
|
||||||
"""Advance the given key in the token to a new value if and only if the
|
"""Advance the given key in the token to a new value if and only if the
|
||||||
new value is after the old value.
|
new value is after the old value.
|
||||||
|
|
||||||
|
:raises TypeError: if `key` is not the one of the keys tracked by a StreamToken.
|
||||||
"""
|
"""
|
||||||
if key == "room_key":
|
if key == "room_key":
|
||||||
new_token = self.copy_and_replace(
|
new_token = self.copy_and_replace(
|
||||||
|
@ -751,7 +758,7 @@ class StreamToken:
|
||||||
else:
|
else:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def copy_and_replace(self, key, new_value) -> "StreamToken":
|
def copy_and_replace(self, key: str, new_value: Any) -> "StreamToken":
|
||||||
return attr.evolve(self, **{key: new_value})
|
return attr.evolve(self, **{key: new_value})
|
||||||
|
|
||||||
|
|
||||||
|
@ -793,14 +800,14 @@ class ThirdPartyInstanceID:
|
||||||
# Deny iteration because it will bite you if you try to create a singleton
|
# Deny iteration because it will bite you if you try to create a singleton
|
||||||
# set by:
|
# set by:
|
||||||
# users = set(user)
|
# users = set(user)
|
||||||
def __iter__(self):
|
def __iter__(self) -> NoReturn:
|
||||||
raise ValueError("Attempted to iterate a %s" % (type(self).__name__,))
|
raise ValueError("Attempted to iterate a %s" % (type(self).__name__,))
|
||||||
|
|
||||||
# Because this class is a frozen class, it is deeply immutable.
|
# Because this class is a frozen class, it is deeply immutable.
|
||||||
def __copy__(self):
|
def __copy__(self) -> "ThirdPartyInstanceID":
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __deepcopy__(self, memo):
|
def __deepcopy__(self, memo: Dict[str, object]) -> "ThirdPartyInstanceID":
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -852,25 +859,28 @@ class DeviceListUpdates:
|
||||||
return bool(self.changed or self.left)
|
return bool(self.changed or self.left)
|
||||||
|
|
||||||
|
|
||||||
def get_verify_key_from_cross_signing_key(key_info):
|
def get_verify_key_from_cross_signing_key(
|
||||||
|
key_info: Mapping[str, Any]
|
||||||
|
) -> Tuple[str, VerifyKey]:
|
||||||
"""Get the key ID and signedjson verify key from a cross-signing key dict
|
"""Get the key ID and signedjson verify key from a cross-signing key dict
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
key_info (dict): a cross-signing key dict, which must have a "keys"
|
key_info: a cross-signing key dict, which must have a "keys"
|
||||||
property that has exactly one item in it
|
property that has exactly one item in it
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(str, VerifyKey): the key ID and verify key for the cross-signing key
|
the key ID and verify key for the cross-signing key
|
||||||
"""
|
"""
|
||||||
# make sure that exactly one key is provided
|
# make sure that a `keys` field is provided
|
||||||
if "keys" not in key_info:
|
if "keys" not in key_info:
|
||||||
raise ValueError("Invalid key")
|
raise ValueError("Invalid key")
|
||||||
keys = key_info["keys"]
|
keys = key_info["keys"]
|
||||||
if len(keys) != 1:
|
# and that it contains exactly one key
|
||||||
raise ValueError("Invalid key")
|
if len(keys) == 1:
|
||||||
# and return that one key
|
key_id, key_data = next(iter(keys.items()))
|
||||||
for key_id, key_data in keys.items():
|
|
||||||
return key_id, decode_verify_key_bytes(key_id, decode_base64(key_data))
|
return key_id, decode_verify_key_bytes(key_id, decode_base64(key_data))
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid key")
|
||||||
|
|
||||||
|
|
||||||
@attr.s(auto_attribs=True, frozen=True, slots=True)
|
@attr.s(auto_attribs=True, frozen=True, slots=True)
|
||||||
|
|
Loading…
Reference in a new issue