0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-06-03 11:18:56 +02:00

Merge pull request #8 from matrix-org/server2server_signing

Server2server signing
This commit is contained in:
Mark Haines 2014-10-14 10:06:04 +01:00
commit 636a0dbde7
28 changed files with 786 additions and 468 deletions

View file

@ -0,0 +1,151 @@
Signing JSON
============
JSON is signed by encoding the JSON object without ``signatures`` or ``meta``
keys using a canonical encoding. The JSON bytes are then signed using the
signature algorithm and the signature encoded using base64 with the padding
stripped. The resulting base64 signature is added to an object under the
*signing key identifier* which is added to the ``signatures`` object under the
name of the server signing it which is added back to the original JSON object
along with the ``meta`` object.
The *signing key identifier* is the concatenation of the *signing algorithm*
and a *key version*. The *signing algorithm* identifies the algorithm used to
sign the JSON. The currently support value for *signing algorithm* is
``ed25519`` as implemented by NACL (http://nacl.cr.yp.to/). The *key version*
is used to distinguish between different signing keys used by the same entity.
The ``meta`` object and the ``signatures`` object are not covered by the
signature. Therefore intermediate servers can add metadata such as time stamps
and additional signatures.
::
{
"name": "example.org",
"signing_keys": {
"ed25519:1": "XSl0kuyvrXNj6A+7/tkrB9sxSbRi08Of5uRhxOqZtEQ"
},
"meta": {
"retrieved_ts_ms": 922834800000
},
"signatures": {
"example.org": {
"ed25519:1": "s76RUgajp8w172am0zQb/iPTHsRnb4SkrzGoeCOSFfcBY2V/1c8QfrmdXHpvnc2jK5BD1WiJIxiMW95fMjK7Bw"
}
}
}
::
def sign_json(json_object, signing_key, signing_name):
signatures = json_object.pop("signatures", {})
meta = json_object.pop("meta", None)
signed = signing_key.sign(encode_canonical_json(json_object))
signature_base64 = encode_base64(signed.signature)
key_id = "%s:%s" % (signing_key.alg, signing_key.version)
signatures.setdefault(sigature_name, {})[key_id] = signature_base64
json_object["signatures"] = signatures
if meta is not None:
json_object["meta"] = meta
return json_object
Checking for a Signature
------------------------
To check if an entity has signed a JSON object a server does the following
1. Checks if the ``signatures`` object contains an entry with the name of the
entity. If the entry is missing then the check fails.
2. Removes any *signing key identifiers* from the entry with algorithms it
doesn't understand. If there are no *signing key identifiers* left then the
check fails.
3. Looks up *verification keys* for the remaining *signing key identifiers*
either from a local cache or by consulting a trusted key server. If it
cannot find a *verification key* then the check fails.
4. Decodes the base64 encoded signature bytes. If base64 decoding fails then
the check fails.
5. Checks the signature bytes using the *verification key*. If this fails then
the check fails. Otherwise the check succeeds.
Canonical JSON
--------------
The canonical JSON encoding for a value is the shortest UTF-8 JSON encoding
with dictionary keys lexicographically sorted by unicode codepoint. Numbers in
the JSON value must be integers in the range [-(2**53)+1, (2**53)-1].
::
import json
def canonical_json(value):
return json.dumps(
value,
ensure_ascii=False,
separators=(',',':'),
sort_keys=True,
).encode("UTF-8")
Grammar
+++++++
Adapted from the grammar in http://tools.ietf.org/html/rfc7159 removing
insignificant whitespace, fractions, exponents and redundant character escapes
::
value = false / null / true / object / array / number / string
false = %x66.61.6c.73.65
null = %x6e.75.6c.6c
true = %x74.72.75.65
object = %x7B [ member *( %x2C member ) ] %7D
member = string %x3A value
array = %x5B [ value *( %x2C value ) ] %5B
number = [ %x2D ] int
int = %x30 / ( %x31-39 *digit )
digit = %x30-39
string = %x22 *char %x22
char = unescaped / %x5C escaped
unescaped = %x20-21 / %x23-5B / %x5D-10FFFF
escaped = %x22 ; " quotation mark U+0022
/ %x5C ; \ reverse solidus U+005C
/ %x62 ; b backspace U+0008
/ %x66 ; f form feed U+000C
/ %x6E ; n line feed U+000A
/ %x72 ; r carriage return U+000D
/ %x74 ; t tab U+0009
/ %x75.30.30.30 (%x30-37 / %x62 / %x65-66) ; u000X
/ %x75.30.30.31 (%x30-39 / %x61-66) ; u001X
Signing Events
==============
Signing events is a more complicated process since servers can choose to redact
non-essential event contents. Before signing the event it is encoded as
Canonical JSON and hashed using SHA-256. The resulting hash is then stored
in the event JSON in a ``hash`` object under a ``sha256`` key. Then all
non-essential keys are stripped from the event object, and the resulting object
which included the ``hash`` key is signed using the JSON signing algorithm.
Servers can then transmit the entire event or the event with the non-essential
keys removed. Receiving servers can then check the entire event if it is
present by computing the SHA-256 of the event excluding the ``hash`` object, or
by using the ``hash`` object included in the event if keys have been redacted.
New hash functions can be introduced by adding additional keys to the ``hash``
object. Since the ``hash`` object cannot be redacted a server shouldn't allow
too many hashes to be listed, otherwise a server might embed illict data within
the ``hash`` object. For similar reasons a server shouldn't allow hash values
that are too long.
[[TODO(markjh): We might want to specify a maximum number of keys for the
``hash`` and we might want to specify the maximum output size of a hash]]
[[TODO(markjh) We might want to allow the server to omit the output of well
known hash functions like SHA-256 when none of the keys have been redacted]]

View file

@ -19,6 +19,7 @@ import logging
class Codes(object):
UNAUTHORIZED = "M_UNAUTHORIZED"
FORBIDDEN = "M_FORBIDDEN"
BAD_JSON = "M_BAD_JSON"
NOT_JSON = "M_NOT_JSON"

View file

@ -18,4 +18,5 @@
CLIENT_PREFIX = "/_matrix/client/api/v1"
FEDERATION_PREFIX = "/_matrix/federation/v1"
WEB_CLIENT_PREFIX = "/_matrix/client"
CONTENT_REPO_PREFIX = "/_matrix/content"
CONTENT_REPO_PREFIX = "/_matrix/content"
SERVER_KEY_PREFIX = "/_matrix/key/v1"

View file

@ -25,9 +25,11 @@ from twisted.web.static import File
from twisted.web.server import Site
from synapse.http.server import JsonResource, RootRedirect
from synapse.http.content_repository import ContentRepoResource
from synapse.http.server_key_resource import LocalKey
from synapse.http.client import MatrixHttpClient
from synapse.api.urls import (
CLIENT_PREFIX, FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX
CLIENT_PREFIX, FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX,
SERVER_KEY_PREFIX,
)
from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory
@ -63,6 +65,9 @@ class SynapseHomeServer(HomeServer):
self, self.upload_dir, self.auth, self.content_addr
)
def build_resource_for_server_key(self):
return LocalKey(self)
def build_db_pool(self):
return adbapi.ConnectionPool(
"sqlite3", self.get_db_name(),
@ -88,7 +93,8 @@ class SynapseHomeServer(HomeServer):
desired_tree = [
(CLIENT_PREFIX, self.get_resource_for_client()),
(FEDERATION_PREFIX, self.get_resource_for_federation()),
(CONTENT_REPO_PREFIX, self.get_resource_for_content_repo())
(CONTENT_REPO_PREFIX, self.get_resource_for_content_repo()),
(SERVER_KEY_PREFIX, self.get_resource_for_server_key()),
]
if web_client:
logger.info("Adding the web client.")

View file

@ -13,10 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import nacl.signing
import os
from ._base import Config
from syutil.base64util import encode_base64, decode_base64
from ._base import Config, ConfigError
import syutil.crypto.signing_key
class ServerConfig(Config):
@ -70,9 +69,16 @@ class ServerConfig(Config):
"content repository")
def read_signing_key(self, signing_key_path):
signing_key_base64 = self.read_file(signing_key_path, "signing_key")
signing_key_bytes = decode_base64(signing_key_base64)
return nacl.signing.SigningKey(signing_key_bytes)
signing_keys = self.read_file(signing_key_path, "signing_key")
try:
return syutil.crypto.signing_key.read_signing_keys(
signing_keys.splitlines(True)
)
except Exception as e:
raise ConfigError(
"Error reading signing_key."
" Try running again with --generate-config"
)
@classmethod
def generate_config(cls, args, config_dir_path):
@ -86,6 +92,21 @@ class ServerConfig(Config):
if not os.path.exists(args.signing_key_path):
with open(args.signing_key_path, "w") as signing_key_file:
key = nacl.signing.SigningKey.generate()
signing_key_file.write(encode_base64(key.encode()))
syutil.crypto.signing_key.write_signing_keys(
signing_key_file,
(syutil.crypto.SigningKey.generate("auto"),),
)
else:
signing_keys = cls.read_file(args.signing_key_path, "signing_key")
if len(signing_keys.split("\n")[0].split()) == 1:
# handle keys in the old format.
key = syutil.crypto.signing_key.decode_signing_key_base64(
syutil.crypto.signing_key.NACL_ED25519,
"auto",
signing_keys.split("\n")[0]
)
with open(args.signing_key_path, "w") as signing_key_file:
syutil.crypto.signing_key.write_signing_keys(
signing_key_file,
(key,),
)

View file

@ -15,9 +15,10 @@
from twisted.web.http import HTTPClient
from twisted.internet.protocol import Factory
from twisted.internet import defer, reactor
from twisted.internet.protocol import ClientFactory
from twisted.names.srvconnect import SRVConnector
from twisted.internet.endpoints import connectProtocol
from synapse.http.endpoint import matrix_endpoint
import json
import logging
@ -30,15 +31,19 @@ def fetch_server_key(server_name, ssl_context_factory):
"""Fetch the keys for a remote server."""
factory = SynapseKeyClientFactory()
endpoint = matrix_endpoint(
reactor, server_name, ssl_context_factory, timeout=30
)
SRVConnector(
reactor, "matrix", server_name, factory,
protocol="tcp", connectFuncName="connectSSL", defaultPort=443,
connectFuncKwArgs=dict(contextFactory=ssl_context_factory)).connect()
server_key, server_certificate = yield factory.remote_key
defer.returnValue((server_key, server_certificate))
for i in range(5):
try:
protocol = yield endpoint.connect(factory)
server_response, server_certificate = yield protocol.remote_key
defer.returnValue((server_response, server_certificate))
return
except Exception as e:
logger.exception(e)
raise IOError("Cannot get key for %s" % server_name)
class SynapseKeyClientError(Exception):
@ -51,69 +56,47 @@ class SynapseKeyClientProtocol(HTTPClient):
the server and extracts the X.509 certificate for the remote peer from the
SSL connection."""
timeout = 30
def __init__(self):
self.remote_key = defer.Deferred()
def connectionMade(self):
logger.debug("Connected to %s", self.transport.getHost())
self.sendCommand(b"GET", b"/key")
self.sendCommand(b"GET", b"/_matrix/key/v1/")
self.endHeaders()
self.timer = reactor.callLater(
self.factory.timeout_seconds,
self.timeout,
self.on_timeout
)
def handleStatus(self, version, status, message):
if status != b"200":
logger.info("Non-200 response from %s: %s %s",
self.transport.getHost(), status, message)
#logger.info("Non-200 response from %s: %s %s",
# self.transport.getHost(), status, message)
self.transport.abortConnection()
def handleResponse(self, response_body_bytes):
try:
json_response = json.loads(response_body_bytes)
except ValueError:
logger.info("Invalid JSON response from %s",
self.transport.getHost())
#logger.info("Invalid JSON response from %s",
# self.transport.getHost())
self.transport.abortConnection()
return
certificate = self.transport.getPeerCertificate()
self.factory.on_remote_key((json_response, certificate))
self.remote_key.callback((json_response, certificate))
self.transport.abortConnection()
self.timer.cancel()
def on_timeout(self):
logger.debug("Timeout waiting for response from %s",
self.transport.getHost())
self.remote_key.errback(IOError("Timeout waiting for response"))
self.transport.abortConnection()
class SynapseKeyClientFactory(ClientFactory):
class SynapseKeyClientFactory(Factory):
protocol = SynapseKeyClientProtocol
max_retries = 5
timeout_seconds = 30
def __init__(self):
self.succeeded = False
self.retries = 0
self.remote_key = defer.Deferred()
def on_remote_key(self, key):
self.succeeded = True
self.remote_key.callback(key)
def retry_connection(self, connector):
self.retries += 1
if self.retries < self.max_retries:
connector.connector = None
connector.connect()
else:
self.remote_key.errback(
SynapseKeyClientError("Max retries exceeded"))
def clientConnectionFailed(self, connector, reason):
logger.info("Connection failed %s", reason)
self.retry_connection(connector)
def clientConnectionLost(self, connector, reason):
logger.info("Connection lost %s", reason)
if not self.succeeded:
self.retry_connection(connector)

154
synapse/crypto/keyring.py Normal file
View file

@ -0,0 +1,154 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.crypto.keyclient import fetch_server_key
from twisted.internet import defer
from syutil.crypto.jsonsign import verify_signed_json, signature_ids
from syutil.crypto.signing_key import (
is_signing_algorithm_supported, decode_verify_key_bytes
)
from syutil.base64util import decode_base64, encode_base64
from synapse.api.errors import SynapseError, Codes
from OpenSSL import crypto
import logging
logger = logging.getLogger(__name__)
class Keyring(object):
def __init__(self, hs):
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.hs = hs
@defer.inlineCallbacks
def verify_json_for_server(self, server_name, json_object):
key_ids = signature_ids(json_object, server_name)
if not key_ids:
raise SynapseError(
400,
"Not signed with a supported algorithm",
Codes.UNAUTHORIZED,
)
try:
verify_key = yield self.get_server_verify_key(server_name, key_ids)
except IOError:
raise SynapseError(
502,
"Error downloading keys for %s" % (server_name,),
Codes.UNAUTHORIZED,
)
except:
raise SynapseError(
401,
"No key for %s with id %s" % (server_name, key_ids),
Codes.UNAUTHORIZED,
)
try:
verify_signed_json(json_object, server_name, verify_key)
except:
raise SynapseError(
401,
"Invalid signature for server %s with key %s:%s" % (
server_name, verify_key.alg, verify_key.version
),
Codes.UNAUTHORIZED,
)
@defer.inlineCallbacks
def get_server_verify_key(self, server_name, key_ids):
"""Finds a verification key for the server with one of the key ids.
Args:
server_name (str): The name of the server to fetch a key for.
keys_ids (list of str): The key_ids to check for.
"""
# Check the datastore to see if we have one cached.
cached = yield self.store.get_server_verify_keys(server_name, key_ids)
if cached:
defer.returnValue(cached[0])
return
# Try to fetch the key from the remote server.
# TODO(markjh): Ratelimit requests to a given server.
(response, tls_certificate) = yield fetch_server_key(
server_name, self.hs.tls_context_factory
)
# Check the response.
x509_certificate_bytes = crypto.dump_certificate(
crypto.FILETYPE_ASN1, tls_certificate
)
if ("signatures" not in response
or server_name not in response["signatures"]):
raise ValueError("Key response not signed by remote server")
if "tls_certificate" not in response:
raise ValueError("Key response missing TLS certificate")
tls_certificate_b64 = response["tls_certificate"]
if encode_base64(x509_certificate_bytes) != tls_certificate_b64:
raise ValueError("TLS certificate doesn't match")
verify_keys = {}
for key_id, key_base64 in response["verify_keys"].items():
if is_signing_algorithm_supported(key_id):
key_bytes = decode_base64(key_base64)
verify_key = decode_verify_key_bytes(key_id, key_bytes)
verify_keys[key_id] = verify_key
for key_id in response["signatures"][server_name]:
if key_id not in response["verify_keys"]:
raise ValueError(
"Key response must include verification keys for all"
" signatures"
)
if key_id in verify_keys:
verify_signed_json(
response,
server_name,
verify_keys[key_id]
)
# Cache the result in the datastore.
time_now_ms = self.clock.time_msec()
self.store.store_server_certificate(
server_name,
server_name,
time_now_ms,
tls_certificate,
)
for key_id, key in verify_keys.items():
self.store.store_server_verify_key(
server_name, server_name, time_now_ms, key
)
for key_id in key_ids:
if key_id in verify_keys:
defer.returnValue(verify_keys[key_id])
return
raise ValueError("No verification key found for given key ids")

View file

@ -1,111 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import reactor, ssl
from twisted.web import server
from twisted.web.resource import Resource
from twisted.python.log import PythonLoggingObserver
from synapse.crypto.resource.key import LocalKey
from synapse.crypto.config import load_config
from syutil.base64util import decode_base64
from OpenSSL import crypto, SSL
import logging
import nacl.signing
import sys
class KeyServerSSLContextFactory(ssl.ContextFactory):
"""Factory for PyOpenSSL SSL contexts that are used to handle incoming
connections and to make connections to remote servers."""
def __init__(self, key_server):
self._context = SSL.Context(SSL.SSLv23_METHOD)
self.configure_context(self._context, key_server)
@staticmethod
def configure_context(context, key_server):
context.set_options(SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3)
context.use_certificate(key_server.tls_certificate)
context.use_privatekey(key_server.tls_private_key)
context.load_tmp_dh(key_server.tls_dh_params_path)
context.set_cipher_list("!ADH:HIGH+kEDH:!AECDH:HIGH+kEECDH")
def getContext(self):
return self._context
class KeyServer(object):
"""An HTTPS server serving LocalKey and RemoteKey resources."""
def __init__(self, server_name, tls_certificate_path, tls_private_key_path,
tls_dh_params_path, signing_key_path, bind_host, bind_port):
self.server_name = server_name
self.tls_certificate = self.read_tls_certificate(tls_certificate_path)
self.tls_private_key = self.read_tls_private_key(tls_private_key_path)
self.tls_dh_params_path = tls_dh_params_path
self.signing_key = self.read_signing_key(signing_key_path)
self.bind_host = bind_host
self.bind_port = int(bind_port)
self.ssl_context_factory = KeyServerSSLContextFactory(self)
@staticmethod
def read_tls_certificate(cert_path):
with open(cert_path) as cert_file:
cert_pem = cert_file.read()
return crypto.load_certificate(crypto.FILETYPE_PEM, cert_pem)
@staticmethod
def read_tls_private_key(private_key_path):
with open(private_key_path) as private_key_file:
private_key_pem = private_key_file.read()
return crypto.load_privatekey(crypto.FILETYPE_PEM, private_key_pem)
@staticmethod
def read_signing_key(signing_key_path):
with open(signing_key_path) as signing_key_file:
signing_key_b64 = signing_key_file.read()
signing_key_bytes = decode_base64(signing_key_b64)
return nacl.signing.SigningKey(signing_key_bytes)
def run(self):
root = Resource()
root.putChild("key", LocalKey(self))
site = server.Site(root)
reactor.listenSSL(
self.bind_port,
site,
self.ssl_context_factory,
interface=self.bind_host
)
logging.basicConfig(level=logging.DEBUG)
observer = PythonLoggingObserver()
observer.start()
reactor.run()
def main():
key_server = KeyServer(**load_config(__doc__, sys.argv[1:]))
key_server.run()
if __name__ == "__main__":
main()

View file

@ -1,15 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -1,161 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer
from synapse.http.server import respond_with_json_bytes
from synapse.crypto.keyclient import fetch_server_key
from syutil.crypto.jsonsign import sign_json, verify_signed_json
from syutil.base64util import encode_base64, decode_base64
from syutil.jsonutil import encode_canonical_json
from OpenSSL import crypto
from nacl.signing import VerifyKey
import logging
logger = logging.getLogger(__name__)
class LocalKey(Resource):
"""HTTP resource containing encoding the TLS X.509 certificate and NACL
signature verification keys for this server::
GET /key HTTP/1.1
HTTP/1.1 200 OK
Content-Type: application/json
{
"server_name": "this.server.example.com"
"signature_verify_key": # base64 encoded NACL verification key.
"tls_certificate": # base64 ASN.1 DER encoded X.509 tls cert.
"signatures": {
"this.server.example.com": # NACL signature for this server.
}
}
"""
def __init__(self, key_server):
self.key_server = key_server
self.response_body = encode_canonical_json(
self.response_json_object(key_server)
)
Resource.__init__(self)
@staticmethod
def response_json_object(key_server):
verify_key_bytes = key_server.signing_key.verify_key.encode()
x509_certificate_bytes = crypto.dump_certificate(
crypto.FILETYPE_ASN1,
key_server.tls_certificate
)
json_object = {
u"server_name": key_server.server_name,
u"signature_verify_key": encode_base64(verify_key_bytes),
u"tls_certificate": encode_base64(x509_certificate_bytes)
}
signed_json = sign_json(
json_object,
key_server.server_name,
key_server.signing_key
)
return signed_json
def getChild(self, name, request):
logger.info("getChild %s %s", name, request)
if name == '':
return self
else:
return RemoteKey(name, self.key_server)
def render_GET(self, request):
return respond_with_json_bytes(request, 200, self.response_body)
class RemoteKey(Resource):
"""HTTP resource for retreiving the TLS certificate and NACL signature
verification keys for a another server. Checks that the reported X.509 TLS
certificate matches the one used in the HTTPS connection. Checks that the
NACL signature for the remote server is valid. Returns JSON signed by both
the remote server and by this server.
GET /key/remote.server.example.com HTTP/1.1
HTTP/1.1 200 OK
Content-Type: application/json
{
"server_name": "remote.server.example.com"
"signature_verify_key": # base64 encoded NACL verification key.
"tls_certificate": # base64 ASN.1 DER encoded X.509 tls cert.
"signatures": {
"remote.server.example.com": # NACL signature for remote server.
"this.server.example.com": # NACL signature for this server.
}
}
"""
isLeaf = True
def __init__(self, server_name, key_server):
self.server_name = server_name
self.key_server = key_server
Resource.__init__(self)
def render_GET(self, request):
self._async_render_GET(request)
return NOT_DONE_YET
@defer.inlineCallbacks
def _async_render_GET(self, request):
try:
server_keys, certificate = yield fetch_server_key(
self.server_name,
self.key_server.ssl_context_factory
)
resp_server_name = server_keys[u"server_name"]
verify_key_b64 = server_keys[u"signature_verify_key"]
tls_certificate_b64 = server_keys[u"tls_certificate"]
verify_key = VerifyKey(decode_base64(verify_key_b64))
if resp_server_name != self.server_name:
raise ValueError("Wrong server name '%s' != '%s'" %
(resp_server_name, self.server_name))
x509_certificate_bytes = crypto.dump_certificate(
crypto.FILETYPE_ASN1,
certificate
)
if encode_base64(x509_certificate_bytes) != tls_certificate_b64:
raise ValueError("TLS certificate doesn't match")
verify_signed_json(server_keys, self.server_name, verify_key)
signed_json = sign_json(
server_keys,
self.key_server.server_name,
self.key_server.signing_key
)
json_bytes = encode_canonical_json(signed_json)
respond_with_json_bytes(request, 200, json_bytes)
except Exception as e:
json_bytes = encode_canonical_json({
u"error": {u"code": 502, u"message": e.message}
})
respond_with_json_bytes(request, 502, json_bytes)

View file

@ -22,6 +22,7 @@ from .transport import TransportLayer
def initialize_http_replication(homeserver):
transport = TransportLayer(
homeserver,
homeserver.hostname,
server=homeserver.get_resource_for_federation(),
client=homeserver.get_http_client()

View file

@ -492,7 +492,6 @@ class _TransactionQueue(object):
"""
def __init__(self, hs, transaction_actions, transport_layer):
self.server_name = hs.hostname
self.transaction_actions = transaction_actions
self.transport_layer = transport_layer
@ -591,7 +590,7 @@ class _TransactionQueue(object):
transaction = Transaction.create_new(
ts=self._clock.time_msec(),
transaction_id=self._next_txn_id,
transaction_id=str(self._next_txn_id),
origin=self.server_name,
destination=destination,
pdus=pdus,
@ -609,18 +608,17 @@ class _TransactionQueue(object):
# FIXME (erikj): This is a bit of a hack to make the Pdu age
# keys work
def cb(transaction):
def json_data_cb():
data = transaction.get_dict()
now = int(self._clock.time_msec())
if "pdus" in transaction:
for p in transaction["pdus"]:
if "pdus" in data:
for p in data["pdus"]:
if "age_ts" in p:
p["age"] = now - int(p["age_ts"])
return transaction
return data
code, response = yield self.transport_layer.send_transaction(
transaction,
on_send_callback=cb,
transaction, json_data_cb
)
logger.debug("TX [%s] Sent transaction", destination)

View file

@ -24,6 +24,7 @@ over a different (albeit still reliable) protocol.
from twisted.internet import defer
from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.api.errors import Codes, SynapseError
from synapse.util.logutils import log_function
import logging
@ -54,7 +55,7 @@ class TransportLayer(object):
we receive data.
"""
def __init__(self, server_name, server, client):
def __init__(self, homeserver, server_name, server, client):
"""
Args:
server_name (str): Local home server host
@ -63,6 +64,7 @@ class TransportLayer(object):
client (synapse.protocol.http.HttpClient): the http client used to
send requests
"""
self.keyring = homeserver.get_keyring()
self.server_name = server_name
self.server = server
self.client = client
@ -144,7 +146,7 @@ class TransportLayer(object):
@defer.inlineCallbacks
@log_function
def send_transaction(self, transaction, on_send_callback=None):
def send_transaction(self, transaction, json_data_callback=None):
""" Sends the given Transaction to it's destination
Args:
@ -163,25 +165,15 @@ class TransportLayer(object):
if transaction.destination == self.server_name:
raise RuntimeError("Transport layer cannot send to itself!")
data = transaction.get_dict()
# FIXME (erikj): This is a bit of a hack to make the Pdu age
# keys work
def cb(destination, method, path_bytes, producer):
if not on_send_callback:
return
transaction = json.loads(producer.body)
new_transaction = on_send_callback(transaction)
producer.reset(new_transaction)
# FIXME: This is only used by the tests. The actual json sent is
# generated by the json_data_callback.
json_data = transaction.get_dict()
code, response = yield self.client.put_json(
transaction.destination,
path=PREFIX + "/send/%s/" % transaction.transaction_id,
data=data,
on_send_callback=cb,
data=json_data,
json_data_callback=json_data_callback,
)
logger.debug(
@ -205,6 +197,72 @@ class TransportLayer(object):
defer.returnValue(response)
@defer.inlineCallbacks
def _authenticate_request(self, request):
json_request = {
"method": request.method,
"uri": request.uri,
"destination": self.server_name,
"signatures": {},
}
content = None
origin = None
if request.method == "PUT":
#TODO: Handle other method types? other content types?
try:
content_bytes = request.content.read()
content = json.loads(content_bytes)
json_request["content"] = content
except:
raise SynapseError(400, "Unable to parse JSON", Codes.BAD_JSON)
def parse_auth_header(header_str):
try:
params = auth.split(" ")[1].split(",")
param_dict = dict(kv.split("=") for kv in params)
def strip_quotes(value):
if value.startswith("\""):
return value[1:-1]
else:
return value
origin = strip_quotes(param_dict["origin"])
key = strip_quotes(param_dict["key"])
sig = strip_quotes(param_dict["sig"])
return (origin, key, sig)
except:
raise SynapseError(
400, "Malformed Authorization header", Codes.UNAUTHORIZED
)
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
for auth in auth_headers:
if auth.startswith("X-Matrix"):
(origin, key, sig) = parse_auth_header(auth)
json_request["origin"] = origin
json_request["signatures"].setdefault(origin,{})[key] = sig
if not json_request["signatures"]:
raise SynapseError(
401, "Missing Authorization headers", Codes.UNAUTHORIZED,
)
yield self.keyring.verify_json_for_server(origin, json_request)
defer.returnValue((origin, content))
def _with_authentication(self, handler):
@defer.inlineCallbacks
def new_handler(request, *args, **kwargs):
(origin, content) = yield self._authenticate_request(request)
response = yield handler(
origin, content, request.args, *args, **kwargs
)
defer.returnValue(response)
return new_handler
@log_function
def register_received_handler(self, handler):
""" Register a handler that will be fired when we receive data.
@ -218,7 +276,7 @@ class TransportLayer(object):
self.server.register_path(
"PUT",
re.compile("^" + PREFIX + "/send/([^/]*)/$"),
self._on_send_request
self._with_authentication(self._on_send_request)
)
@log_function
@ -236,9 +294,9 @@ class TransportLayer(object):
self.server.register_path(
"GET",
re.compile("^" + PREFIX + "/pull/$"),
lambda request: handler.on_pull_request(
request.args["origin"][0],
request.args["v"]
self._with_authentication(
lambda origin, content, query:
handler.on_pull_request(query["origin"][0], query["v"])
)
)
@ -247,8 +305,9 @@ class TransportLayer(object):
self.server.register_path(
"GET",
re.compile("^" + PREFIX + "/pdu/([^/]*)/([^/]*)/$"),
lambda request, pdu_origin, pdu_id: handler.on_pdu_request(
pdu_origin, pdu_id
self._with_authentication(
lambda origin, content, query, pdu_origin, pdu_id:
handler.on_pdu_request(pdu_origin, pdu_id)
)
)
@ -256,38 +315,47 @@ class TransportLayer(object):
self.server.register_path(
"GET",
re.compile("^" + PREFIX + "/state/([^/]*)/$"),
lambda request, context: handler.on_context_state_request(
context
self._with_authentication(
lambda origin, content, query, context:
handler.on_context_state_request(context)
)
)
self.server.register_path(
"GET",
re.compile("^" + PREFIX + "/backfill/([^/]*)/$"),
lambda request, context: self._on_backfill_request(
context, request.args["v"],
request.args["limit"]
self._with_authentication(
lambda origin, content, query, context:
self._on_backfill_request(
context, query["v"], query["limit"]
)
)
)
self.server.register_path(
"GET",
re.compile("^" + PREFIX + "/context/([^/]*)/$"),
lambda request, context: handler.on_context_pdus_request(context)
self._with_authentication(
lambda origin, content, query, context:
handler.on_context_pdus_request(context)
)
)
# This is when we receive a server-server Query
self.server.register_path(
"GET",
re.compile("^" + PREFIX + "/query/([^/]*)$"),
lambda request, query_type: handler.on_query_request(
query_type, {k: v[0] for k, v in request.args.items()}
self._with_authentication(
lambda origin, content, query, query_type:
handler.on_query_request(
query_type, {k: v[0] for k, v in query.items()}
)
)
)
@defer.inlineCallbacks
@log_function
def _on_send_request(self, request, transaction_id):
def _on_send_request(self, origin, content, query, transaction_id):
""" Called on PUT /send/<transaction_id>/
Args:
@ -302,12 +370,7 @@ class TransportLayer(object):
"""
# Parse the request
try:
data = request.content.read()
l = data[:20].encode("string_escape")
logger.debug("Got data: \"%s\"", l)
transaction_data = json.loads(data)
transaction_data = content
logger.debug(
"Decoded %s: %s",

View file

@ -186,6 +186,8 @@ class Transaction(JsonEncodedObject):
"previous_ids",
"pdus",
"edus",
"transaction_id",
"destination",
]
internal_keys = [

View file

@ -26,6 +26,8 @@ from syutil.jsonutil import encode_canonical_json
from synapse.api.errors import CodeMessageException, SynapseError
from syutil.crypto.jsonsign import sign_json
from StringIO import StringIO
import json
@ -147,7 +149,7 @@ class BaseHttpClient(object):
class MatrixHttpClient(BaseHttpClient):
""" Wrapper around the twisted HTTP client api. Implements
""" Wrapper around the twisted HTTP client api. Implements
Attributes:
agent (twisted.web.client.Agent): The twisted Agent used to send the
@ -156,8 +158,42 @@ class MatrixHttpClient(BaseHttpClient):
RETRY_DNS_LOOKUP_FAILURES = "__retry_dns"
def __init__(self, hs):
self.signing_key = hs.config.signing_key[0]
self.server_name = hs.hostname
BaseHttpClient.__init__(self, hs)
def sign_request(self, destination, method, url_bytes, headers_dict,
content=None):
request = {
"method": method,
"uri": url_bytes,
"origin": self.server_name,
"destination": destination,
}
if content is not None:
request["content"] = content
request = sign_json(request, self.server_name, self.signing_key)
from syutil.jsonutil import encode_canonical_json
logger.debug("Signing " + " " * 11 + "%s %s",
self.server_name, encode_canonical_json(request))
auth_headers = []
for key,sig in request["signatures"][self.server_name].items():
auth_headers.append(bytes(
"X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % (
self.server_name, key, sig,
)
))
headers_dict[b"Authorization"] = auth_headers
@defer.inlineCallbacks
def put_json(self, destination, path, data, on_send_callback=None):
def put_json(self, destination, path, data={}, json_data_callback=None):
""" Sends the specifed json data using PUT
Args:
@ -166,6 +202,8 @@ class MatrixHttpClient(BaseHttpClient):
path (str): The HTTP path.
data (dict): A dict containing the data that will be used as
the request body. This will be encoded as JSON.
json_data_callback (callable): A callable returning the dict to
use as the request body.
Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result
@ -173,13 +211,16 @@ class MatrixHttpClient(BaseHttpClient):
CodeMessageException is raised.
"""
if not on_send_callback:
def on_send_callback(destination, method, path_bytes, producer):
pass
if not json_data_callback:
def json_data_callback():
return data
def body_callback(method, url_bytes, headers_dict):
producer = _JsonProducer(data)
on_send_callback(destination, method, path, producer)
json_data = json_data_callback()
self.sign_request(
destination, method, url_bytes, headers_dict, json_data
)
producer = _JsonProducer(json_data)
return producer
response = yield self._create_request(
@ -221,6 +262,7 @@ class MatrixHttpClient(BaseHttpClient):
logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail)
def body_callback(method, url_bytes, headers_dict):
self.sign_request(destination, method, url_bytes, headers_dict)
return None
response = yield self._create_request(

View file

@ -0,0 +1,89 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.web.resource import Resource
from synapse.http.server import respond_with_json_bytes
from syutil.crypto.jsonsign import sign_json
from syutil.base64util import encode_base64
from syutil.jsonutil import encode_canonical_json
from OpenSSL import crypto
import logging
logger = logging.getLogger(__name__)
class LocalKey(Resource):
"""HTTP resource containing encoding the TLS X.509 certificate and NACL
signature verification keys for this server::
GET /key HTTP/1.1
HTTP/1.1 200 OK
Content-Type: application/json
{
"server_name": "this.server.example.com"
"verify_keys": {
"algorithm:version": # base64 encoded NACL verification key.
},
"tls_certificate": # base64 ASN.1 DER encoded X.509 tls cert.
"signatures": {
"this.server.example.com": {
"algorithm:version": # NACL signature for this server.
}
}
}
"""
def __init__(self, hs):
self.hs = hs
self.response_body = encode_canonical_json(
self.response_json_object(hs.config)
)
Resource.__init__(self)
@staticmethod
def response_json_object(server_config):
verify_keys = {}
for key in server_config.signing_key:
verify_key_bytes = key.verify_key.encode()
key_id = "%s:%s" % (key.alg, key.version)
verify_keys[key_id] = encode_base64(verify_key_bytes)
x509_certificate_bytes = crypto.dump_certificate(
crypto.FILETYPE_ASN1,
server_config.tls_certificate
)
json_object = {
u"server_name": server_config.server_name,
u"verify_keys": verify_keys,
u"tls_certificate": encode_base64(x509_certificate_bytes)
}
for key in server_config.signing_key:
json_object = sign_json(
json_object,
server_config.server_name,
key,
)
return json_object
def render_GET(self, request):
return respond_with_json_bytes(request, 200, self.response_body)
def getChild(self, name, request):
if name == '':
return self

View file

@ -34,6 +34,7 @@ from synapse.util.distributor import Distributor
from synapse.util.lockutils import LockManager
from synapse.streams.events import EventSources
from synapse.api.ratelimiting import Ratelimiter
from synapse.crypto.keyring import Keyring
class BaseHomeServer(object):
@ -75,8 +76,10 @@ class BaseHomeServer(object):
'resource_for_federation',
'resource_for_web_client',
'resource_for_content_repo',
'resource_for_server_key',
'event_sources',
'ratelimiter',
'keyring',
]
def __init__(self, hostname, **kwargs):
@ -212,6 +215,9 @@ class HomeServer(BaseHomeServer):
def build_ratelimiter(self):
return Ratelimiter()
def build_keyring(self):
return Keyring(self)
def register_servlets(self):
""" Register all servlets associated with this HomeServer.
"""

View file

@ -57,6 +57,7 @@ SCHEMAS = [
"presence",
"im",
"room_aliases",
"keys",
"redactions",
]

View file

@ -121,7 +121,7 @@ class SQLBaseStore(object):
# "Simple" SQL API methods that operate on a single table with no JOINs,
# no complex WHERE clauses, just a dict of values for columns.
def _simple_insert(self, table, values, or_replace=False):
def _simple_insert(self, table, values, or_replace=False, or_ignore=False):
"""Executes an INSERT query on the named table.
Args:
@ -130,13 +130,16 @@ class SQLBaseStore(object):
or_replace : bool; if True performs an INSERT OR REPLACE
"""
return self.runInteraction(
self._simple_insert_txn, table, values, or_replace=or_replace
self._simple_insert_txn, table, values, or_replace=or_replace,
or_ignore=or_ignore,
)
@log_function
def _simple_insert_txn(self, txn, table, values, or_replace=False):
def _simple_insert_txn(self, txn, table, values, or_replace=False,
or_ignore=False):
sql = "%s INTO %s (%s) VALUES(%s)" % (
("INSERT OR REPLACE" if or_replace else "INSERT"),
("INSERT OR REPLACE" if or_replace else
"INSERT OR IGNORE" if or_ignore else "INSERT"),
table,
", ".join(k for k in values),
", ".join("?" for k in values)

View file

@ -18,7 +18,8 @@ from _base import SQLBaseStore
from twisted.internet import defer
import OpenSSL
import nacl.signing
from syutil.crypto.signing_key import decode_verify_key_bytes
import hashlib
class KeyStore(SQLBaseStore):
"""Persistence for signature verification keys and tls X.509 certificates
@ -42,62 +43,76 @@ class KeyStore(SQLBaseStore):
)
defer.returnValue(tls_certificate)
def store_server_certificate(self, server_name, key_server, ts_now_ms,
def store_server_certificate(self, server_name, from_server, time_now_ms,
tls_certificate):
"""Stores the TLS X.509 certificate for the given server
Args:
server_name (bytes): The name of the server.
key_server (bytes): Where the certificate was looked up
ts_now_ms (int): The time now in milliseconds
server_name (str): The name of the server.
from_server (str): Where the certificate was looked up
time_now_ms (int): The time now in milliseconds
tls_certificate (OpenSSL.crypto.X509): The X.509 certificate.
"""
tls_certificate_bytes = OpenSSL.crypto.dump_certificate(
OpenSSL.crypto.FILETYPE_ASN1, tls_certificate
)
fingerprint = hashlib.sha256(tls_certificate_bytes).hexdigest()
return self._simple_insert(
table="server_tls_certificates",
keyvalues={
values={
"server_name": server_name,
"key_server": key_server,
"ts_added_ms": ts_now_ms,
"tls_certificate": tls_certificate_bytes,
"fingerprint": fingerprint,
"from_server": from_server,
"ts_added_ms": time_now_ms,
"tls_certificate": buffer(tls_certificate_bytes),
},
or_ignore=True,
)
@defer.inlineCallbacks
def get_server_verification_key(self, server_name):
"""Retrieve the NACL verification key for a given server
def get_server_verify_keys(self, server_name, key_ids):
"""Retrieve the NACL verification key for a given server for the given
key_ids
Args:
server_name (bytes): The name of the server.
server_name (str): The name of the server.
key_ids (list of str): List of key_ids to try and look up.
Returns:
(nacl.signing.VerifyKey): The verification key.
(list of VerifyKey): The verification keys.
"""
verification_key_bytes, = yield self._simple_select_one(
table="server_signature_keys",
key_values={"server_name": server_name},
retcols=("tls_certificate",),
sql = (
"SELECT key_id, verify_key FROM server_signature_keys"
" WHERE server_name = ?"
" AND key_id in (" + ",".join("?" for key_id in key_ids) + ")"
)
verification_key = nacl.signing.VerifyKey(verification_key_bytes)
defer.returnValue(verification_key)
def store_server_verification_key(self, server_name, key_version,
key_server, ts_now_ms, verification_key):
rows = yield self._execute_and_decode(sql, server_name, *key_ids)
keys = []
for row in rows:
key_id = row["key_id"]
key_bytes = row["verify_key"]
key = decode_verify_key_bytes(key_id, str(key_bytes))
keys.append(key)
defer.returnValue(keys)
def store_server_verify_key(self, server_name, from_server, time_now_ms,
verify_key):
"""Stores a NACL verification key for the given server.
Args:
server_name (bytes): The name of the server.
key_version (bytes): The version of the key for the server.
key_server (bytes): Where the verification key was looked up
server_name (str): The name of the server.
key_id (str): The version of the key for the server.
from_server (str): Where the verification key was looked up
ts_now_ms (int): The time now in milliseconds
verification_key (nacl.signing.VerifyKey): The NACL verify key.
verification_key (VerifyKey): The NACL verify key.
"""
verification_key_bytes = verification_key.encode()
verify_key_bytes = verify_key.encode()
return self._simple_insert(
table="server_signature_keys",
key_values={
values={
"server_name": server_name,
"key_version": key_version,
"key_server": key_server,
"ts_added_ms": ts_now_ms,
"verification_key": verification_key_bytes,
"key_id": "%s:%s" % (verify_key.alg, verify_key.version),
"from_server": from_server,
"ts_added_ms": time_now_ms,
"verify_key": buffer(verify_key.encode()),
},
or_ignore=True,
)

View file

@ -14,17 +14,18 @@
*/
CREATE TABLE IF NOT EXISTS server_tls_certificates(
server_name TEXT, -- Server name.
key_server TEXT, -- Which key server the certificate was fetched from.
fingerprint TEXT, -- Certificate fingerprint.
from_server TEXT, -- Which key server the certificate was fetched from.
ts_added_ms INTEGER, -- When the certifcate was added.
tls_certificate BLOB, -- DER encoded x509 certificate.
CONSTRAINT uniqueness UNIQUE (server_name)
CONSTRAINT uniqueness UNIQUE (server_name, fingerprint)
);
CREATE TABLE IF NOT EXISTS server_signature_keys(
server_name TEXT, -- Server name.
key_version TEXT, -- Key version.
key_server TEXT, -- Which key server the key was fetched form.
key_id TEXT, -- Key version.
from_server TEXT, -- Which key server the key was fetched form.
ts_added_ms INTEGER, -- When the key was added.
verification_key BLOB, -- NACL verification key.
CONSTRAINT uniqueness UNIQUE (server_name, key_version)
verify_key BLOB, -- NACL verification key.
CONSTRAINT uniqueness UNIQUE (server_name, key_id)
);

View file

@ -19,7 +19,7 @@ from tests import unittest
# python imports
from mock import Mock, ANY
from ..utils import MockHttpResource, MockClock
from ..utils import MockHttpResource, MockClock, MockKey
from synapse.server import HomeServer
from synapse.federation import initialize_http_replication
@ -64,6 +64,8 @@ class FederationTestCase(unittest.TestCase):
self.mock_persistence.get_received_txn_response.return_value = (
defer.succeed(None)
)
self.mock_config = Mock()
self.mock_config.signing_key = [MockKey()]
self.clock = MockClock()
hs = HomeServer("test",
resource_for_federation=self.mock_resource,
@ -71,6 +73,8 @@ class FederationTestCase(unittest.TestCase):
db_pool=None,
datastore=self.mock_persistence,
clock=self.clock,
config=self.mock_config,
keyring=Mock(),
)
self.federation = initialize_http_replication(hs)
self.distributor = hs.get_distributor()
@ -182,7 +186,7 @@ class FederationTestCase(unittest.TestCase):
},
]
},
on_send_callback=ANY,
json_data_callback=ANY,
)
@defer.inlineCallbacks
@ -214,9 +218,10 @@ class FederationTestCase(unittest.TestCase):
}
],
},
on_send_callback=ANY,
json_data_callback=ANY,
)
@defer.inlineCallbacks
def test_recv_edu(self):
recv_observer = Mock()

View file

@ -26,12 +26,16 @@ from synapse.federation.units import Pdu
from mock import NonCallableMock, ANY
from ..utils import get_mock_call_args
from ..utils import get_mock_call_args, MockKey
class FederationTestCase(unittest.TestCase):
def setUp(self):
self.mock_config = NonCallableMock()
self.mock_config.signing_key = [MockKey()]
self.hostname = "test"
hs = HomeServer(
self.hostname,
@ -48,6 +52,7 @@ class FederationTestCase(unittest.TestCase):
"room_member_handler",
"federation_handler",
]),
config=self.mock_config,
)
self.datastore = hs.get_datastore()

View file

@ -17,11 +17,12 @@
from tests import unittest
from twisted.internet import defer, reactor
from mock import Mock, call, ANY
from mock import Mock, call, ANY, NonCallableMock, patch
import json
from tests.utils import (
MockHttpResource, MockClock, DeferredMockCallable, SQLiteMemoryDbPool
MockHttpResource, MockClock, DeferredMockCallable, SQLiteMemoryDbPool,
MockKey
)
from synapse.server import HomeServer
@ -58,7 +59,6 @@ class JustPresenceHandlers(object):
def __init__(self, hs):
self.presence_handler = PresenceHandler(hs)
class PresenceStateTestCase(unittest.TestCase):
""" Tests presence management. """
@ -67,12 +67,17 @@ class PresenceStateTestCase(unittest.TestCase):
db_pool = SQLiteMemoryDbPool()
yield db_pool.prepare()
self.mock_config = NonCallableMock()
self.mock_config.signing_key = [MockKey()]
hs = HomeServer("test",
clock=MockClock(),
db_pool=db_pool,
handlers=None,
resource_for_federation=Mock(),
http_client=None,
config=self.mock_config,
keyring=Mock(),
)
hs.handlers = JustPresenceHandlers(hs)
@ -214,6 +219,9 @@ class PresenceInvitesTestCase(unittest.TestCase):
db_pool = SQLiteMemoryDbPool()
yield db_pool.prepare()
self.mock_config = NonCallableMock()
self.mock_config.signing_key = [MockKey()]
hs = HomeServer("test",
clock=MockClock(),
db_pool=db_pool,
@ -221,6 +229,8 @@ class PresenceInvitesTestCase(unittest.TestCase):
resource_for_client=Mock(),
resource_for_federation=self.mock_federation_resource,
http_client=self.mock_http_client,
config=self.mock_config,
keyring=Mock(),
)
hs.handlers = JustPresenceHandlers(hs)
@ -290,7 +300,7 @@ class PresenceInvitesTestCase(unittest.TestCase):
"observed_user": "@cabbage:elsewhere",
}
),
on_send_callback=ANY,
json_data_callback=ANY,
),
defer.succeed((200, "OK"))
)
@ -319,7 +329,7 @@ class PresenceInvitesTestCase(unittest.TestCase):
"observed_user": "@apple:test",
}
),
on_send_callback=ANY,
json_data_callback=ANY,
),
defer.succeed((200, "OK"))
)
@ -355,7 +365,7 @@ class PresenceInvitesTestCase(unittest.TestCase):
"observed_user": "@durian:test",
}
),
on_send_callback=ANY,
json_data_callback=ANY,
),
defer.succeed((200, "OK"))
)
@ -503,6 +513,9 @@ class PresencePushTestCase(unittest.TestCase):
self.mock_federation_resource = MockHttpResource()
self.mock_config = NonCallableMock()
self.mock_config.signing_key = [MockKey()]
hs = HomeServer("test",
clock=self.clock,
db_pool=None,
@ -520,6 +533,8 @@ class PresencePushTestCase(unittest.TestCase):
resource_for_client=Mock(),
resource_for_federation=self.mock_federation_resource,
http_client=self.mock_http_client,
config=self.mock_config,
keyring=Mock(),
)
hs.handlers = JustPresenceHandlers(hs)
@ -771,7 +786,7 @@ class PresencePushTestCase(unittest.TestCase):
],
}
),
on_send_callback=ANY,
json_data_callback=ANY,
),
defer.succeed((200, "OK"))
)
@ -787,7 +802,7 @@ class PresencePushTestCase(unittest.TestCase):
],
}
),
on_send_callback=ANY,
json_data_callback=ANY,
),
defer.succeed((200, "OK"))
)
@ -913,7 +928,7 @@ class PresencePushTestCase(unittest.TestCase):
],
}
),
on_send_callback=ANY,
json_data_callback=ANY,
),
defer.succeed((200, "OK"))
)
@ -928,7 +943,7 @@ class PresencePushTestCase(unittest.TestCase):
],
}
),
on_send_callback=ANY,
json_data_callback=ANY,
),
defer.succeed((200, "OK"))
)
@ -958,7 +973,7 @@ class PresencePushTestCase(unittest.TestCase):
],
}
),
on_send_callback=ANY,
json_data_callback=ANY,
),
defer.succeed((200, "OK"))
)
@ -995,6 +1010,9 @@ class PresencePollingTestCase(unittest.TestCase):
self.mock_federation_resource = MockHttpResource()
self.mock_config = NonCallableMock()
self.mock_config.signing_key = [MockKey()]
hs = HomeServer("test",
clock=MockClock(),
db_pool=None,
@ -1009,6 +1027,8 @@ class PresencePollingTestCase(unittest.TestCase):
resource_for_client=Mock(),
resource_for_federation=self.mock_federation_resource,
http_client=self.mock_http_client,
config=self.mock_config,
keyring=Mock(),
)
hs.handlers = JustPresenceHandlers(hs)
@ -1155,7 +1175,7 @@ class PresencePollingTestCase(unittest.TestCase):
"poll": [ "@potato:remote" ],
},
),
on_send_callback=ANY,
json_data_callback=ANY,
),
defer.succeed((200, "OK"))
)
@ -1168,7 +1188,7 @@ class PresencePollingTestCase(unittest.TestCase):
"push": [ {"user_id": "@clementine:test" }],
},
),
on_send_callback=ANY,
json_data_callback=ANY,
),
defer.succeed((200, "OK"))
)
@ -1197,7 +1217,7 @@ class PresencePollingTestCase(unittest.TestCase):
"push": [ {"user_id": "@fig:test" }],
},
),
on_send_callback=ANY,
json_data_callback=ANY,
),
defer.succeed((200, "OK"))
)
@ -1230,7 +1250,7 @@ class PresencePollingTestCase(unittest.TestCase):
"unpoll": [ "@potato:remote" ],
},
),
on_send_callback=ANY,
json_data_callback=ANY,
),
defer.succeed((200, "OK"))
)
@ -1262,7 +1282,7 @@ class PresencePollingTestCase(unittest.TestCase):
],
},
),
on_send_callback=ANY,
json_data_callback=ANY,
),
defer.succeed((200, "OK"))
)

View file

@ -24,6 +24,7 @@ from synapse.api.constants import Membership
from synapse.handlers.room import RoomMemberHandler, RoomCreationHandler
from synapse.handlers.profile import ProfileHandler
from synapse.server import HomeServer
from ..utils import MockKey
from mock import Mock, NonCallableMock
@ -31,6 +32,8 @@ from mock import Mock, NonCallableMock
class RoomMemberHandlerTestCase(unittest.TestCase):
def setUp(self):
self.mock_config = NonCallableMock()
self.mock_config.signing_key = [MockKey()]
self.hostname = "red"
hs = HomeServer(
self.hostname,
@ -38,7 +41,6 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
ratelimiter=NonCallableMock(spec_set=[
"send_message",
]),
config=NonCallableMock(),
datastore=NonCallableMock(spec_set=[
"persist_event",
"get_joined_hosts_for_room",
@ -57,6 +59,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
]),
auth=NonCallableMock(spec_set=["check"]),
state_handler=NonCallableMock(spec_set=["handle_new_event"]),
config=self.mock_config,
)
self.federation = NonCallableMock(spec_set=[

View file

@ -20,7 +20,7 @@ from twisted.internet import defer
from mock import Mock, call, ANY
import json
from ..utils import MockHttpResource, MockClock, DeferredMockCallable
from ..utils import MockHttpResource, MockClock, DeferredMockCallable, MockKey
from synapse.server import HomeServer
from synapse.handlers.typing import TypingNotificationHandler
@ -61,6 +61,9 @@ class TypingNotificationsTestCase(unittest.TestCase):
self.mock_federation_resource = MockHttpResource()
self.mock_config = Mock()
self.mock_config.signing_key = [MockKey()]
hs = HomeServer("test",
clock=self.clock,
db_pool=None,
@ -75,6 +78,8 @@ class TypingNotificationsTestCase(unittest.TestCase):
resource_for_client=Mock(),
resource_for_federation=self.mock_federation_resource,
http_client=self.mock_http_client,
config=self.mock_config,
keyring=Mock(),
)
hs.handlers = JustTypingNotificationHandlers(hs)
@ -170,7 +175,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
"typing": True,
}
),
on_send_callback=ANY,
json_data_callback=ANY,
),
defer.succeed((200, "OK"))
)
@ -221,7 +226,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
"typing": False,
}
),
on_send_callback=ANY,
json_data_callback=ANY,
),
defer.succeed((200, "OK"))
)

View file

@ -20,7 +20,7 @@ from twisted.internet import defer
from mock import Mock
from ..utils import MockHttpResource
from ..utils import MockHttpResource, MockKey
from synapse.api.constants import PresenceState
from synapse.handlers.presence import PresenceHandler
@ -45,7 +45,8 @@ class PresenceStateTestCase(unittest.TestCase):
def setUp(self):
self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
self.mock_config = Mock()
self.mock_config.signing_key = [MockKey()]
hs = HomeServer("test",
db_pool=None,
datastore=Mock(spec=[
@ -56,7 +57,7 @@ class PresenceStateTestCase(unittest.TestCase):
http_client=None,
resource_for_client=self.mock_resource,
resource_for_federation=self.mock_resource,
config=Mock(),
config=self.mock_config,
)
hs.handlers = JustPresenceHandlers(hs)
@ -125,6 +126,8 @@ class PresenceListTestCase(unittest.TestCase):
def setUp(self):
self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
self.mock_config = Mock()
self.mock_config.signing_key = [MockKey()]
hs = HomeServer("test",
db_pool=None,
@ -142,7 +145,7 @@ class PresenceListTestCase(unittest.TestCase):
http_client=None,
resource_for_client=self.mock_resource,
resource_for_federation=self.mock_resource,
config=Mock(),
config=self.mock_config,
)
hs.handlers = JustPresenceHandlers(hs)
@ -237,6 +240,9 @@ class PresenceEventStreamTestCase(unittest.TestCase):
def setUp(self):
self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
self.mock_config = Mock()
self.mock_config.signing_key = [MockKey()]
# HIDEOUS HACKERY
# TODO(paul): This should be injected in via the HomeServer DI system
from synapse.streams.events import (
@ -267,6 +273,7 @@ class PresenceEventStreamTestCase(unittest.TestCase):
"cancel_call_later",
"time_msec",
]),
config=self.mock_config,
)
hs.get_clock().time_msec.return_value = 1000000

View file

@ -76,6 +76,13 @@ class MockHttpResource(HttpServer):
mock_content.configure_mock(**config)
mock_request.content = mock_content
mock_request.method = http_method
mock_request.uri = path
mock_request.requestHeaders.getRawHeaders.return_value=[
"X-Matrix origin=test,key=,sig="
]
# return the right path if the event requires it
mock_request.path = path
@ -108,6 +115,21 @@ class MockHttpResource(HttpServer):
self.callbacks.append((method, path_pattern, callback))
class MockKey(object):
alg = "mock_alg"
version = "mock_version"
@property
def verify_key(self):
return self
def sign(self, message):
return b"\x9a\x87$"
def verify(self, message, sig):
assert sig == b"\x9a\x87$"
class MockClock(object):
now = 1000