0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-13 21:33:20 +01:00

Add missing type hints to synapse.crypto. (#11146)

And require type hints for this module.
This commit is contained in:
Patrick Cloke 2021-10-21 09:07:07 -04:00 committed by GitHub
parent 09eff1b3db
commit 0f9adc99ad
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 36 additions and 18 deletions

1
changelog.d/11146.misc Normal file
View file

@ -0,0 +1 @@
Add missing type hints to `synapse.crypto`.

View file

@ -103,6 +103,9 @@ files =
[mypy-synapse.api.*] [mypy-synapse.api.*]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-synapse.crypto.*]
disallow_untyped_defs = True
[mypy-synapse.events.*] [mypy-synapse.events.*]
disallow_untyped_defs = True disallow_untyped_defs = True

View file

@ -29,9 +29,12 @@ from twisted.internet.ssl import (
TLSVersion, TLSVersion,
platformTrust, platformTrust,
) )
from twisted.protocols.tls import TLSMemoryBIOProtocol
from twisted.python.failure import Failure from twisted.python.failure import Failure
from twisted.web.iweb import IPolicyForHTTPS from twisted.web.iweb import IPolicyForHTTPS
from synapse.config.homeserver import HomeServerConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -51,7 +54,7 @@ class ServerContextFactory(ContextFactory):
per https://github.com/matrix-org/synapse/issues/1691 per https://github.com/matrix-org/synapse/issues/1691
""" """
def __init__(self, config): def __init__(self, config: HomeServerConfig):
# TODO: once pyOpenSSL exposes TLS_METHOD and SSL_CTX_set_min_proto_version, # TODO: once pyOpenSSL exposes TLS_METHOD and SSL_CTX_set_min_proto_version,
# switch to those (see https://github.com/pyca/cryptography/issues/5379). # switch to those (see https://github.com/pyca/cryptography/issues/5379).
# #
@ -64,7 +67,7 @@ class ServerContextFactory(ContextFactory):
self.configure_context(self._context, config) self.configure_context(self._context, config)
@staticmethod @staticmethod
def configure_context(context, config): def configure_context(context: SSL.Context, config: HomeServerConfig) -> None:
try: try:
_ecCurve = crypto.get_elliptic_curve(_defaultCurveName) _ecCurve = crypto.get_elliptic_curve(_defaultCurveName)
context.set_tmp_ecdh(_ecCurve) context.set_tmp_ecdh(_ecCurve)
@ -75,14 +78,15 @@ class ServerContextFactory(ContextFactory):
SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3 | SSL.OP_NO_TLSv1 | SSL.OP_NO_TLSv1_1 SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3 | SSL.OP_NO_TLSv1 | SSL.OP_NO_TLSv1_1
) )
context.use_certificate_chain_file(config.tls.tls_certificate_file) context.use_certificate_chain_file(config.tls.tls_certificate_file)
assert config.tls.tls_private_key is not None
context.use_privatekey(config.tls.tls_private_key) context.use_privatekey(config.tls.tls_private_key)
# https://hynek.me/articles/hardening-your-web-servers-ssl-ciphers/ # https://hynek.me/articles/hardening-your-web-servers-ssl-ciphers/
context.set_cipher_list( context.set_cipher_list(
"ECDH+AESGCM:ECDH+CHACHA20:ECDH+AES256:ECDH+AES128:!aNULL:!SHA1:!AESCCM" b"ECDH+AESGCM:ECDH+CHACHA20:ECDH+AES256:ECDH+AES128:!aNULL:!SHA1:!AESCCM"
) )
def getContext(self): def getContext(self) -> SSL.Context:
return self._context return self._context
@ -98,7 +102,7 @@ class FederationPolicyForHTTPS:
constructs an SSLClientConnectionCreator factory accordingly. constructs an SSLClientConnectionCreator factory accordingly.
""" """
def __init__(self, config): def __init__(self, config: HomeServerConfig):
self._config = config self._config = config
# Check if we're using a custom list of a CA certificates # Check if we're using a custom list of a CA certificates
@ -131,7 +135,7 @@ class FederationPolicyForHTTPS:
self._config.tls.federation_certificate_verification_whitelist self._config.tls.federation_certificate_verification_whitelist
) )
def get_options(self, host: bytes): def get_options(self, host: bytes) -> IOpenSSLClientConnectionCreator:
# IPolicyForHTTPS.get_options takes bytes, but we want to compare # IPolicyForHTTPS.get_options takes bytes, but we want to compare
# against the str whitelist. The hostnames in the whitelist are already # against the str whitelist. The hostnames in the whitelist are already
# IDNA-encoded like the hosts will be here. # IDNA-encoded like the hosts will be here.
@ -153,7 +157,9 @@ class FederationPolicyForHTTPS:
return SSLClientConnectionCreator(host, ssl_context, should_verify) return SSLClientConnectionCreator(host, ssl_context, should_verify)
def creatorForNetloc(self, hostname, port): def creatorForNetloc(
self, hostname: bytes, port: int
) -> IOpenSSLClientConnectionCreator:
"""Implements the IPolicyForHTTPS interface so that this can be passed """Implements the IPolicyForHTTPS interface so that this can be passed
directly to agents. directly to agents.
""" """
@ -169,16 +175,18 @@ class RegularPolicyForHTTPS:
trust root. trust root.
""" """
def __init__(self): def __init__(self) -> None:
trust_root = platformTrust() trust_root = platformTrust()
self._ssl_context = CertificateOptions(trustRoot=trust_root).getContext() self._ssl_context = CertificateOptions(trustRoot=trust_root).getContext()
self._ssl_context.set_info_callback(_context_info_cb) self._ssl_context.set_info_callback(_context_info_cb)
def creatorForNetloc(self, hostname, port): def creatorForNetloc(
self, hostname: bytes, port: int
) -> IOpenSSLClientConnectionCreator:
return SSLClientConnectionCreator(hostname, self._ssl_context, True) return SSLClientConnectionCreator(hostname, self._ssl_context, True)
def _context_info_cb(ssl_connection, where, ret): def _context_info_cb(ssl_connection: SSL.Connection, where: int, ret: int) -> None:
"""The 'information callback' for our openssl context objects. """The 'information callback' for our openssl context objects.
Note: Once this is set as the info callback on a Context object, the Context should Note: Once this is set as the info callback on a Context object, the Context should
@ -204,11 +212,13 @@ class SSLClientConnectionCreator:
Replaces twisted.internet.ssl.ClientTLSOptions Replaces twisted.internet.ssl.ClientTLSOptions
""" """
def __init__(self, hostname: bytes, ctx, verify_certs: bool): def __init__(self, hostname: bytes, ctx: SSL.Context, verify_certs: bool):
self._ctx = ctx self._ctx = ctx
self._verifier = ConnectionVerifier(hostname, verify_certs) self._verifier = ConnectionVerifier(hostname, verify_certs)
def clientConnectionForTLS(self, tls_protocol): def clientConnectionForTLS(
self, tls_protocol: TLSMemoryBIOProtocol
) -> SSL.Connection:
context = self._ctx context = self._ctx
connection = SSL.Connection(context, None) connection = SSL.Connection(context, None)
@ -219,7 +229,7 @@ class SSLClientConnectionCreator:
# ... and we also gut-wrench a '_synapse_tls_verifier' attribute into the # ... and we also gut-wrench a '_synapse_tls_verifier' attribute into the
# tls_protocol so that the SSL context's info callback has something to # tls_protocol so that the SSL context's info callback has something to
# call to do the cert verification. # call to do the cert verification.
tls_protocol._synapse_tls_verifier = self._verifier tls_protocol._synapse_tls_verifier = self._verifier # type: ignore[attr-defined]
return connection return connection
@ -244,7 +254,9 @@ class ConnectionVerifier:
self._hostnameBytes = hostname self._hostnameBytes = hostname
self._hostnameASCII = self._hostnameBytes.decode("ascii") self._hostnameASCII = self._hostnameBytes.decode("ascii")
def verify_context_info_cb(self, ssl_connection, where): def verify_context_info_cb(
self, ssl_connection: SSL.Connection, where: int
) -> None:
if where & SSL.SSL_CB_HANDSHAKE_START and not self._is_ip_address: if where & SSL.SSL_CB_HANDSHAKE_START and not self._is_ip_address:
ssl_connection.set_tlsext_host_name(self._hostnameBytes) ssl_connection.set_tlsext_host_name(self._hostnameBytes)

View file

@ -100,7 +100,7 @@ def compute_content_hash(
def compute_event_reference_hash( def compute_event_reference_hash(
event, hash_algorithm: Hasher = hashlib.sha256 event: EventBase, hash_algorithm: Hasher = hashlib.sha256
) -> Tuple[str, bytes]: ) -> Tuple[str, bytes]:
"""Computes the event reference hash. This is the hash of the redacted """Computes the event reference hash. This is the hash of the redacted
event. event.

View file

@ -87,7 +87,7 @@ class VerifyJsonRequest:
server_name: str, server_name: str,
json_object: JsonDict, json_object: JsonDict,
minimum_valid_until_ms: int, minimum_valid_until_ms: int,
): ) -> "VerifyJsonRequest":
"""Create a VerifyJsonRequest to verify all signatures on a signed JSON """Create a VerifyJsonRequest to verify all signatures on a signed JSON
object for the given server. object for the given server.
""" """
@ -104,7 +104,7 @@ class VerifyJsonRequest:
server_name: str, server_name: str,
event: EventBase, event: EventBase,
minimum_valid_until_ms: int, minimum_valid_until_ms: int,
): ) -> "VerifyJsonRequest":
"""Create a VerifyJsonRequest to verify all signatures on an event """Create a VerifyJsonRequest to verify all signatures on an event
object for the given server. object for the given server.
""" """
@ -449,7 +449,9 @@ class StoreKeyFetcher(KeyFetcher):
self.store = hs.get_datastore() self.store = hs.get_datastore()
async def _fetch_keys(self, keys_to_fetch: List[_FetchKeyRequest]): async def _fetch_keys(
self, keys_to_fetch: List[_FetchKeyRequest]
) -> Dict[str, Dict[str, FetchKeyResult]]:
key_ids_to_fetch = ( key_ids_to_fetch = (
(queue_value.server_name, key_id) (queue_value.server_name, key_id)
for queue_value in keys_to_fetch for queue_value in keys_to_fetch