mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-15 21:03:51 +01:00
Fail quicker for 4xx responses in the key client, optional hit a different API path
This commit is contained in:
parent
32e14d8181
commit
8d761134c2
1 changed files with 31 additions and 6 deletions
|
@ -25,12 +25,15 @@ import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
KEY_API_V1 = b"/_matrix/key/v1/"
|
||||||
|
KEY_API_V2 = b"/_matrix/key/v2/local"
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def fetch_server_key(server_name, ssl_context_factory):
|
def fetch_server_key(server_name, ssl_context_factory, path=KEY_API_V1):
|
||||||
"""Fetch the keys for a remote server."""
|
"""Fetch the keys for a remote server."""
|
||||||
|
|
||||||
factory = SynapseKeyClientFactory()
|
factory = SynapseKeyClientFactory()
|
||||||
|
factory.path = path
|
||||||
endpoint = matrix_federation_endpoint(
|
endpoint = matrix_federation_endpoint(
|
||||||
reactor, server_name, ssl_context_factory, timeout=30
|
reactor, server_name, ssl_context_factory, timeout=30
|
||||||
)
|
)
|
||||||
|
@ -42,13 +45,19 @@ def fetch_server_key(server_name, ssl_context_factory):
|
||||||
server_response, server_certificate = yield protocol.remote_key
|
server_response, server_certificate = yield protocol.remote_key
|
||||||
defer.returnValue((server_response, server_certificate))
|
defer.returnValue((server_response, server_certificate))
|
||||||
return
|
return
|
||||||
|
except SynapseKeyClientError as e:
|
||||||
|
logger.exception("Error getting key for %r" % (server_name,))
|
||||||
|
if e.status.startswith("4"):
|
||||||
|
# Don't retry for 4xx responses.
|
||||||
|
raise IOError("Cannot get key for %r" % server_name)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(e)
|
logger.exception(e)
|
||||||
raise IOError("Cannot get key for %s" % server_name)
|
raise IOError("Cannot get key for %r" % server_name)
|
||||||
|
|
||||||
|
|
||||||
class SynapseKeyClientError(Exception):
|
class SynapseKeyClientError(Exception):
|
||||||
"""The key wasn't retrieved from the remote server."""
|
"""The key wasn't retrieved from the remote server."""
|
||||||
|
status = None
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@ -66,17 +75,30 @@ class SynapseKeyClientProtocol(HTTPClient):
|
||||||
def connectionMade(self):
|
def connectionMade(self):
|
||||||
self.host = self.transport.getHost()
|
self.host = self.transport.getHost()
|
||||||
logger.debug("Connected to %s", self.host)
|
logger.debug("Connected to %s", self.host)
|
||||||
self.sendCommand(b"GET", b"/_matrix/key/v1/")
|
self.sendCommand(b"GET", self.path)
|
||||||
self.endHeaders()
|
self.endHeaders()
|
||||||
self.timer = reactor.callLater(
|
self.timer = reactor.callLater(
|
||||||
self.timeout,
|
self.timeout,
|
||||||
self.on_timeout
|
self.on_timeout
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def errback(self, error):
|
||||||
|
if not self.remote_key.called:
|
||||||
|
self.remote_key.errback(error)
|
||||||
|
|
||||||
|
def callback(self, result):
|
||||||
|
if not self.remote_key.called:
|
||||||
|
self.remote_key.callback(result)
|
||||||
|
|
||||||
def handleStatus(self, version, status, message):
|
def handleStatus(self, version, status, message):
|
||||||
if status != b"200":
|
if status != b"200":
|
||||||
# logger.info("Non-200 response from %s: %s %s",
|
# logger.info("Non-200 response from %s: %s %s",
|
||||||
# self.transport.getHost(), status, message)
|
# self.transport.getHost(), status, message)
|
||||||
|
error = SynapseKeyClientError("Non-200 response %r from %r" %
|
||||||
|
(status, self.host)
|
||||||
|
)
|
||||||
|
error.status = status
|
||||||
|
self.errback(error)
|
||||||
self.transport.abortConnection()
|
self.transport.abortConnection()
|
||||||
|
|
||||||
def handleResponse(self, response_body_bytes):
|
def handleResponse(self, response_body_bytes):
|
||||||
|
@ -89,15 +111,18 @@ class SynapseKeyClientProtocol(HTTPClient):
|
||||||
return
|
return
|
||||||
|
|
||||||
certificate = self.transport.getPeerCertificate()
|
certificate = self.transport.getPeerCertificate()
|
||||||
self.remote_key.callback((json_response, certificate))
|
self.callback((json_response, certificate))
|
||||||
self.transport.abortConnection()
|
self.transport.abortConnection()
|
||||||
self.timer.cancel()
|
self.timer.cancel()
|
||||||
|
|
||||||
def on_timeout(self):
|
def on_timeout(self):
|
||||||
logger.debug("Timeout waiting for response from %s", self.host)
|
logger.debug("Timeout waiting for response from %s", self.host)
|
||||||
self.remote_key.errback(IOError("Timeout waiting for response"))
|
self.errback(IOError("Timeout waiting for response"))
|
||||||
self.transport.abortConnection()
|
self.transport.abortConnection()
|
||||||
|
|
||||||
|
|
||||||
class SynapseKeyClientFactory(Factory):
|
class SynapseKeyClientFactory(Factory):
|
||||||
protocol = SynapseKeyClientProtocol
|
def protocol(self):
|
||||||
|
protocol = SynapseKeyClientProtocol()
|
||||||
|
protocol.path = self.path
|
||||||
|
return protocol
|
||||||
|
|
Loading…
Reference in a new issue