diff --git a/changelog.d/11146.misc b/changelog.d/11146.misc new file mode 100644 index 000000000..6ce1c9f9f --- /dev/null +++ b/changelog.d/11146.misc @@ -0,0 +1 @@ +Add missing type hints to `synapse.crypto`. diff --git a/mypy.ini b/mypy.ini index 14d8bb8ea..c5f44aea3 100644 --- a/mypy.ini +++ b/mypy.ini @@ -103,6 +103,9 @@ files = [mypy-synapse.api.*] disallow_untyped_defs = True +[mypy-synapse.crypto.*] +disallow_untyped_defs = True + [mypy-synapse.events.*] disallow_untyped_defs = True diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py index 2a6110eb1..7855f3498 100644 --- a/synapse/crypto/context_factory.py +++ b/synapse/crypto/context_factory.py @@ -29,9 +29,12 @@ from twisted.internet.ssl import ( TLSVersion, platformTrust, ) +from twisted.protocols.tls import TLSMemoryBIOProtocol from twisted.python.failure import Failure from twisted.web.iweb import IPolicyForHTTPS +from synapse.config.homeserver import HomeServerConfig + logger = logging.getLogger(__name__) @@ -51,7 +54,7 @@ class ServerContextFactory(ContextFactory): per https://github.com/matrix-org/synapse/issues/1691 """ - def __init__(self, config): + def __init__(self, config: HomeServerConfig): # TODO: once pyOpenSSL exposes TLS_METHOD and SSL_CTX_set_min_proto_version, # switch to those (see https://github.com/pyca/cryptography/issues/5379). # @@ -64,7 +67,7 @@ class ServerContextFactory(ContextFactory): self.configure_context(self._context, config) @staticmethod - def configure_context(context, config): + def configure_context(context: SSL.Context, config: HomeServerConfig) -> None: try: _ecCurve = crypto.get_elliptic_curve(_defaultCurveName) context.set_tmp_ecdh(_ecCurve) @@ -75,14 +78,15 @@ class ServerContextFactory(ContextFactory): SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3 | SSL.OP_NO_TLSv1 | SSL.OP_NO_TLSv1_1 ) context.use_certificate_chain_file(config.tls.tls_certificate_file) + assert config.tls.tls_private_key is not None context.use_privatekey(config.tls.tls_private_key) # https://hynek.me/articles/hardening-your-web-servers-ssl-ciphers/ context.set_cipher_list( - "ECDH+AESGCM:ECDH+CHACHA20:ECDH+AES256:ECDH+AES128:!aNULL:!SHA1:!AESCCM" + b"ECDH+AESGCM:ECDH+CHACHA20:ECDH+AES256:ECDH+AES128:!aNULL:!SHA1:!AESCCM" ) - def getContext(self): + def getContext(self) -> SSL.Context: return self._context @@ -98,7 +102,7 @@ class FederationPolicyForHTTPS: constructs an SSLClientConnectionCreator factory accordingly. """ - def __init__(self, config): + def __init__(self, config: HomeServerConfig): self._config = config # Check if we're using a custom list of a CA certificates @@ -131,7 +135,7 @@ class FederationPolicyForHTTPS: self._config.tls.federation_certificate_verification_whitelist ) - def get_options(self, host: bytes): + def get_options(self, host: bytes) -> IOpenSSLClientConnectionCreator: # IPolicyForHTTPS.get_options takes bytes, but we want to compare # against the str whitelist. The hostnames in the whitelist are already # IDNA-encoded like the hosts will be here. @@ -153,7 +157,9 @@ class FederationPolicyForHTTPS: return SSLClientConnectionCreator(host, ssl_context, should_verify) - def creatorForNetloc(self, hostname, port): + def creatorForNetloc( + self, hostname: bytes, port: int + ) -> IOpenSSLClientConnectionCreator: """Implements the IPolicyForHTTPS interface so that this can be passed directly to agents. """ @@ -169,16 +175,18 @@ class RegularPolicyForHTTPS: trust root. """ - def __init__(self): + def __init__(self) -> None: trust_root = platformTrust() self._ssl_context = CertificateOptions(trustRoot=trust_root).getContext() self._ssl_context.set_info_callback(_context_info_cb) - def creatorForNetloc(self, hostname, port): + def creatorForNetloc( + self, hostname: bytes, port: int + ) -> IOpenSSLClientConnectionCreator: return SSLClientConnectionCreator(hostname, self._ssl_context, True) -def _context_info_cb(ssl_connection, where, ret): +def _context_info_cb(ssl_connection: SSL.Connection, where: int, ret: int) -> None: """The 'information callback' for our openssl context objects. Note: Once this is set as the info callback on a Context object, the Context should @@ -204,11 +212,13 @@ class SSLClientConnectionCreator: Replaces twisted.internet.ssl.ClientTLSOptions """ - def __init__(self, hostname: bytes, ctx, verify_certs: bool): + def __init__(self, hostname: bytes, ctx: SSL.Context, verify_certs: bool): self._ctx = ctx self._verifier = ConnectionVerifier(hostname, verify_certs) - def clientConnectionForTLS(self, tls_protocol): + def clientConnectionForTLS( + self, tls_protocol: TLSMemoryBIOProtocol + ) -> SSL.Connection: context = self._ctx connection = SSL.Connection(context, None) @@ -219,7 +229,7 @@ class SSLClientConnectionCreator: # ... and we also gut-wrench a '_synapse_tls_verifier' attribute into the # tls_protocol so that the SSL context's info callback has something to # call to do the cert verification. - tls_protocol._synapse_tls_verifier = self._verifier + tls_protocol._synapse_tls_verifier = self._verifier # type: ignore[attr-defined] return connection @@ -244,7 +254,9 @@ class ConnectionVerifier: self._hostnameBytes = hostname self._hostnameASCII = self._hostnameBytes.decode("ascii") - def verify_context_info_cb(self, ssl_connection, where): + def verify_context_info_cb( + self, ssl_connection: SSL.Connection, where: int + ) -> None: if where & SSL.SSL_CB_HANDSHAKE_START and not self._is_ip_address: ssl_connection.set_tlsext_host_name(self._hostnameBytes) diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py index 0f2b632e4..7520647d1 100644 --- a/synapse/crypto/event_signing.py +++ b/synapse/crypto/event_signing.py @@ -100,7 +100,7 @@ def compute_content_hash( def compute_event_reference_hash( - event, hash_algorithm: Hasher = hashlib.sha256 + event: EventBase, hash_algorithm: Hasher = hashlib.sha256 ) -> Tuple[str, bytes]: """Computes the event reference hash. This is the hash of the redacted event. diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index e1e13a241..8628e951c 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -87,7 +87,7 @@ class VerifyJsonRequest: server_name: str, json_object: JsonDict, minimum_valid_until_ms: int, - ): + ) -> "VerifyJsonRequest": """Create a VerifyJsonRequest to verify all signatures on a signed JSON object for the given server. """ @@ -104,7 +104,7 @@ class VerifyJsonRequest: server_name: str, event: EventBase, minimum_valid_until_ms: int, - ): + ) -> "VerifyJsonRequest": """Create a VerifyJsonRequest to verify all signatures on an event object for the given server. """ @@ -449,7 +449,9 @@ class StoreKeyFetcher(KeyFetcher): self.store = hs.get_datastore() - async def _fetch_keys(self, keys_to_fetch: List[_FetchKeyRequest]): + async def _fetch_keys( + self, keys_to_fetch: List[_FetchKeyRequest] + ) -> Dict[str, Dict[str, FetchKeyResult]]: key_ids_to_fetch = ( (queue_value.server_name, key_id) for queue_value in keys_to_fetch