mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-14 14:03:54 +01:00
Verify signatures for server2server requests
This commit is contained in:
parent
10ef8e6e4b
commit
6684855767
5 changed files with 100 additions and 25 deletions
|
@ -22,6 +22,7 @@ from .transport import TransportLayer
|
|||
|
||||
def initialize_http_replication(homeserver):
|
||||
transport = TransportLayer(
|
||||
homeserver,
|
||||
homeserver.hostname,
|
||||
server=homeserver.get_resource_for_federation(),
|
||||
client=homeserver.get_http_client()
|
||||
|
|
|
@ -54,7 +54,7 @@ class TransportLayer(object):
|
|||
we receive data.
|
||||
"""
|
||||
|
||||
def __init__(self, server_name, server, client):
|
||||
def __init__(self, homeserver, server_name, server, client):
|
||||
"""
|
||||
Args:
|
||||
server_name (str): Local home server host
|
||||
|
@ -63,6 +63,7 @@ class TransportLayer(object):
|
|||
client (synapse.protocol.http.HttpClient): the http client used to
|
||||
send requests
|
||||
"""
|
||||
self.keyring = homeserver.get_keyring()
|
||||
self.server_name = server_name
|
||||
self.server = server
|
||||
self.client = client
|
||||
|
@ -195,6 +196,66 @@ class TransportLayer(object):
|
|||
|
||||
defer.returnValue(response)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _authenticate_request(self, request):
|
||||
json_request = {
|
||||
"method": request.method,
|
||||
"uri": request.uri,
|
||||
"destination": self.server_name,
|
||||
"signatures": {},
|
||||
}
|
||||
|
||||
content = None
|
||||
origin = None
|
||||
|
||||
if request.method == "PUT":
|
||||
#TODO: Handle other method types? other content types?
|
||||
content_bytes = request.content.read()
|
||||
content = json.loads(content_bytes)
|
||||
json_request["content"] = content
|
||||
|
||||
def parse_auth_header(header_str):
|
||||
params = auth.split(" ")[1].split(",")
|
||||
param_dict = dict(kv.split("=") for kv in params)
|
||||
def strip_quotes(value):
|
||||
if value.startswith("\""):
|
||||
return value[1:-1]
|
||||
else:
|
||||
return value
|
||||
origin = strip_quotes(param_dict["origin"])
|
||||
key = strip_quotes(param_dict["key"])
|
||||
sig = strip_quotes(param_dict["sig"])
|
||||
return (origin, key, sig)
|
||||
|
||||
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
|
||||
|
||||
if not auth_headers:
|
||||
#TODO(markjh): Send a 401 response?
|
||||
raise Exception("Missing auth headers")
|
||||
|
||||
for auth in auth_headers:
|
||||
if auth.startswith("X-Matrix"):
|
||||
(origin, key, sig) = parse_auth_header(auth)
|
||||
json_request["origin"] = origin
|
||||
json_request["signatures"].setdefault(origin,{})[key] = sig
|
||||
|
||||
from syutil.jsonutil import encode_canonical_json
|
||||
logger.debug("Checking %s %s",
|
||||
origin, encode_canonical_json(json_request))
|
||||
yield self.keyring.verify_json_for_server(origin, json_request)
|
||||
|
||||
defer.returnValue((origin, content))
|
||||
|
||||
def _with_authentication(self, handler):
|
||||
@defer.inlineCallbacks
|
||||
def new_handler(request, *args, **kwargs):
|
||||
(origin, content) = yield self._authenticate_request(request)
|
||||
response = yield handler(
|
||||
origin, content, request.args, *args, **kwargs
|
||||
)
|
||||
defer.returnValue(response)
|
||||
return new_handler
|
||||
|
||||
@log_function
|
||||
def register_received_handler(self, handler):
|
||||
""" Register a handler that will be fired when we receive data.
|
||||
|
@ -208,7 +269,7 @@ class TransportLayer(object):
|
|||
self.server.register_path(
|
||||
"PUT",
|
||||
re.compile("^" + PREFIX + "/send/([^/]*)/$"),
|
||||
self._on_send_request
|
||||
self._with_authentication(self._on_send_request)
|
||||
)
|
||||
|
||||
@log_function
|
||||
|
@ -226,9 +287,9 @@ class TransportLayer(object):
|
|||
self.server.register_path(
|
||||
"GET",
|
||||
re.compile("^" + PREFIX + "/pull/$"),
|
||||
lambda request: handler.on_pull_request(
|
||||
request.args["origin"][0],
|
||||
request.args["v"]
|
||||
self._with_authentication(
|
||||
lambda origin, content, query:
|
||||
handler.on_pull_request(query["origin"][0], query["v"])
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -237,8 +298,9 @@ class TransportLayer(object):
|
|||
self.server.register_path(
|
||||
"GET",
|
||||
re.compile("^" + PREFIX + "/pdu/([^/]*)/([^/]*)/$"),
|
||||
lambda request, pdu_origin, pdu_id: handler.on_pdu_request(
|
||||
pdu_origin, pdu_id
|
||||
self._with_authentication(
|
||||
lambda origin, content, query, pdu_origin, pdu_id:
|
||||
handler.on_pdu_request(pdu_origin, pdu_id)
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -246,38 +308,47 @@ class TransportLayer(object):
|
|||
self.server.register_path(
|
||||
"GET",
|
||||
re.compile("^" + PREFIX + "/state/([^/]*)/$"),
|
||||
lambda request, context: handler.on_context_state_request(
|
||||
context
|
||||
self._with_authentication(
|
||||
lambda origin, content, query, context:
|
||||
handler.on_context_state_request(context)
|
||||
)
|
||||
)
|
||||
|
||||
self.server.register_path(
|
||||
"GET",
|
||||
re.compile("^" + PREFIX + "/backfill/([^/]*)/$"),
|
||||
lambda request, context: self._on_backfill_request(
|
||||
context, request.args["v"],
|
||||
request.args["limit"]
|
||||
self._with_authentication(
|
||||
lambda origin, content, query, context:
|
||||
self._on_backfill_request(
|
||||
context, query["v"], query["limit"]
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
self.server.register_path(
|
||||
"GET",
|
||||
re.compile("^" + PREFIX + "/context/([^/]*)/$"),
|
||||
lambda request, context: handler.on_context_pdus_request(context)
|
||||
self._with_authentication(
|
||||
lambda origin, content, query, context:
|
||||
handler.on_context_pdus_request(context)
|
||||
)
|
||||
)
|
||||
|
||||
# This is when we receive a server-server Query
|
||||
self.server.register_path(
|
||||
"GET",
|
||||
re.compile("^" + PREFIX + "/query/([^/]*)$"),
|
||||
lambda request, query_type: handler.on_query_request(
|
||||
query_type, {k: v[0] for k, v in request.args.items()}
|
||||
self._with_authentication(
|
||||
lambda origin, content, query, query_type:
|
||||
handler.on_query_request(
|
||||
query_type, {k: v[0] for k, v in query.items()}
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def _on_send_request(self, request, transaction_id):
|
||||
def _on_send_request(self, origin, content, query, transaction_id):
|
||||
""" Called on PUT /send/<transaction_id>/
|
||||
|
||||
Args:
|
||||
|
@ -292,12 +363,7 @@ class TransportLayer(object):
|
|||
"""
|
||||
# Parse the request
|
||||
try:
|
||||
data = request.content.read()
|
||||
|
||||
l = data[:20].encode("string_escape")
|
||||
logger.debug("Got data: \"%s\"", l)
|
||||
|
||||
transaction_data = json.loads(data)
|
||||
transaction_data = content
|
||||
|
||||
logger.debug(
|
||||
"Decoded %s: %s",
|
||||
|
|
|
@ -177,16 +177,20 @@ class MatrixHttpClient(BaseHttpClient):
|
|||
|
||||
request = sign_json(request, self.server_name, self.signing_key)
|
||||
|
||||
from syutil.jsonutil import encode_canonical_json
|
||||
logger.debug("Signing " + " " * 11 + "%s %s",
|
||||
self.server_name, encode_canonical_json(request))
|
||||
|
||||
auth_headers = []
|
||||
|
||||
for key,sig in request["signatures"][self.server_name].items():
|
||||
auth_headers.append(
|
||||
auth_headers.append(bytes(
|
||||
"X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % (
|
||||
self.server_name, key, sig,
|
||||
)
|
||||
)
|
||||
))
|
||||
|
||||
headers_dict["Authorization"] = auth_headers
|
||||
headers_dict[b"Authorization"] = auth_headers
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def put_json(self, destination, path, data={}, json_data_callback=None):
|
||||
|
|
|
@ -221,6 +221,7 @@ class FederationTestCase(unittest.TestCase):
|
|||
json_data_callback=ANY,
|
||||
)
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_recv_edu(self):
|
||||
recv_observer = Mock()
|
||||
|
|
|
@ -76,6 +76,9 @@ class MockHttpResource(HttpServer):
|
|||
mock_content.configure_mock(**config)
|
||||
mock_request.content = mock_content
|
||||
|
||||
mock_request.method = http_method
|
||||
mock_request.uri = path
|
||||
|
||||
# return the right path if the event requires it
|
||||
mock_request.path = path
|
||||
|
||||
|
|
Loading…
Reference in a new issue