0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-15 04:33:53 +01:00

Respond with more helpful error messages for unsigned requests

This commit is contained in:
Mark Haines 2014-10-13 16:39:15 +01:00
parent 25d80f35f1
commit 07639c79d9
6 changed files with 45 additions and 10 deletions

View file

@ -19,6 +19,7 @@ import logging
class Codes(object): class Codes(object):
UNAUTHORIZED = "M_UNAUTHORIZED"
FORBIDDEN = "M_FORBIDDEN" FORBIDDEN = "M_FORBIDDEN"
BAD_JSON = "M_BAD_JSON" BAD_JSON = "M_BAD_JSON"
NOT_JSON = "M_NOT_JSON" NOT_JSON = "M_NOT_JSON"

View file

@ -43,7 +43,7 @@ def fetch_server_key(server_name, ssl_context_factory):
return return
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)
raise IOError("Cannot get key for " % server_name) raise IOError("Cannot get key for %s" % server_name)
class SynapseKeyClientError(Exception): class SynapseKeyClientError(Exception):
@ -93,7 +93,7 @@ class SynapseKeyClientProtocol(HTTPClient):
def on_timeout(self): def on_timeout(self):
logger.debug("Timeout waiting for response from %s", logger.debug("Timeout waiting for response from %s",
self.transport.getHost()) self.transport.getHost())
self.on_remote_key.errback(IOError("Timeout waiting for response")) self.remote_key.errback(IOError("Timeout waiting for response"))
self.transport.abortConnection() self.transport.abortConnection()

View file

@ -20,6 +20,7 @@ from syutil.crypto.signing_key import (
is_signing_algorithm_supported, decode_verify_key_bytes is_signing_algorithm_supported, decode_verify_key_bytes
) )
from syutil.base64util import decode_base64, encode_base64 from syutil.base64util import decode_base64, encode_base64
from synapse.api.errors import SynapseError, Codes
from OpenSSL import crypto from OpenSSL import crypto
@ -38,8 +39,36 @@ class Keyring(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def verify_json_for_server(self, server_name, json_object): def verify_json_for_server(self, server_name, json_object):
key_ids = signature_ids(json_object, server_name) key_ids = signature_ids(json_object, server_name)
verify_key = yield self.get_server_verify_key(server_name, key_ids) if not key_ids:
verify_signed_json(json_object, server_name, verify_key) raise SynapseError(
400,
"No supported algorithms in signing keys",
Codes.UNAUTHORIZED,
)
try:
verify_key = yield self.get_server_verify_key(server_name, key_ids)
except IOError:
raise SynapseError(
502,
"Error downloading keys for %s" % (server_name,),
Codes.UNAUTHORIZED,
)
except:
raise SynapseError(
401,
"No key for %s with id %s" % (server_name, key_ids),
Codes.UNAUTHORIZED,
)
try:
verify_signed_json(json_object, server_name, verify_key)
except:
raise SynapseError(
401,
"Invalid signature for server %s with key %s:%s" % (
server_name, verify_key.alg, verify_key.version
),
Codes.UNAUTHORIZED,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_server_verify_key(self, server_name, key_ids): def get_server_verify_key(self, server_name, key_ids):

View file

@ -233,7 +233,7 @@ class TransportLayer(object):
return (origin, key, sig) return (origin, key, sig)
except: except:
raise SynapseError( raise SynapseError(
400, "Malformed Authorization Header", Codes.FORBIDDEN 400, "Malformed Authorization header", Codes.UNAUTHORIZED
) )
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
@ -246,7 +246,7 @@ class TransportLayer(object):
if not json_request["signatures"]: if not json_request["signatures"]:
raise SynapseError( raise SynapseError(
401, "Missing Authorization headers", Codes.FORBIDDEN, 401, "Missing Authorization headers", Codes.UNAUTHORIZED,
) )
yield self.keyring.verify_json_for_server(origin, json_request) yield self.keyring.verify_json_for_server(origin, json_request)

View file

@ -121,7 +121,7 @@ class SQLBaseStore(object):
# "Simple" SQL API methods that operate on a single table with no JOINs, # "Simple" SQL API methods that operate on a single table with no JOINs,
# no complex WHERE clauses, just a dict of values for columns. # no complex WHERE clauses, just a dict of values for columns.
def _simple_insert(self, table, values, or_replace=False): def _simple_insert(self, table, values, or_replace=False, or_ignore=False):
"""Executes an INSERT query on the named table. """Executes an INSERT query on the named table.
Args: Args:
@ -130,13 +130,16 @@ class SQLBaseStore(object):
or_replace : bool; if True performs an INSERT OR REPLACE or_replace : bool; if True performs an INSERT OR REPLACE
""" """
return self.runInteraction( return self.runInteraction(
self._simple_insert_txn, table, values, or_replace=or_replace self._simple_insert_txn, table, values, or_replace=or_replace,
or_ignore=or_ignore,
) )
@log_function @log_function
def _simple_insert_txn(self, txn, table, values, or_replace=False): def _simple_insert_txn(self, txn, table, values, or_replace=False,
or_ignore=False):
sql = "%s INTO %s (%s) VALUES(%s)" % ( sql = "%s INTO %s (%s) VALUES(%s)" % (
("INSERT OR REPLACE" if or_replace else "INSERT"), ("INSERT OR REPLACE" if or_replace else
"INSERT OR IGNORE" if or_ignore else "INSERT"),
table, table,
", ".join(k for k in values), ", ".join(k for k in values),
", ".join("?" for k in values) ", ".join("?" for k in values)

View file

@ -65,6 +65,7 @@ class KeyStore(SQLBaseStore):
"ts_added_ms": time_now_ms, "ts_added_ms": time_now_ms,
"tls_certificate": buffer(tls_certificate_bytes), "tls_certificate": buffer(tls_certificate_bytes),
}, },
or_ignore=True,
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -113,4 +114,5 @@ class KeyStore(SQLBaseStore):
"ts_added_ms": time_now_ms, "ts_added_ms": time_now_ms,
"verify_key": buffer(verify_key.encode()), "verify_key": buffer(verify_key.encode()),
}, },
or_ignore=True,
) )