mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-15 20:33:51 +01:00
Fail quicker for 4xx responses in the key client, optional hit a different API path
This commit is contained in:
parent
32e14d8181
commit
8d761134c2
1 changed files with 31 additions and 6 deletions
|
@ -25,12 +25,15 @@ import logging
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
KEY_API_V1 = b"/_matrix/key/v1/"
|
||||
KEY_API_V2 = b"/_matrix/key/v2/local"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def fetch_server_key(server_name, ssl_context_factory):
|
||||
def fetch_server_key(server_name, ssl_context_factory, path=KEY_API_V1):
|
||||
"""Fetch the keys for a remote server."""
|
||||
|
||||
factory = SynapseKeyClientFactory()
|
||||
factory.path = path
|
||||
endpoint = matrix_federation_endpoint(
|
||||
reactor, server_name, ssl_context_factory, timeout=30
|
||||
)
|
||||
|
@ -42,13 +45,19 @@ def fetch_server_key(server_name, ssl_context_factory):
|
|||
server_response, server_certificate = yield protocol.remote_key
|
||||
defer.returnValue((server_response, server_certificate))
|
||||
return
|
||||
except SynapseKeyClientError as e:
|
||||
logger.exception("Error getting key for %r" % (server_name,))
|
||||
if e.status.startswith("4"):
|
||||
# Don't retry for 4xx responses.
|
||||
raise IOError("Cannot get key for %r" % server_name)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
raise IOError("Cannot get key for %s" % server_name)
|
||||
raise IOError("Cannot get key for %r" % server_name)
|
||||
|
||||
|
||||
class SynapseKeyClientError(Exception):
|
||||
"""The key wasn't retrieved from the remote server."""
|
||||
status = None
|
||||
pass
|
||||
|
||||
|
||||
|
@ -66,17 +75,30 @@ class SynapseKeyClientProtocol(HTTPClient):
|
|||
def connectionMade(self):
|
||||
self.host = self.transport.getHost()
|
||||
logger.debug("Connected to %s", self.host)
|
||||
self.sendCommand(b"GET", b"/_matrix/key/v1/")
|
||||
self.sendCommand(b"GET", self.path)
|
||||
self.endHeaders()
|
||||
self.timer = reactor.callLater(
|
||||
self.timeout,
|
||||
self.on_timeout
|
||||
)
|
||||
|
||||
def errback(self, error):
|
||||
if not self.remote_key.called:
|
||||
self.remote_key.errback(error)
|
||||
|
||||
def callback(self, result):
|
||||
if not self.remote_key.called:
|
||||
self.remote_key.callback(result)
|
||||
|
||||
def handleStatus(self, version, status, message):
|
||||
if status != b"200":
|
||||
# logger.info("Non-200 response from %s: %s %s",
|
||||
# self.transport.getHost(), status, message)
|
||||
error = SynapseKeyClientError("Non-200 response %r from %r" %
|
||||
(status, self.host)
|
||||
)
|
||||
error.status = status
|
||||
self.errback(error)
|
||||
self.transport.abortConnection()
|
||||
|
||||
def handleResponse(self, response_body_bytes):
|
||||
|
@ -89,15 +111,18 @@ class SynapseKeyClientProtocol(HTTPClient):
|
|||
return
|
||||
|
||||
certificate = self.transport.getPeerCertificate()
|
||||
self.remote_key.callback((json_response, certificate))
|
||||
self.callback((json_response, certificate))
|
||||
self.transport.abortConnection()
|
||||
self.timer.cancel()
|
||||
|
||||
def on_timeout(self):
|
||||
logger.debug("Timeout waiting for response from %s", self.host)
|
||||
self.remote_key.errback(IOError("Timeout waiting for response"))
|
||||
self.errback(IOError("Timeout waiting for response"))
|
||||
self.transport.abortConnection()
|
||||
|
||||
|
||||
class SynapseKeyClientFactory(Factory):
|
||||
protocol = SynapseKeyClientProtocol
|
||||
def protocol(self):
|
||||
protocol = SynapseKeyClientProtocol()
|
||||
protocol.path = self.path
|
||||
return protocol
|
||||
|
|
Loading…
Reference in a new issue