mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-16 08:53:53 +01:00
Respond with more helpful error messages for unsigned requests
This commit is contained in:
parent
25d80f35f1
commit
07639c79d9
6 changed files with 45 additions and 10 deletions
|
@ -19,6 +19,7 @@ import logging
|
||||||
|
|
||||||
|
|
||||||
class Codes(object):
|
class Codes(object):
|
||||||
|
UNAUTHORIZED = "M_UNAUTHORIZED"
|
||||||
FORBIDDEN = "M_FORBIDDEN"
|
FORBIDDEN = "M_FORBIDDEN"
|
||||||
BAD_JSON = "M_BAD_JSON"
|
BAD_JSON = "M_BAD_JSON"
|
||||||
NOT_JSON = "M_NOT_JSON"
|
NOT_JSON = "M_NOT_JSON"
|
||||||
|
|
|
@ -43,7 +43,7 @@ def fetch_server_key(server_name, ssl_context_factory):
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(e)
|
logger.exception(e)
|
||||||
raise IOError("Cannot get key for " % server_name)
|
raise IOError("Cannot get key for %s" % server_name)
|
||||||
|
|
||||||
|
|
||||||
class SynapseKeyClientError(Exception):
|
class SynapseKeyClientError(Exception):
|
||||||
|
@ -93,7 +93,7 @@ class SynapseKeyClientProtocol(HTTPClient):
|
||||||
def on_timeout(self):
|
def on_timeout(self):
|
||||||
logger.debug("Timeout waiting for response from %s",
|
logger.debug("Timeout waiting for response from %s",
|
||||||
self.transport.getHost())
|
self.transport.getHost())
|
||||||
self.on_remote_key.errback(IOError("Timeout waiting for response"))
|
self.remote_key.errback(IOError("Timeout waiting for response"))
|
||||||
self.transport.abortConnection()
|
self.transport.abortConnection()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -20,6 +20,7 @@ from syutil.crypto.signing_key import (
|
||||||
is_signing_algorithm_supported, decode_verify_key_bytes
|
is_signing_algorithm_supported, decode_verify_key_bytes
|
||||||
)
|
)
|
||||||
from syutil.base64util import decode_base64, encode_base64
|
from syutil.base64util import decode_base64, encode_base64
|
||||||
|
from synapse.api.errors import SynapseError, Codes
|
||||||
|
|
||||||
from OpenSSL import crypto
|
from OpenSSL import crypto
|
||||||
|
|
||||||
|
@ -38,8 +39,36 @@ class Keyring(object):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def verify_json_for_server(self, server_name, json_object):
|
def verify_json_for_server(self, server_name, json_object):
|
||||||
key_ids = signature_ids(json_object, server_name)
|
key_ids = signature_ids(json_object, server_name)
|
||||||
|
if not key_ids:
|
||||||
|
raise SynapseError(
|
||||||
|
400,
|
||||||
|
"No supported algorithms in signing keys",
|
||||||
|
Codes.UNAUTHORIZED,
|
||||||
|
)
|
||||||
|
try:
|
||||||
verify_key = yield self.get_server_verify_key(server_name, key_ids)
|
verify_key = yield self.get_server_verify_key(server_name, key_ids)
|
||||||
|
except IOError:
|
||||||
|
raise SynapseError(
|
||||||
|
502,
|
||||||
|
"Error downloading keys for %s" % (server_name,),
|
||||||
|
Codes.UNAUTHORIZED,
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
raise SynapseError(
|
||||||
|
401,
|
||||||
|
"No key for %s with id %s" % (server_name, key_ids),
|
||||||
|
Codes.UNAUTHORIZED,
|
||||||
|
)
|
||||||
|
try:
|
||||||
verify_signed_json(json_object, server_name, verify_key)
|
verify_signed_json(json_object, server_name, verify_key)
|
||||||
|
except:
|
||||||
|
raise SynapseError(
|
||||||
|
401,
|
||||||
|
"Invalid signature for server %s with key %s:%s" % (
|
||||||
|
server_name, verify_key.alg, verify_key.version
|
||||||
|
),
|
||||||
|
Codes.UNAUTHORIZED,
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_server_verify_key(self, server_name, key_ids):
|
def get_server_verify_key(self, server_name, key_ids):
|
||||||
|
|
|
@ -233,7 +233,7 @@ class TransportLayer(object):
|
||||||
return (origin, key, sig)
|
return (origin, key, sig)
|
||||||
except:
|
except:
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
400, "Malformed Authorization Header", Codes.FORBIDDEN
|
400, "Malformed Authorization header", Codes.UNAUTHORIZED
|
||||||
)
|
)
|
||||||
|
|
||||||
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
|
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
|
||||||
|
@ -246,7 +246,7 @@ class TransportLayer(object):
|
||||||
|
|
||||||
if not json_request["signatures"]:
|
if not json_request["signatures"]:
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
401, "Missing Authorization headers", Codes.FORBIDDEN,
|
401, "Missing Authorization headers", Codes.UNAUTHORIZED,
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self.keyring.verify_json_for_server(origin, json_request)
|
yield self.keyring.verify_json_for_server(origin, json_request)
|
||||||
|
|
|
@ -121,7 +121,7 @@ class SQLBaseStore(object):
|
||||||
# "Simple" SQL API methods that operate on a single table with no JOINs,
|
# "Simple" SQL API methods that operate on a single table with no JOINs,
|
||||||
# no complex WHERE clauses, just a dict of values for columns.
|
# no complex WHERE clauses, just a dict of values for columns.
|
||||||
|
|
||||||
def _simple_insert(self, table, values, or_replace=False):
|
def _simple_insert(self, table, values, or_replace=False, or_ignore=False):
|
||||||
"""Executes an INSERT query on the named table.
|
"""Executes an INSERT query on the named table.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -130,13 +130,16 @@ class SQLBaseStore(object):
|
||||||
or_replace : bool; if True performs an INSERT OR REPLACE
|
or_replace : bool; if True performs an INSERT OR REPLACE
|
||||||
"""
|
"""
|
||||||
return self.runInteraction(
|
return self.runInteraction(
|
||||||
self._simple_insert_txn, table, values, or_replace=or_replace
|
self._simple_insert_txn, table, values, or_replace=or_replace,
|
||||||
|
or_ignore=or_ignore,
|
||||||
)
|
)
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
def _simple_insert_txn(self, txn, table, values, or_replace=False):
|
def _simple_insert_txn(self, txn, table, values, or_replace=False,
|
||||||
|
or_ignore=False):
|
||||||
sql = "%s INTO %s (%s) VALUES(%s)" % (
|
sql = "%s INTO %s (%s) VALUES(%s)" % (
|
||||||
("INSERT OR REPLACE" if or_replace else "INSERT"),
|
("INSERT OR REPLACE" if or_replace else
|
||||||
|
"INSERT OR IGNORE" if or_ignore else "INSERT"),
|
||||||
table,
|
table,
|
||||||
", ".join(k for k in values),
|
", ".join(k for k in values),
|
||||||
", ".join("?" for k in values)
|
", ".join("?" for k in values)
|
||||||
|
|
|
@ -65,6 +65,7 @@ class KeyStore(SQLBaseStore):
|
||||||
"ts_added_ms": time_now_ms,
|
"ts_added_ms": time_now_ms,
|
||||||
"tls_certificate": buffer(tls_certificate_bytes),
|
"tls_certificate": buffer(tls_certificate_bytes),
|
||||||
},
|
},
|
||||||
|
or_ignore=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -113,4 +114,5 @@ class KeyStore(SQLBaseStore):
|
||||||
"ts_added_ms": time_now_ms,
|
"ts_added_ms": time_now_ms,
|
||||||
"verify_key": buffer(verify_key.encode()),
|
"verify_key": buffer(verify_key.encode()),
|
||||||
},
|
},
|
||||||
|
or_ignore=True,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in a new issue