0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-15 14:32:30 +01:00

Merge pull request #129 from matrix-org/key_distribution

Key distribution v2
This commit is contained in:
Erik Johnston 2015-04-29 13:34:38 +01:00
commit 64991b0c8b
15 changed files with 967 additions and 77 deletions

View file

@ -22,5 +22,6 @@ STATIC_PREFIX = "/_matrix/static"
WEB_CLIENT_PREFIX = "/_matrix/client" WEB_CLIENT_PREFIX = "/_matrix/client"
CONTENT_REPO_PREFIX = "/_matrix/content" CONTENT_REPO_PREFIX = "/_matrix/content"
SERVER_KEY_PREFIX = "/_matrix/key/v1" SERVER_KEY_PREFIX = "/_matrix/key/v1"
SERVER_KEY_V2_PREFIX = "/_matrix/key/v2"
MEDIA_PREFIX = "/_matrix/media/v1" MEDIA_PREFIX = "/_matrix/media/v1"
APP_SERVICE_PREFIX = "/_matrix/appservice/v1" APP_SERVICE_PREFIX = "/_matrix/appservice/v1"

View file

@ -36,10 +36,12 @@ from synapse.http.server import JsonResource, RootRedirect
from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.media.v1.media_repository import MediaRepositoryResource from synapse.rest.media.v1.media_repository import MediaRepositoryResource
from synapse.rest.key.v1.server_key_resource import LocalKey from synapse.rest.key.v1.server_key_resource import LocalKey
from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.api.urls import ( from synapse.api.urls import (
CLIENT_PREFIX, FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX, CLIENT_PREFIX, FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX,
SERVER_KEY_PREFIX, MEDIA_PREFIX, CLIENT_V2_ALPHA_PREFIX, STATIC_PREFIX SERVER_KEY_PREFIX, MEDIA_PREFIX, CLIENT_V2_ALPHA_PREFIX, STATIC_PREFIX,
SERVER_KEY_V2_PREFIX,
) )
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory from synapse.crypto import context_factory
@ -97,6 +99,9 @@ class SynapseHomeServer(HomeServer):
def build_resource_for_server_key(self): def build_resource_for_server_key(self):
return LocalKey(self) return LocalKey(self)
def build_resource_for_server_key_v2(self):
return KeyApiV2Resource(self)
def build_resource_for_metrics(self): def build_resource_for_metrics(self):
if self.get_config().enable_metrics: if self.get_config().enable_metrics:
return MetricsResource(self) return MetricsResource(self)
@ -134,6 +139,7 @@ class SynapseHomeServer(HomeServer):
(FEDERATION_PREFIX, self.get_resource_for_federation()), (FEDERATION_PREFIX, self.get_resource_for_federation()),
(CONTENT_REPO_PREFIX, self.get_resource_for_content_repo()), (CONTENT_REPO_PREFIX, self.get_resource_for_content_repo()),
(SERVER_KEY_PREFIX, self.get_resource_for_server_key()), (SERVER_KEY_PREFIX, self.get_resource_for_server_key()),
(SERVER_KEY_V2_PREFIX, self.get_resource_for_server_key_v2()),
(MEDIA_PREFIX, self.get_resource_for_media_repository()), (MEDIA_PREFIX, self.get_resource_for_media_repository()),
(STATIC_PREFIX, self.get_resource_for_static_content()), (STATIC_PREFIX, self.get_resource_for_static_content()),
] ]

View file

@ -77,6 +77,17 @@ class Config(object):
with open(file_path) as file_stream: with open(file_path) as file_stream:
return file_stream.read() return file_stream.read()
@classmethod
def read_yaml_file(cls, file_path, config_name):
cls.check_file(file_path, config_name)
with open(file_path) as file_stream:
try:
return yaml.load(file_stream)
except:
raise ConfigError(
"Error parsing yaml in file %r" % (file_path,)
)
@staticmethod @staticmethod
def default_path(name): def default_path(name):
return os.path.abspath(os.path.join(os.path.curdir, name)) return os.path.abspath(os.path.join(os.path.curdir, name))

View file

@ -24,12 +24,13 @@ from .voip import VoipConfig
from .registration import RegistrationConfig from .registration import RegistrationConfig
from .metrics import MetricsConfig from .metrics import MetricsConfig
from .appservice import AppServiceConfig from .appservice import AppServiceConfig
from .key import KeyConfig
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
RatelimitConfig, ContentRepositoryConfig, CaptchaConfig, RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
VoipConfig, RegistrationConfig, VoipConfig, RegistrationConfig,
MetricsConfig, AppServiceConfig,): MetricsConfig, AppServiceConfig, KeyConfig,):
pass pass

147
synapse/config/key.py Normal file
View file

@ -0,0 +1,147 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from ._base import Config, ConfigError
import syutil.crypto.signing_key
from syutil.crypto.signing_key import (
is_signing_algorithm_supported, decode_verify_key_bytes
)
from syutil.base64util import decode_base64
class KeyConfig(Config):
def __init__(self, args):
super(KeyConfig, self).__init__(args)
self.signing_key = self.read_signing_key(args.signing_key_path)
self.old_signing_keys = self.read_old_signing_keys(
args.old_signing_key_path
)
self.key_refresh_interval = args.key_refresh_interval
self.perspectives = self.read_perspectives(
args.perspectives_config_path
)
@classmethod
def add_arguments(cls, parser):
super(KeyConfig, cls).add_arguments(parser)
key_group = parser.add_argument_group("keys")
key_group.add_argument("--signing-key-path",
help="The signing key to sign messages with")
key_group.add_argument("--old-signing-key-path",
help="The keys that the server used to sign"
" sign messages with but won't use"
" to sign new messages. E.g. it has"
" lost its private key")
key_group.add_argument("--key-refresh-interval",
default=24 * 60 * 60 * 1000, # 1 Day
help="How long a key response is valid for."
" Used to set the exipiry in /key/v2/."
" Controls how frequently servers will"
" query what keys are still valid")
key_group.add_argument("--perspectives-config-path",
help="The trusted servers to download signing"
" keys from")
def read_perspectives(self, perspectives_config_path):
config = self.read_yaml_file(
perspectives_config_path, "perspectives_config_path"
)
servers = {}
for server_name, server_config in config["servers"].items():
for key_id, key_data in server_config["verify_keys"].items():
if is_signing_algorithm_supported(key_id):
key_base64 = key_data["key"]
key_bytes = decode_base64(key_base64)
verify_key = decode_verify_key_bytes(key_id, key_bytes)
servers.setdefault(server_name, {})[key_id] = verify_key
return servers
def read_signing_key(self, signing_key_path):
signing_keys = self.read_file(signing_key_path, "signing_key")
try:
return syutil.crypto.signing_key.read_signing_keys(
signing_keys.splitlines(True)
)
except Exception:
raise ConfigError(
"Error reading signing_key."
" Try running again with --generate-config"
)
def read_old_signing_keys(self, old_signing_key_path):
old_signing_keys = self.read_file(
old_signing_key_path, "old_signing_key"
)
try:
return syutil.crypto.signing_key.read_old_signing_keys(
old_signing_keys.splitlines(True)
)
except Exception:
raise ConfigError(
"Error reading old signing keys."
)
@classmethod
def generate_config(cls, args, config_dir_path):
super(KeyConfig, cls).generate_config(args, config_dir_path)
base_key_name = os.path.join(config_dir_path, args.server_name)
args.pid_file = os.path.abspath(args.pid_file)
if not args.signing_key_path:
args.signing_key_path = base_key_name + ".signing.key"
if not os.path.exists(args.signing_key_path):
with open(args.signing_key_path, "w") as signing_key_file:
syutil.crypto.signing_key.write_signing_keys(
signing_key_file,
(syutil.crypto.signing_key.generate_signing_key("auto"),),
)
else:
signing_keys = cls.read_file(args.signing_key_path, "signing_key")
if len(signing_keys.split("\n")[0].split()) == 1:
# handle keys in the old format.
key = syutil.crypto.signing_key.decode_signing_key_base64(
syutil.crypto.signing_key.NACL_ED25519,
"auto",
signing_keys.split("\n")[0]
)
with open(args.signing_key_path, "w") as signing_key_file:
syutil.crypto.signing_key.write_signing_keys(
signing_key_file,
(key,),
)
if not args.old_signing_key_path:
args.old_signing_key_path = base_key_name + ".old.signing.keys"
if not os.path.exists(args.old_signing_key_path):
with open(args.old_signing_key_path, "w"):
pass
if not args.perspectives_config_path:
args.perspectives_config_path = base_key_name + ".perspectives"
if not os.path.exists(args.perspectives_config_path):
with open(args.perspectives_config_path, "w") as perspectives_file:
perspectives_file.write(
'servers:\n'
' matrix.org:\n'
' verify_keys:\n'
' "ed25519:auto":\n'
' key: "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw"\n'
)

View file

@ -13,16 +13,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os from ._base import Config
from ._base import Config, ConfigError
import syutil.crypto.signing_key
class ServerConfig(Config): class ServerConfig(Config):
def __init__(self, args): def __init__(self, args):
super(ServerConfig, self).__init__(args) super(ServerConfig, self).__init__(args)
self.server_name = args.server_name self.server_name = args.server_name
self.signing_key = self.read_signing_key(args.signing_key_path)
self.bind_port = args.bind_port self.bind_port = args.bind_port
self.bind_host = args.bind_host self.bind_host = args.bind_host
self.unsecure_port = args.unsecure_port self.unsecure_port = args.unsecure_port
@ -53,8 +50,6 @@ class ServerConfig(Config):
"This is used by remote servers to connect to this server, " "This is used by remote servers to connect to this server, "
"e.g. matrix.org, localhost:8080, etc." "e.g. matrix.org, localhost:8080, etc."
) )
server_group.add_argument("--signing-key-path",
help="The signing key to sign messages with")
server_group.add_argument("-p", "--bind-port", metavar="PORT", server_group.add_argument("-p", "--bind-port", metavar="PORT",
type=int, help="https port to listen on", type=int, help="https port to listen on",
default=8448) default=8448)
@ -83,46 +78,3 @@ class ServerConfig(Config):
"Zero is used to indicate synapse " "Zero is used to indicate synapse "
"should set the soft limit to the hard" "should set the soft limit to the hard"
"limit.") "limit.")
def read_signing_key(self, signing_key_path):
signing_keys = self.read_file(signing_key_path, "signing_key")
try:
return syutil.crypto.signing_key.read_signing_keys(
signing_keys.splitlines(True)
)
except Exception:
raise ConfigError(
"Error reading signing_key."
" Try running again with --generate-config"
)
@classmethod
def generate_config(cls, args, config_dir_path):
super(ServerConfig, cls).generate_config(args, config_dir_path)
base_key_name = os.path.join(config_dir_path, args.server_name)
args.pid_file = os.path.abspath(args.pid_file)
if not args.signing_key_path:
args.signing_key_path = base_key_name + ".signing.key"
if not os.path.exists(args.signing_key_path):
with open(args.signing_key_path, "w") as signing_key_file:
syutil.crypto.signing_key.write_signing_keys(
signing_key_file,
(syutil.crypto.signing_key.generate_signing_key("auto"),),
)
else:
signing_keys = cls.read_file(args.signing_key_path, "signing_key")
if len(signing_keys.split("\n")[0].split()) == 1:
# handle keys in the old format.
key = syutil.crypto.signing_key.decode_signing_key_base64(
syutil.crypto.signing_key.NACL_ED25519,
"auto",
signing_keys.split("\n")[0]
)
with open(args.signing_key_path, "w") as signing_key_file:
syutil.crypto.signing_key.write_signing_keys(
signing_key_file,
(key,),
)

View file

@ -25,12 +25,15 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
KEY_API_V1 = b"/_matrix/key/v1/"
@defer.inlineCallbacks @defer.inlineCallbacks
def fetch_server_key(server_name, ssl_context_factory): def fetch_server_key(server_name, ssl_context_factory, path=KEY_API_V1):
"""Fetch the keys for a remote server.""" """Fetch the keys for a remote server."""
factory = SynapseKeyClientFactory() factory = SynapseKeyClientFactory()
factory.path = path
endpoint = matrix_federation_endpoint( endpoint = matrix_federation_endpoint(
reactor, server_name, ssl_context_factory, timeout=30 reactor, server_name, ssl_context_factory, timeout=30
) )
@ -42,13 +45,19 @@ def fetch_server_key(server_name, ssl_context_factory):
server_response, server_certificate = yield protocol.remote_key server_response, server_certificate = yield protocol.remote_key
defer.returnValue((server_response, server_certificate)) defer.returnValue((server_response, server_certificate))
return return
except SynapseKeyClientError as e:
logger.exception("Error getting key for %r" % (server_name,))
if e.status.startswith("4"):
# Don't retry for 4xx responses.
raise IOError("Cannot get key for %r" % server_name)
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)
raise IOError("Cannot get key for %s" % server_name) raise IOError("Cannot get key for %r" % server_name)
class SynapseKeyClientError(Exception): class SynapseKeyClientError(Exception):
"""The key wasn't retrieved from the remote server.""" """The key wasn't retrieved from the remote server."""
status = None
pass pass
@ -66,17 +75,30 @@ class SynapseKeyClientProtocol(HTTPClient):
def connectionMade(self): def connectionMade(self):
self.host = self.transport.getHost() self.host = self.transport.getHost()
logger.debug("Connected to %s", self.host) logger.debug("Connected to %s", self.host)
self.sendCommand(b"GET", b"/_matrix/key/v1/") self.sendCommand(b"GET", self.path)
self.endHeaders() self.endHeaders()
self.timer = reactor.callLater( self.timer = reactor.callLater(
self.timeout, self.timeout,
self.on_timeout self.on_timeout
) )
def errback(self, error):
if not self.remote_key.called:
self.remote_key.errback(error)
def callback(self, result):
if not self.remote_key.called:
self.remote_key.callback(result)
def handleStatus(self, version, status, message): def handleStatus(self, version, status, message):
if status != b"200": if status != b"200":
# logger.info("Non-200 response from %s: %s %s", # logger.info("Non-200 response from %s: %s %s",
# self.transport.getHost(), status, message) # self.transport.getHost(), status, message)
error = SynapseKeyClientError(
"Non-200 response %r from %r" % (status, self.host)
)
error.status = status
self.errback(error)
self.transport.abortConnection() self.transport.abortConnection()
def handleResponse(self, response_body_bytes): def handleResponse(self, response_body_bytes):
@ -89,15 +111,18 @@ class SynapseKeyClientProtocol(HTTPClient):
return return
certificate = self.transport.getPeerCertificate() certificate = self.transport.getPeerCertificate()
self.remote_key.callback((json_response, certificate)) self.callback((json_response, certificate))
self.transport.abortConnection() self.transport.abortConnection()
self.timer.cancel() self.timer.cancel()
def on_timeout(self): def on_timeout(self):
logger.debug("Timeout waiting for response from %s", self.host) logger.debug("Timeout waiting for response from %s", self.host)
self.remote_key.errback(IOError("Timeout waiting for response")) self.errback(IOError("Timeout waiting for response"))
self.transport.abortConnection() self.transport.abortConnection()
class SynapseKeyClientFactory(Factory): class SynapseKeyClientFactory(Factory):
protocol = SynapseKeyClientProtocol def protocol(self):
protocol = SynapseKeyClientProtocol()
protocol.path = self.path
return protocol

View file

@ -15,7 +15,9 @@
from synapse.crypto.keyclient import fetch_server_key from synapse.crypto.keyclient import fetch_server_key
from twisted.internet import defer from twisted.internet import defer
from syutil.crypto.jsonsign import verify_signed_json, signature_ids from syutil.crypto.jsonsign import (
verify_signed_json, signature_ids, sign_json, encode_canonical_json
)
from syutil.crypto.signing_key import ( from syutil.crypto.signing_key import (
is_signing_algorithm_supported, decode_verify_key_bytes is_signing_algorithm_supported, decode_verify_key_bytes
) )
@ -28,6 +30,8 @@ from synapse.util.async import create_observer
from OpenSSL import crypto from OpenSSL import crypto
import urllib
import hashlib
import logging import logging
@ -38,6 +42,9 @@ class Keyring(object):
def __init__(self, hs): def __init__(self, hs):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.client = hs.get_http_client()
self.config = hs.get_config()
self.perspective_servers = self.config.perspectives
self.hs = hs self.hs = hs
self.key_downloads = {} self.key_downloads = {}
@ -89,12 +96,11 @@ class Keyring(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_server_verify_key(self, server_name, key_ids): def get_server_verify_key(self, server_name, key_ids):
"""Finds a verification key for the server with one of the key ids. """Finds a verification key for the server with one of the key ids.
Trys to fetch the key from a trusted perspective server first.
Args: Args:
server_name (str): The name of the server to fetch a key for. server_name(str): The name of the server to fetch a key for.
keys_ids (list of str): The key_ids to check for. keys_ids (list of str): The key_ids to check for.
""" """
# Check the datastore to see if we have one cached.
cached = yield self.store.get_server_verify_keys(server_name, key_ids) cached = yield self.store.get_server_verify_keys(server_name, key_ids)
if cached: if cached:
@ -117,7 +123,29 @@ class Keyring(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_server_verify_key_impl(self, server_name, key_ids): def _get_server_verify_key_impl(self, server_name, key_ids):
# Try to fetch the key from the remote server. keys = None
perspective_results = []
for perspective_name, perspective_keys in self.perspective_servers.items():
@defer.inlineCallbacks
def get_key():
try:
result = yield self.get_server_verify_key_v2_indirect(
server_name, key_ids, perspective_name, perspective_keys
)
defer.returnValue(result)
except:
logging.info(
"Unable to getting key %r for %r from %r",
key_ids, server_name, perspective_name,
)
perspective_results.append(get_key())
perspective_results = yield defer.gatherResults(perspective_results)
for results in perspective_results:
if results is not None:
keys = results
limiter = yield get_retry_limiter( limiter = yield get_retry_limiter(
server_name, server_name,
@ -126,10 +154,234 @@ class Keyring(object):
) )
with limiter: with limiter:
(response, tls_certificate) = yield fetch_server_key( if keys is None:
server_name, self.hs.tls_context_factory try:
keys = yield self.get_server_verify_key_v2_direct(
server_name, key_ids
)
except:
pass
keys = yield self.get_server_verify_key_v1_direct(
server_name, key_ids
) )
for key_id in key_ids:
if key_id in keys:
defer.returnValue(keys[key_id])
return
raise ValueError("No verification key found for given key ids")
@defer.inlineCallbacks
def get_server_verify_key_v2_indirect(self, server_name, key_ids,
perspective_name,
perspective_keys):
limiter = yield get_retry_limiter(
perspective_name, self.clock, self.store
)
with limiter:
# TODO(mark): Set the minimum_valid_until_ts to that needed by
# the events being validated or the current time if validating
# an incoming request.
responses = yield self.client.post_json(
destination=perspective_name,
path=b"/_matrix/key/v2/query",
data={
u"server_keys": {
server_name: {
key_id: {
u"minimum_valid_until_ts": 0
} for key_id in key_ids
}
}
},
)
keys = {}
for response in responses:
if (u"signatures" not in response
or perspective_name not in response[u"signatures"]):
raise ValueError(
"Key response not signed by perspective server"
" %r" % (perspective_name,)
)
verified = False
for key_id in response[u"signatures"][perspective_name]:
if key_id in perspective_keys:
verify_signed_json(
response,
perspective_name,
perspective_keys[key_id]
)
verified = True
if not verified:
logging.info(
"Response from perspective server %r not signed with a"
" known key, signed with: %r, known keys: %r",
perspective_name,
list(response[u"signatures"][perspective_name]),
list(perspective_keys)
)
raise ValueError(
"Response not signed with a known key for perspective"
" server %r" % (perspective_name,)
)
response_keys = yield self.process_v2_response(
server_name, perspective_name, response
)
keys.update(response_keys)
yield self.store_keys(
server_name=server_name,
from_server=perspective_name,
verify_keys=keys,
)
defer.returnValue(keys)
@defer.inlineCallbacks
def get_server_verify_key_v2_direct(self, server_name, key_ids):
keys = {}
for requested_key_id in key_ids:
if requested_key_id in keys:
continue
(response, tls_certificate) = yield fetch_server_key(
server_name, self.hs.tls_context_factory,
path=(b"/_matrix/key/v2/server/%s" % (
urllib.quote(requested_key_id),
)).encode("ascii"),
)
if (u"signatures" not in response
or server_name not in response[u"signatures"]):
raise ValueError("Key response not signed by remote server")
if "tls_fingerprints" not in response:
raise ValueError("Key response missing TLS fingerprints")
certificate_bytes = crypto.dump_certificate(
crypto.FILETYPE_ASN1, tls_certificate
)
sha256_fingerprint = hashlib.sha256(certificate_bytes).digest()
sha256_fingerprint_b64 = encode_base64(sha256_fingerprint)
response_sha256_fingerprints = set()
for fingerprint in response[u"tls_fingerprints"]:
if u"sha256" in fingerprint:
response_sha256_fingerprints.add(fingerprint[u"sha256"])
if sha256_fingerprint_b64 not in response_sha256_fingerprints:
raise ValueError("TLS certificate not allowed by fingerprints")
response_keys = yield self.process_v2_response(
server_name=server_name,
from_server=server_name,
requested_id=requested_key_id,
response_json=response,
)
keys.update(response_keys)
yield self.store_keys(
server_name=server_name,
from_server=server_name,
verify_keys=keys,
)
defer.returnValue(keys)
@defer.inlineCallbacks
def process_v2_response(self, server_name, from_server, response_json,
requested_id=None):
time_now_ms = self.clock.time_msec()
response_keys = {}
verify_keys = {}
for key_id, key_data in response_json["verify_keys"].items():
if is_signing_algorithm_supported(key_id):
key_base64 = key_data["key"]
key_bytes = decode_base64(key_base64)
verify_key = decode_verify_key_bytes(key_id, key_bytes)
verify_key.time_added = time_now_ms
verify_keys[key_id] = verify_key
old_verify_keys = {}
for key_id, key_data in response_json["old_verify_keys"].items():
if is_signing_algorithm_supported(key_id):
key_base64 = key_data["key"]
key_bytes = decode_base64(key_base64)
verify_key = decode_verify_key_bytes(key_id, key_bytes)
verify_key.expired = key_data["expired_ts"]
verify_key.time_added = time_now_ms
old_verify_keys[key_id] = verify_key
for key_id in response_json["signatures"][server_name]:
if key_id not in response_json["verify_keys"]:
raise ValueError(
"Key response must include verification keys for all"
" signatures"
)
if key_id in verify_keys:
verify_signed_json(
response_json,
server_name,
verify_keys[key_id]
)
signed_key_json = sign_json(
response_json,
self.config.server_name,
self.config.signing_key[0],
)
signed_key_json_bytes = encode_canonical_json(signed_key_json)
ts_valid_until_ms = signed_key_json[u"valid_until_ts"]
updated_key_ids = set()
if requested_id is not None:
updated_key_ids.add(requested_id)
updated_key_ids.update(verify_keys)
updated_key_ids.update(old_verify_keys)
response_keys.update(verify_keys)
response_keys.update(old_verify_keys)
for key_id in updated_key_ids:
yield self.store.store_server_keys_json(
server_name=server_name,
key_id=key_id,
from_server=server_name,
ts_now_ms=time_now_ms,
ts_expires_ms=ts_valid_until_ms,
key_json_bytes=signed_key_json_bytes,
)
defer.returnValue(response_keys)
raise ValueError("No verification key found for given key ids")
@defer.inlineCallbacks
def get_server_verify_key_v1_direct(self, server_name, key_ids):
"""Finds a verification key for the server with one of the key ids.
Args:
server_name (str): The name of the server to fetch a key for.
keys_ids (list of str): The key_ids to check for.
"""
# Try to fetch the key from the remote server.
(response, tls_certificate) = yield fetch_server_key(
server_name, self.hs.tls_context_factory
)
# Check the response. # Check the response.
x509_certificate_bytes = crypto.dump_certificate( x509_certificate_bytes = crypto.dump_certificate(
@ -148,11 +400,16 @@ class Keyring(object):
if encode_base64(x509_certificate_bytes) != tls_certificate_b64: if encode_base64(x509_certificate_bytes) != tls_certificate_b64:
raise ValueError("TLS certificate doesn't match") raise ValueError("TLS certificate doesn't match")
# Cache the result in the datastore.
time_now_ms = self.clock.time_msec()
verify_keys = {} verify_keys = {}
for key_id, key_base64 in response["verify_keys"].items(): for key_id, key_base64 in response["verify_keys"].items():
if is_signing_algorithm_supported(key_id): if is_signing_algorithm_supported(key_id):
key_bytes = decode_base64(key_base64) key_bytes = decode_base64(key_base64)
verify_key = decode_verify_key_bytes(key_id, key_bytes) verify_key = decode_verify_key_bytes(key_id, key_bytes)
verify_key.time_added = time_now_ms
verify_keys[key_id] = verify_key verify_keys[key_id] = verify_key
for key_id in response["signatures"][server_name]: for key_id in response["signatures"][server_name]:
@ -168,10 +425,6 @@ class Keyring(object):
verify_keys[key_id] verify_keys[key_id]
) )
# Cache the result in the datastore.
time_now_ms = self.clock.time_msec()
yield self.store.store_server_certificate( yield self.store.store_server_certificate(
server_name, server_name,
server_name, server_name,
@ -179,14 +432,26 @@ class Keyring(object):
tls_certificate, tls_certificate,
) )
yield self.store_keys(
server_name=server_name,
from_server=server_name,
verify_keys=verify_keys,
)
defer.returnValue(verify_keys)
@defer.inlineCallbacks
def store_keys(self, server_name, from_server, verify_keys):
"""Store a collection of verify keys for a given server
Args:
server_name(str): The name of the server the keys are for.
from_server(str): The server the keys were downloaded from.
verify_keys(dict): A mapping of key_id to VerifyKey.
Returns:
A deferred that completes when the keys are stored.
"""
for key_id, key in verify_keys.items(): for key_id, key in verify_keys.items():
# TODO(markjh): Store whether the keys have expired.
yield self.store.store_server_verify_key( yield self.store.store_server_verify_key(
server_name, server_name, time_now_ms, key server_name, server_name, key.time_added, key
) )
for key_id in key_ids:
if key_id in verify_keys:
defer.returnValue(verify_keys[key_id])
return
raise ValueError("No verification key found for given key ids")

View file

@ -0,0 +1,25 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.web.resource import Resource
from .local_key_resource import LocalKey
from .remote_key_resource import RemoteKey
class KeyApiV2Resource(Resource):
def __init__(self, hs):
Resource.__init__(self)
self.putChild("server", LocalKey(hs))
self.putChild("query", RemoteKey(hs))

View file

@ -0,0 +1,125 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.web.resource import Resource
from synapse.http.server import respond_with_json_bytes
from syutil.crypto.jsonsign import sign_json
from syutil.base64util import encode_base64
from syutil.jsonutil import encode_canonical_json
from hashlib import sha256
from OpenSSL import crypto
import logging
logger = logging.getLogger(__name__)
class LocalKey(Resource):
"""HTTP resource containing encoding the TLS X.509 certificate and NACL
signature verification keys for this server::
GET /_matrix/key/v2/server/a.key.id HTTP/1.1
HTTP/1.1 200 OK
Content-Type: application/json
{
"valid_until_ts": # integer posix timestamp when this result expires.
"server_name": "this.server.example.com"
"verify_keys": {
"algorithm:version": {
"key": # base64 encoded NACL verification key.
}
},
"old_verify_keys": {
"algorithm:version": {
"expired_ts": # integer posix timestamp when the key expired.
"key": # base64 encoded NACL verification key.
}
}
"tls_certificate": # base64 ASN.1 DER encoded X.509 tls cert.
"signatures": {
"this.server.example.com": {
"algorithm:version": # NACL signature for this server
}
}
}
"""
isLeaf = True
def __init__(self, hs):
self.version_string = hs.version_string
self.config = hs.config
self.clock = hs.clock
self.update_response_body(self.clock.time_msec())
Resource.__init__(self)
def update_response_body(self, time_now_msec):
refresh_interval = self.config.key_refresh_interval
self.valid_until_ts = int(time_now_msec + refresh_interval)
self.response_body = encode_canonical_json(self.response_json_object())
def response_json_object(self):
verify_keys = {}
for key in self.config.signing_key:
verify_key_bytes = key.verify_key.encode()
key_id = "%s:%s" % (key.alg, key.version)
verify_keys[key_id] = {
u"key": encode_base64(verify_key_bytes)
}
old_verify_keys = {}
for key in self.config.old_signing_keys:
key_id = "%s:%s" % (key.alg, key.version)
verify_key_bytes = key.encode()
old_verify_keys[key_id] = {
u"key": encode_base64(verify_key_bytes),
u"expired_ts": key.expired,
}
x509_certificate_bytes = crypto.dump_certificate(
crypto.FILETYPE_ASN1,
self.config.tls_certificate
)
sha256_fingerprint = sha256(x509_certificate_bytes).digest()
json_object = {
u"valid_until_ts": self.valid_until_ts,
u"server_name": self.config.server_name,
u"verify_keys": verify_keys,
u"old_verify_keys": old_verify_keys,
u"tls_fingerprints": [{
u"sha256": encode_base64(sha256_fingerprint),
}]
}
for key in self.config.signing_key:
json_object = sign_json(
json_object,
self.config.server_name,
key,
)
return json_object
def render_GET(self, request):
time_now = self.clock.time_msec()
# Update the expiry time if less than half the interval remains.
if time_now + self.config.key_refresh_interval / 2 > self.valid_until_ts:
self.update_response_body(time_now)
return respond_with_json_bytes(
request, 200, self.response_body,
version_string=self.version_string
)

View file

@ -0,0 +1,242 @@
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.http.server import request_handler, respond_with_json_bytes
from synapse.http.servlet import parse_integer
from synapse.api.errors import SynapseError, Codes
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer
from io import BytesIO
import json
import logging
logger = logging.getLogger(__name__)
class RemoteKey(Resource):
"""HTTP resource for retreiving the TLS certificate and NACL signature
verification keys for a collection of servers. Checks that the reported
X.509 TLS certificate matches the one used in the HTTPS connection. Checks
that the NACL signature for the remote server is valid. Returns a dict of
JSON signed by both the remote server and by this server.
Supports individual GET APIs and a bulk query POST API.
Requsts:
GET /_matrix/key/v2/query/remote.server.example.com HTTP/1.1
GET /_matrix/key/v2/query/remote.server.example.com/a.key.id HTTP/1.1
POST /_matrix/v2/query HTTP/1.1
Content-Type: application/json
{
"server_keys": {
"remote.server.example.com": {
"a.key.id": {
"minimum_valid_until_ts": 1234567890123
}
}
}
}
Response:
HTTP/1.1 200 OK
Content-Type: application/json
{
"server_keys": [
{
"server_name": "remote.server.example.com"
"valid_until_ts": # posix timestamp
"verify_keys": {
"a.key.id": { # The identifier for a key.
key: "" # base64 encoded verification key.
}
}
"old_verify_keys": {
"an.old.key.id": { # The identifier for an old key.
key: "", # base64 encoded key
"expired_ts": 0, # when the key stop being used.
}
}
"tls_fingerprints": [
{ "sha256": # fingerprint }
]
"signatures": {
"remote.server.example.com": {...}
"this.server.example.com": {...}
}
}
]
}
"""
isLeaf = True
def __init__(self, hs):
self.keyring = hs.get_keyring()
self.store = hs.get_datastore()
self.version_string = hs.version_string
self.clock = hs.get_clock()
def render_GET(self, request):
self.async_render_GET(request)
return NOT_DONE_YET
@request_handler
@defer.inlineCallbacks
def async_render_GET(self, request):
if len(request.postpath) == 1:
server, = request.postpath
query = {server: {}}
elif len(request.postpath) == 2:
server, key_id = request.postpath
minimum_valid_until_ts = parse_integer(
request, "minimum_valid_until_ts"
)
arguments = {}
if minimum_valid_until_ts is not None:
arguments["minimum_valid_until_ts"] = minimum_valid_until_ts
query = {server: {key_id: arguments}}
else:
raise SynapseError(
404, "Not found %r" % request.postpath, Codes.NOT_FOUND
)
yield self.query_keys(request, query, query_remote_on_cache_miss=True)
def render_POST(self, request):
self.async_render_POST(request)
return NOT_DONE_YET
@request_handler
@defer.inlineCallbacks
def async_render_POST(self, request):
try:
content = json.loads(request.content.read())
if type(content) != dict:
raise ValueError()
except ValueError:
raise SynapseError(
400, "Content must be JSON object.", errcode=Codes.NOT_JSON
)
query = content["server_keys"]
yield self.query_keys(request, query, query_remote_on_cache_miss=True)
@defer.inlineCallbacks
def query_keys(self, request, query, query_remote_on_cache_miss=False):
logger.info("Handling query for keys %r", query)
store_queries = []
for server_name, key_ids in query.items():
if not key_ids:
key_ids = (None,)
for key_id in key_ids:
store_queries.append((server_name, key_id, None))
cached = yield self.store.get_server_keys_json(store_queries)
json_results = set()
time_now_ms = self.clock.time_msec()
cache_misses = dict()
for (server_name, key_id, from_server), results in cached.items():
results = [
(result["ts_added_ms"], result) for result in results
]
if not results and key_id is not None:
cache_misses.setdefault(server_name, set()).add(key_id)
continue
if key_id is not None:
ts_added_ms, most_recent_result = max(results)
ts_valid_until_ms = most_recent_result["ts_valid_until_ms"]
req_key = query.get(server_name, {}).get(key_id, {})
req_valid_until = req_key.get("minimum_valid_until_ts")
miss = False
if req_valid_until is not None:
if ts_valid_until_ms < req_valid_until:
logger.debug(
"Cached response for %r/%r is older than requested"
": valid_until (%r) < minimum_valid_until (%r)",
server_name, key_id,
ts_valid_until_ms, req_valid_until
)
miss = True
else:
logger.debug(
"Cached response for %r/%r is newer than requested"
": valid_until (%r) >= minimum_valid_until (%r)",
server_name, key_id,
ts_valid_until_ms, req_valid_until
)
elif (ts_added_ms + ts_valid_until_ms) / 2 < time_now_ms:
logger.debug(
"Cached response for %r/%r is too old"
": (added (%r) + valid_until (%r)) / 2 < now (%r)",
server_name, key_id,
ts_added_ms, ts_valid_until_ms, time_now_ms
)
# We more than half way through the lifetime of the
# response. We should fetch a fresh copy.
miss = True
else:
logger.debug(
"Cached response for %r/%r is still valid"
": (added (%r) + valid_until (%r)) / 2 < now (%r)",
server_name, key_id,
ts_added_ms, ts_valid_until_ms, time_now_ms
)
if miss:
cache_misses.setdefault(server_name, set()).add(key_id)
json_results.add(bytes(most_recent_result["key_json"]))
else:
for ts_added, result in results:
json_results.add(bytes(result["key_json"]))
if cache_misses and query_remote_on_cache_miss:
for server_name, key_ids in cache_misses.items():
try:
yield self.keyring.get_server_verify_key_v2_direct(
server_name, key_ids
)
except:
logger.exception("Failed to get key for %r", server_name)
pass
yield self.query_keys(
request, query, query_remote_on_cache_miss=False
)
else:
result_io = BytesIO()
result_io.write(b"{\"server_keys\":")
sep = b"["
for json_bytes in json_results:
result_io.write(sep)
result_io.write(json_bytes)
sep = b","
if sep == b"[":
result_io.write(sep)
result_io.write(b"]}")
respond_with_json_bytes(
request, 200, result_io.getvalue(),
version_string=self.version_string
)

View file

@ -79,6 +79,7 @@ class BaseHomeServer(object):
'resource_for_web_client', 'resource_for_web_client',
'resource_for_content_repo', 'resource_for_content_repo',
'resource_for_server_key', 'resource_for_server_key',
'resource_for_server_key_v2',
'resource_for_media_repository', 'resource_for_media_repository',
'resource_for_metrics', 'resource_for_metrics',
'event_sources', 'event_sources',

View file

@ -51,7 +51,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database # Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts. # schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 16 SCHEMA_VERSION = 17
dir_path = os.path.abspath(os.path.dirname(__file__)) dir_path = os.path.abspath(os.path.dirname(__file__))

View file

@ -122,3 +122,68 @@ class KeyStore(SQLBaseStore):
}, },
desc="store_server_verify_key", desc="store_server_verify_key",
) )
def store_server_keys_json(self, server_name, key_id, from_server,
ts_now_ms, ts_expires_ms, key_json_bytes):
"""Stores the JSON bytes for a set of keys from a server
The JSON should be signed by the originating server, the intermediate
server, and by this server. Updates the value for the
(server_name, key_id, from_server) triplet if one already existed.
Args:
server_name (str): The name of the server.
key_id (str): The identifer of the key this JSON is for.
from_server (str): The server this JSON was fetched from.
ts_now_ms (int): The time now in milliseconds.
ts_valid_until_ms (int): The time when this json stops being valid.
key_json (bytes): The encoded JSON.
"""
return self._simple_insert(
table="server_keys_json",
values={
"server_name": server_name,
"key_id": key_id,
"from_server": from_server,
"ts_added_ms": ts_now_ms,
"ts_valid_until_ms": ts_expires_ms,
"key_json": buffer(key_json_bytes),
},
or_replace=True,
)
def get_server_keys_json(self, server_keys):
"""Retrive the key json for a list of server_keys and key ids.
If no keys are found for a given server, key_id and source then
that server, key_id, and source triplet entry will be an empty list.
The JSON is returned as a byte array so that it can be efficiently
used in an HTTP response.
Args:
server_keys (list): List of (server_name, key_id, source) triplets.
Returns:
Dict mapping (server_name, key_id, source) triplets to dicts with
"ts_valid_until_ms" and "key_json" keys.
"""
def _get_server_keys_json_txn(txn):
results = {}
for server_name, key_id, from_server in server_keys:
keyvalues = {"server_name": server_name}
if key_id is not None:
keyvalues["key_id"] = key_id
if from_server is not None:
keyvalues["from_server"] = from_server
rows = self._simple_select_list_txn(
txn,
"server_keys_json",
keyvalues=keyvalues,
retcols=(
"key_id",
"from_server",
"ts_added_ms",
"ts_valid_until_ms",
"key_json",
),
)
results[(server_name, key_id, from_server)] = rows
return results
return self.runInteraction(
"get_server_keys_json", _get_server_keys_json_txn
)

View file

@ -0,0 +1,24 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS server_keys_json (
server_name TEXT, -- Server name.
key_id TEXT, -- Requested key id.
from_server TEXT, -- Which server the keys were fetched from.
ts_added_ms INTEGER, -- When the keys were fetched
ts_valid_until_ms INTEGER, -- When this version of the keys exipires.
key_json bytea, -- JSON certificate for the remote server.
CONSTRAINT uniqueness UNIQUE (server_name, key_id, from_server)
);