Another batch of type annotations (#12726)

This commit is contained in:
David Robertson 2022-05-13 12:35:31 +01:00 committed by GitHub
parent 39bed28b28
commit aec69d2481
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 144 additions and 79 deletions

1
changelog.d/12726.misc Normal file
View file

@ -0,0 +1 @@
Add type annotations to increase the number of modules passing `disallow-untyped-defs`.

View file

@ -128,15 +128,30 @@ disallow_untyped_defs = True
[mypy-synapse.http.federation.*] [mypy-synapse.http.federation.*]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-synapse.http.connectproxyclient]
disallow_untyped_defs = True
[mypy-synapse.http.proxyagent]
disallow_untyped_defs = True
[mypy-synapse.http.request_metrics] [mypy-synapse.http.request_metrics]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-synapse.http.server] [mypy-synapse.http.server]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-synapse.logging._remote]
disallow_untyped_defs = True
[mypy-synapse.logging.context] [mypy-synapse.logging.context]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-synapse.logging.formatter]
disallow_untyped_defs = True
[mypy-synapse.logging.handlers]
disallow_untyped_defs = True
[mypy-synapse.metrics.*] [mypy-synapse.metrics.*]
disallow_untyped_defs = True disallow_untyped_defs = True
@ -166,6 +181,9 @@ disallow_untyped_defs = True
[mypy-synapse.state.*] [mypy-synapse.state.*]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-synapse.storage.databases.background_updates]
disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.account_data] [mypy-synapse.storage.databases.main.account_data]
disallow_untyped_defs = True disallow_untyped_defs = True
@ -232,6 +250,9 @@ disallow_untyped_defs = True
[mypy-synapse.streams.*] [mypy-synapse.streams.*]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-synapse.types]
disallow_untyped_defs = True
[mypy-synapse.util.*] [mypy-synapse.util.*]
disallow_untyped_defs = True disallow_untyped_defs = True

View file

@ -15,7 +15,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
import attr import attr
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
@ -1105,22 +1105,19 @@ class E2eKeysHandler:
# can request over federation # can request over federation
raise NotFoundError("No %s key found for %s" % (key_type, user_id)) raise NotFoundError("No %s key found for %s" % (key_type, user_id))
( cross_signing_keys = await self._retrieve_cross_signing_keys_for_remote_user(
key, user, key_type
key_id, )
verify_key, if cross_signing_keys is None:
) = await self._retrieve_cross_signing_keys_for_remote_user(user, key_type)
if key is None:
raise NotFoundError("No %s key found for %s" % (key_type, user_id)) raise NotFoundError("No %s key found for %s" % (key_type, user_id))
return key, key_id, verify_key return cross_signing_keys
async def _retrieve_cross_signing_keys_for_remote_user( async def _retrieve_cross_signing_keys_for_remote_user(
self, self,
user: UserID, user: UserID,
desired_key_type: str, desired_key_type: str,
) -> Tuple[Optional[dict], Optional[str], Optional[VerifyKey]]: ) -> Optional[Tuple[Dict[str, Any], str, VerifyKey]]:
"""Queries cross-signing keys for a remote user and saves them to the database """Queries cross-signing keys for a remote user and saves them to the database
Only the key specified by `key_type` will be returned, while all retrieved keys Only the key specified by `key_type` will be returned, while all retrieved keys
@ -1146,12 +1143,10 @@ class E2eKeysHandler:
type(e), type(e),
e, e,
) )
return None, None, None return None
# Process each of the retrieved cross-signing keys # Process each of the retrieved cross-signing keys
desired_key = None desired_key_data = None
desired_key_id = None
desired_verify_key = None
retrieved_device_ids = [] retrieved_device_ids = []
for key_type in ["master", "self_signing"]: for key_type in ["master", "self_signing"]:
key_content = remote_result.get(key_type + "_key") key_content = remote_result.get(key_type + "_key")
@ -1196,9 +1191,7 @@ class E2eKeysHandler:
# If this is the desired key type, save it and its ID/VerifyKey # If this is the desired key type, save it and its ID/VerifyKey
if key_type == desired_key_type: if key_type == desired_key_type:
desired_key = key_content desired_key_data = key_content, key_id, verify_key
desired_verify_key = verify_key
desired_key_id = key_id
# At the same time, store this key in the db for subsequent queries # At the same time, store this key in the db for subsequent queries
await self.store.set_e2e_cross_signing_key( await self.store.set_e2e_cross_signing_key(
@ -1212,7 +1205,7 @@ class E2eKeysHandler:
user.to_string(), retrieved_device_ids user.to_string(), retrieved_device_ids
) )
return desired_key, desired_key_id, desired_verify_key return desired_key_data
def _check_cross_signing_key( def _check_cross_signing_key(

View file

@ -14,15 +14,22 @@
import base64 import base64
import logging import logging
from typing import Optional from typing import Optional, Union
import attr import attr
from zope.interface import implementer from zope.interface import implementer
from twisted.internet import defer, protocol from twisted.internet import defer, protocol
from twisted.internet.error import ConnectError from twisted.internet.error import ConnectError
from twisted.internet.interfaces import IReactorCore, IStreamClientEndpoint from twisted.internet.interfaces import (
IAddress,
IConnector,
IProtocol,
IReactorCore,
IStreamClientEndpoint,
)
from twisted.internet.protocol import ClientFactory, Protocol, connectionDone from twisted.internet.protocol import ClientFactory, Protocol, connectionDone
from twisted.python.failure import Failure
from twisted.web import http from twisted.web import http
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -81,14 +88,14 @@ class HTTPConnectProxyEndpoint:
self._port = port self._port = port
self._proxy_creds = proxy_creds self._proxy_creds = proxy_creds
def __repr__(self): def __repr__(self) -> str:
return "<HTTPConnectProxyEndpoint %s>" % (self._proxy_endpoint,) return "<HTTPConnectProxyEndpoint %s>" % (self._proxy_endpoint,)
# Mypy encounters a false positive here: it complains that ClientFactory # Mypy encounters a false positive here: it complains that ClientFactory
# is incompatible with IProtocolFactory. But ClientFactory inherits from # is incompatible with IProtocolFactory. But ClientFactory inherits from
# Factory, which implements IProtocolFactory. So I think this is a bug # Factory, which implements IProtocolFactory. So I think this is a bug
# in mypy-zope. # in mypy-zope.
def connect(self, protocolFactory: ClientFactory): # type: ignore[override] def connect(self, protocolFactory: ClientFactory) -> "defer.Deferred[IProtocol]": # type: ignore[override]
f = HTTPProxiedClientFactory( f = HTTPProxiedClientFactory(
self._host, self._port, protocolFactory, self._proxy_creds self._host, self._port, protocolFactory, self._proxy_creds
) )
@ -125,10 +132,10 @@ class HTTPProxiedClientFactory(protocol.ClientFactory):
self.proxy_creds = proxy_creds self.proxy_creds = proxy_creds
self.on_connection: "defer.Deferred[None]" = defer.Deferred() self.on_connection: "defer.Deferred[None]" = defer.Deferred()
def startedConnecting(self, connector): def startedConnecting(self, connector: IConnector) -> None:
return self.wrapped_factory.startedConnecting(connector) return self.wrapped_factory.startedConnecting(connector)
def buildProtocol(self, addr): def buildProtocol(self, addr: IAddress) -> "HTTPConnectProtocol":
wrapped_protocol = self.wrapped_factory.buildProtocol(addr) wrapped_protocol = self.wrapped_factory.buildProtocol(addr)
if wrapped_protocol is None: if wrapped_protocol is None:
raise TypeError("buildProtocol produced None instead of a Protocol") raise TypeError("buildProtocol produced None instead of a Protocol")
@ -141,13 +148,13 @@ class HTTPProxiedClientFactory(protocol.ClientFactory):
self.proxy_creds, self.proxy_creds,
) )
def clientConnectionFailed(self, connector, reason): def clientConnectionFailed(self, connector: IConnector, reason: Failure) -> None:
logger.debug("Connection to proxy failed: %s", reason) logger.debug("Connection to proxy failed: %s", reason)
if not self.on_connection.called: if not self.on_connection.called:
self.on_connection.errback(reason) self.on_connection.errback(reason)
return self.wrapped_factory.clientConnectionFailed(connector, reason) return self.wrapped_factory.clientConnectionFailed(connector, reason)
def clientConnectionLost(self, connector, reason): def clientConnectionLost(self, connector: IConnector, reason: Failure) -> None:
logger.debug("Connection to proxy lost: %s", reason) logger.debug("Connection to proxy lost: %s", reason)
if not self.on_connection.called: if not self.on_connection.called:
self.on_connection.errback(reason) self.on_connection.errback(reason)
@ -191,10 +198,10 @@ class HTTPConnectProtocol(protocol.Protocol):
) )
self.http_setup_client.on_connected.addCallback(self.proxyConnected) self.http_setup_client.on_connected.addCallback(self.proxyConnected)
def connectionMade(self): def connectionMade(self) -> None:
self.http_setup_client.makeConnection(self.transport) self.http_setup_client.makeConnection(self.transport)
def connectionLost(self, reason=connectionDone): def connectionLost(self, reason: Failure = connectionDone) -> None:
if self.wrapped_protocol.connected: if self.wrapped_protocol.connected:
self.wrapped_protocol.connectionLost(reason) self.wrapped_protocol.connectionLost(reason)
@ -203,7 +210,7 @@ class HTTPConnectProtocol(protocol.Protocol):
if not self.connected_deferred.called: if not self.connected_deferred.called:
self.connected_deferred.errback(reason) self.connected_deferred.errback(reason)
def proxyConnected(self, _): def proxyConnected(self, _: Union[None, "defer.Deferred[None]"]) -> None:
self.wrapped_protocol.makeConnection(self.transport) self.wrapped_protocol.makeConnection(self.transport)
self.connected_deferred.callback(self.wrapped_protocol) self.connected_deferred.callback(self.wrapped_protocol)
@ -213,7 +220,7 @@ class HTTPConnectProtocol(protocol.Protocol):
if buf: if buf:
self.wrapped_protocol.dataReceived(buf) self.wrapped_protocol.dataReceived(buf)
def dataReceived(self, data: bytes): def dataReceived(self, data: bytes) -> None:
# if we've set up the HTTP protocol, we can send the data there # if we've set up the HTTP protocol, we can send the data there
if self.wrapped_protocol.connected: if self.wrapped_protocol.connected:
return self.wrapped_protocol.dataReceived(data) return self.wrapped_protocol.dataReceived(data)
@ -243,7 +250,7 @@ class HTTPConnectSetupClient(http.HTTPClient):
self.proxy_creds = proxy_creds self.proxy_creds = proxy_creds
self.on_connected: "defer.Deferred[None]" = defer.Deferred() self.on_connected: "defer.Deferred[None]" = defer.Deferred()
def connectionMade(self): def connectionMade(self) -> None:
logger.debug("Connected to proxy, sending CONNECT") logger.debug("Connected to proxy, sending CONNECT")
self.sendCommand(b"CONNECT", b"%s:%d" % (self.host, self.port)) self.sendCommand(b"CONNECT", b"%s:%d" % (self.host, self.port))
@ -257,14 +264,14 @@ class HTTPConnectSetupClient(http.HTTPClient):
self.endHeaders() self.endHeaders()
def handleStatus(self, version: bytes, status: bytes, message: bytes): def handleStatus(self, version: bytes, status: bytes, message: bytes) -> None:
logger.debug("Got Status: %s %s %s", status, message, version) logger.debug("Got Status: %s %s %s", status, message, version)
if status != b"200": if status != b"200":
raise ProxyConnectError(f"Unexpected status on CONNECT: {status!s}") raise ProxyConnectError(f"Unexpected status on CONNECT: {status!s}")
def handleEndHeaders(self): def handleEndHeaders(self) -> None:
logger.debug("End Headers") logger.debug("End Headers")
self.on_connected.callback(None) self.on_connected.callback(None)
def handleResponse(self, body): def handleResponse(self, body: bytes) -> None:
pass pass

View file

@ -245,7 +245,7 @@ def http_proxy_endpoint(
proxy: Optional[bytes], proxy: Optional[bytes],
reactor: IReactorCore, reactor: IReactorCore,
tls_options_factory: Optional[IPolicyForHTTPS], tls_options_factory: Optional[IPolicyForHTTPS],
**kwargs, **kwargs: object,
) -> Tuple[Optional[IStreamClientEndpoint], Optional[ProxyCredentials]]: ) -> Tuple[Optional[IStreamClientEndpoint], Optional[ProxyCredentials]]:
"""Parses an http proxy setting and returns an endpoint for the proxy """Parses an http proxy setting and returns an endpoint for the proxy

View file

@ -31,7 +31,11 @@ from twisted.internet.endpoints import (
TCP4ClientEndpoint, TCP4ClientEndpoint,
TCP6ClientEndpoint, TCP6ClientEndpoint,
) )
from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint from twisted.internet.interfaces import (
IPushProducer,
IReactorTCP,
IStreamClientEndpoint,
)
from twisted.internet.protocol import Factory, Protocol from twisted.internet.protocol import Factory, Protocol
from twisted.internet.tcp import Connection from twisted.internet.tcp import Connection
from twisted.python.failure import Failure from twisted.python.failure import Failure
@ -59,14 +63,14 @@ class LogProducer:
_buffer: Deque[logging.LogRecord] _buffer: Deque[logging.LogRecord]
_paused: bool = attr.ib(default=False, init=False) _paused: bool = attr.ib(default=False, init=False)
def pauseProducing(self): def pauseProducing(self) -> None:
self._paused = True self._paused = True
def stopProducing(self): def stopProducing(self) -> None:
self._paused = True self._paused = True
self._buffer = deque() self._buffer = deque()
def resumeProducing(self): def resumeProducing(self) -> None:
# If we're already producing, nothing to do. # If we're already producing, nothing to do.
self._paused = False self._paused = False
@ -102,8 +106,8 @@ class RemoteHandler(logging.Handler):
host: str, host: str,
port: int, port: int,
maximum_buffer: int = 1000, maximum_buffer: int = 1000,
level=logging.NOTSET, level: int = logging.NOTSET,
_reactor=None, _reactor: Optional[IReactorTCP] = None,
): ):
super().__init__(level=level) super().__init__(level=level)
self.host = host self.host = host
@ -118,7 +122,7 @@ class RemoteHandler(logging.Handler):
if _reactor is None: if _reactor is None:
from twisted.internet import reactor from twisted.internet import reactor
_reactor = reactor _reactor = reactor # type: ignore[assignment]
try: try:
ip = ip_address(self.host) ip = ip_address(self.host)
@ -139,7 +143,7 @@ class RemoteHandler(logging.Handler):
self._stopping = False self._stopping = False
self._connect() self._connect()
def close(self): def close(self) -> None:
self._stopping = True self._stopping = True
self._service.stopService() self._service.stopService()

View file

@ -16,6 +16,8 @@
import logging import logging
import traceback import traceback
from io import StringIO from io import StringIO
from types import TracebackType
from typing import Optional, Tuple, Type
class LogFormatter(logging.Formatter): class LogFormatter(logging.Formatter):
@ -28,10 +30,14 @@ class LogFormatter(logging.Formatter):
where it was caught are logged). where it was caught are logged).
""" """
def __init__(self, *args, **kwargs): def formatException(
super().__init__(*args, **kwargs) self,
ei: Tuple[
def formatException(self, ei): Optional[Type[BaseException]],
Optional[BaseException],
Optional[TracebackType],
],
) -> str:
sio = StringIO() sio = StringIO()
(typ, val, tb) = ei (typ, val, tb) = ei

View file

@ -49,7 +49,7 @@ class PeriodicallyFlushingMemoryHandler(MemoryHandler):
) )
self._flushing_thread.start() self._flushing_thread.start()
def on_reactor_running(): def on_reactor_running() -> None:
self._reactor_started = True self._reactor_started = True
reactor_to_use: IReactorCore reactor_to_use: IReactorCore
@ -74,7 +74,7 @@ class PeriodicallyFlushingMemoryHandler(MemoryHandler):
else: else:
return True return True
def _flush_periodically(self): def _flush_periodically(self) -> None:
""" """
Whilst this handler is active, flush the handler periodically. Whilst this handler is active, flush the handler periodically.
""" """

View file

@ -13,6 +13,8 @@
# limitations under the License.import logging # limitations under the License.import logging
import logging import logging
from types import TracebackType
from typing import Optional, Type
from opentracing import Scope, ScopeManager from opentracing import Scope, ScopeManager
@ -107,19 +109,26 @@ class _LogContextScope(Scope):
and - if enter_logcontext was set - the logcontext is finished too. and - if enter_logcontext was set - the logcontext is finished too.
""" """
def __init__(self, manager, span, logcontext, enter_logcontext, finish_on_close): def __init__(
self,
manager: LogContextScopeManager,
span,
logcontext,
enter_logcontext: bool,
finish_on_close: bool,
):
""" """
Args: Args:
manager (LogContextScopeManager): manager:
the manager that is responsible for this scope. the manager that is responsible for this scope.
span (Span): span (Span):
the opentracing span which this scope represents the local the opentracing span which this scope represents the local
lifetime for. lifetime for.
logcontext (LogContext): logcontext (LogContext):
the logcontext to which this scope is attached. the logcontext to which this scope is attached.
enter_logcontext (Boolean): enter_logcontext:
if True the logcontext will be exited when the scope is finished if True the logcontext will be exited when the scope is finished
finish_on_close (Boolean): finish_on_close:
if True finish the span when the scope is closed if True finish the span when the scope is closed
""" """
super().__init__(manager, span) super().__init__(manager, span)
@ -127,16 +136,21 @@ class _LogContextScope(Scope):
self._finish_on_close = finish_on_close self._finish_on_close = finish_on_close
self._enter_logcontext = enter_logcontext self._enter_logcontext = enter_logcontext
def __exit__(self, exc_type, value, traceback): def __exit__(
self,
exc_type: Optional[Type[BaseException]],
value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
if exc_type == twisted.internet.defer._DefGen_Return: if exc_type == twisted.internet.defer._DefGen_Return:
# filter out defer.returnValue() calls # filter out defer.returnValue() calls
exc_type = value = traceback = None exc_type = value = traceback = None
super().__exit__(exc_type, value, traceback) super().__exit__(exc_type, value, traceback)
def __str__(self): def __str__(self) -> str:
return f"Scope<{self.span}>" return f"Scope<{self.span}>"
def close(self): def close(self) -> None:
active_scope = self.manager.active active_scope = self.manager.active
if active_scope is not self: if active_scope is not self:
logger.error( logger.error(

View file

@ -12,20 +12,24 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from types import TracebackType
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any,
AsyncContextManager, AsyncContextManager,
Awaitable, Awaitable,
Callable, Callable,
Dict, Dict,
Iterable, Iterable,
List,
Optional, Optional,
Type,
) )
import attr import attr
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.types import Connection from synapse.storage.types import Connection, Cursor
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import Clock, json_encoder from synapse.util import Clock, json_encoder
@ -74,7 +78,12 @@ class _BackgroundUpdateContextManager:
return self._update_duration_ms return self._update_duration_ms
async def __aexit__(self, *exc) -> None: async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType],
) -> None:
pass pass
@ -352,7 +361,7 @@ class BackgroundUpdater:
True if we have finished running all the background updates, otherwise False True if we have finished running all the background updates, otherwise False
""" """
def get_background_updates_txn(txn): def get_background_updates_txn(txn: Cursor) -> List[Dict[str, Any]]:
txn.execute( txn.execute(
""" """
SELECT update_name, depends_on FROM background_updates SELECT update_name, depends_on FROM background_updates
@ -469,7 +478,7 @@ class BackgroundUpdater:
self, self,
update_name: str, update_name: str,
update_handler: Callable[[JsonDict, int], Awaitable[int]], update_handler: Callable[[JsonDict, int], Awaitable[int]],
): ) -> None:
"""Register a handler for doing a background update. """Register a handler for doing a background update.
The handler should take two arguments: The handler should take two arguments:
@ -603,7 +612,7 @@ class BackgroundUpdater:
else: else:
runner = create_index_sqlite runner = create_index_sqlite
async def updater(progress, batch_size): async def updater(progress: JsonDict, batch_size: int) -> int:
if runner is not None: if runner is not None:
logger.info("Adding index %s to %s", index_name, table) logger.info("Adding index %s to %s", index_name, table)
await self.db_pool.runWithConnection(runner) await self.db_pool.runWithConnection(runner)

View file

@ -24,6 +24,7 @@ from typing import (
Mapping, Mapping,
Match, Match,
MutableMapping, MutableMapping,
NoReturn,
Optional, Optional,
Set, Set,
Tuple, Tuple,
@ -35,6 +36,7 @@ from typing import (
import attr import attr
from frozendict import frozendict from frozendict import frozendict
from signedjson.key import decode_verify_key_bytes from signedjson.key import decode_verify_key_bytes
from signedjson.types import VerifyKey
from typing_extensions import TypedDict from typing_extensions import TypedDict
from unpaddedbase64 import decode_base64 from unpaddedbase64 import decode_base64
from zope.interface import Interface from zope.interface import Interface
@ -55,6 +57,7 @@ from synapse.util.stringutils import parse_and_validate_server_name
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.appservice.api import ApplicationService from synapse.appservice.api import ApplicationService
from synapse.storage.databases.main import DataStore, PurgeEventsStore from synapse.storage.databases.main import DataStore, PurgeEventsStore
from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore
# Define a state map type from type/state_key to T (usually an event ID or # Define a state map type from type/state_key to T (usually an event ID or
# event) # event)
@ -114,7 +117,7 @@ class Requester:
app_service: Optional["ApplicationService"] app_service: Optional["ApplicationService"]
authenticated_entity: str authenticated_entity: str
def serialize(self): def serialize(self) -> Dict[str, Any]:
"""Converts self to a type that can be serialized as JSON, and then """Converts self to a type that can be serialized as JSON, and then
deserialized by `deserialize` deserialized by `deserialize`
@ -132,7 +135,9 @@ class Requester:
} }
@staticmethod @staticmethod
def deserialize(store, input): def deserialize(
store: "ApplicationServiceWorkerStore", input: Dict[str, Any]
) -> "Requester":
"""Converts a dict that was produced by `serialize` back into a """Converts a dict that was produced by `serialize` back into a
Requester. Requester.
@ -236,10 +241,10 @@ class DomainSpecificString(metaclass=abc.ABCMeta):
domain: str domain: str
# Because this is a frozen class, it is deeply immutable. # Because this is a frozen class, it is deeply immutable.
def __copy__(self): def __copy__(self: DS) -> DS:
return self return self
def __deepcopy__(self, memo): def __deepcopy__(self: DS, memo: Dict[str, object]) -> DS:
return self return self
@classmethod @classmethod
@ -729,12 +734,14 @@ class StreamToken:
) )
@property @property
def room_stream_id(self): def room_stream_id(self) -> int:
return self.room_key.stream return self.room_key.stream
def copy_and_advance(self, key, new_value) -> "StreamToken": def copy_and_advance(self, key: str, new_value: Any) -> "StreamToken":
"""Advance the given key in the token to a new value if and only if the """Advance the given key in the token to a new value if and only if the
new value is after the old value. new value is after the old value.
:raises TypeError: if `key` is not the one of the keys tracked by a StreamToken.
""" """
if key == "room_key": if key == "room_key":
new_token = self.copy_and_replace( new_token = self.copy_and_replace(
@ -751,7 +758,7 @@ class StreamToken:
else: else:
return self return self
def copy_and_replace(self, key, new_value) -> "StreamToken": def copy_and_replace(self, key: str, new_value: Any) -> "StreamToken":
return attr.evolve(self, **{key: new_value}) return attr.evolve(self, **{key: new_value})
@ -793,14 +800,14 @@ class ThirdPartyInstanceID:
# Deny iteration because it will bite you if you try to create a singleton # Deny iteration because it will bite you if you try to create a singleton
# set by: # set by:
# users = set(user) # users = set(user)
def __iter__(self): def __iter__(self) -> NoReturn:
raise ValueError("Attempted to iterate a %s" % (type(self).__name__,)) raise ValueError("Attempted to iterate a %s" % (type(self).__name__,))
# Because this class is a frozen class, it is deeply immutable. # Because this class is a frozen class, it is deeply immutable.
def __copy__(self): def __copy__(self) -> "ThirdPartyInstanceID":
return self return self
def __deepcopy__(self, memo): def __deepcopy__(self, memo: Dict[str, object]) -> "ThirdPartyInstanceID":
return self return self
@classmethod @classmethod
@ -852,25 +859,28 @@ class DeviceListUpdates:
return bool(self.changed or self.left) return bool(self.changed or self.left)
def get_verify_key_from_cross_signing_key(key_info): def get_verify_key_from_cross_signing_key(
key_info: Mapping[str, Any]
) -> Tuple[str, VerifyKey]:
"""Get the key ID and signedjson verify key from a cross-signing key dict """Get the key ID and signedjson verify key from a cross-signing key dict
Args: Args:
key_info (dict): a cross-signing key dict, which must have a "keys" key_info: a cross-signing key dict, which must have a "keys"
property that has exactly one item in it property that has exactly one item in it
Returns: Returns:
(str, VerifyKey): the key ID and verify key for the cross-signing key the key ID and verify key for the cross-signing key
""" """
# make sure that exactly one key is provided # make sure that a `keys` field is provided
if "keys" not in key_info: if "keys" not in key_info:
raise ValueError("Invalid key") raise ValueError("Invalid key")
keys = key_info["keys"] keys = key_info["keys"]
if len(keys) != 1: # and that it contains exactly one key
raise ValueError("Invalid key") if len(keys) == 1:
# and return that one key key_id, key_data = next(iter(keys.items()))
for key_id, key_data in keys.items():
return key_id, decode_verify_key_bytes(key_id, decode_base64(key_data)) return key_id, decode_verify_key_bytes(key_id, decode_base64(key_data))
else:
raise ValueError("Invalid key")
@attr.s(auto_attribs=True, frozen=True, slots=True) @attr.s(auto_attribs=True, frozen=True, slots=True)