SYN-75 Verify signatures on server to server transactions

This commit is contained in:
Mark Haines 2014-09-30 15:15:10 +01:00
parent 52ca867670
commit b95a178584
14 changed files with 245 additions and 235 deletions

View file

@ -15,9 +15,10 @@
from twisted.web.http import HTTPClient from twisted.web.http import HTTPClient
from twisted.internet.protocol import Factory
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.internet.protocol import ClientFactory from twisted.internet.endpoints import connectProtocol
from twisted.names.srvconnect import SRVConnector from synapse.http.endpoint import matrix_endpoint
import json import json
import logging import logging
@ -30,15 +31,19 @@ def fetch_server_key(server_name, ssl_context_factory):
"""Fetch the keys for a remote server.""" """Fetch the keys for a remote server."""
factory = SynapseKeyClientFactory() factory = SynapseKeyClientFactory()
endpoint = matrix_endpoint(
reactor, server_name, ssl_context_factory, timeout=30
)
SRVConnector( for i in range(5):
reactor, "matrix", server_name, factory, try:
protocol="tcp", connectFuncName="connectSSL", defaultPort=443, protocol = yield endpoint.connect(factory)
connectFuncKwArgs=dict(contextFactory=ssl_context_factory)).connect() server_response, server_certificate = yield protocol.remote_key
defer.returnValue((server_response, server_certificate))
server_key, server_certificate = yield factory.remote_key return
except Exception as e:
defer.returnValue((server_key, server_certificate)) logger.exception(e)
raise IOError("Cannot get key for " % server_name)
class SynapseKeyClientError(Exception): class SynapseKeyClientError(Exception):
@ -51,69 +56,47 @@ class SynapseKeyClientProtocol(HTTPClient):
the server and extracts the X.509 certificate for the remote peer from the the server and extracts the X.509 certificate for the remote peer from the
SSL connection.""" SSL connection."""
timeout = 30
def __init__(self):
self.remote_key = defer.Deferred()
def connectionMade(self): def connectionMade(self):
logger.debug("Connected to %s", self.transport.getHost()) logger.debug("Connected to %s", self.transport.getHost())
self.sendCommand(b"GET", b"/key") self.sendCommand(b"GET", b"/_matrix/key/v1/")
self.endHeaders() self.endHeaders()
self.timer = reactor.callLater( self.timer = reactor.callLater(
self.factory.timeout_seconds, self.timeout,
self.on_timeout self.on_timeout
) )
def handleStatus(self, version, status, message): def handleStatus(self, version, status, message):
if status != b"200": if status != b"200":
logger.info("Non-200 response from %s: %s %s", #logger.info("Non-200 response from %s: %s %s",
self.transport.getHost(), status, message) # self.transport.getHost(), status, message)
self.transport.abortConnection() self.transport.abortConnection()
def handleResponse(self, response_body_bytes): def handleResponse(self, response_body_bytes):
try: try:
json_response = json.loads(response_body_bytes) json_response = json.loads(response_body_bytes)
except ValueError: except ValueError:
logger.info("Invalid JSON response from %s", #logger.info("Invalid JSON response from %s",
self.transport.getHost()) # self.transport.getHost())
self.transport.abortConnection() self.transport.abortConnection()
return return
certificate = self.transport.getPeerCertificate() certificate = self.transport.getPeerCertificate()
self.factory.on_remote_key((json_response, certificate)) self.remote_key.callback((json_response, certificate))
self.transport.abortConnection() self.transport.abortConnection()
self.timer.cancel() self.timer.cancel()
def on_timeout(self): def on_timeout(self):
logger.debug("Timeout waiting for response from %s", logger.debug("Timeout waiting for response from %s",
self.transport.getHost()) self.transport.getHost())
self.on_remote_key.errback(IOError("Timeout waiting for response"))
self.transport.abortConnection() self.transport.abortConnection()
class SynapseKeyClientFactory(ClientFactory): class SynapseKeyClientFactory(Factory):
protocol = SynapseKeyClientProtocol protocol = SynapseKeyClientProtocol
max_retries = 5
timeout_seconds = 30
def __init__(self):
self.succeeded = False
self.retries = 0
self.remote_key = defer.Deferred()
def on_remote_key(self, key):
self.succeeded = True
self.remote_key.callback(key)
def retry_connection(self, connector):
self.retries += 1
if self.retries < self.max_retries:
connector.connector = None
connector.connect()
else:
self.remote_key.errback(
SynapseKeyClientError("Max retries exceeded"))
def clientConnectionFailed(self, connector, reason):
logger.info("Connection failed %s", reason)
self.retry_connection(connector)
def clientConnectionLost(self, connector, reason):
logger.info("Connection lost %s", reason)
if not self.succeeded:
self.retry_connection(connector)

125
synapse/crypto/keyring.py Normal file
View file

@ -0,0 +1,125 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.crypto.keyclient import fetch_server_key
from twisted.internet import defer
from syutil.crypto.jsonsign import verify_signed_json, signature_ids
from syutil.crypto.signing_key import (
is_signing_algorithm_supported, decode_verify_key_bytes
)
from syutil.base64util import decode_base64, encode_base64
from OpenSSL import crypto
import logging
logger = logging.getLogger(__name__)
class Keyring(object):
def __init__(self, hs):
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.hs = hs
@defer.inlineCallbacks
def verify_json_for_server(self, server_name, json_object):
key_ids = signature_ids(json_object, server_name)
verify_key = yield self.get_server_verify_key(server_name, key_ids)
verify_signed_json(json_object, server_name, verify_key)
@defer.inlineCallbacks
def get_server_verify_key(self, server_name, key_ids):
"""Finds a verification key for the server with one of the key ids.
Args:
server_name (str): The name of the server to fetch a key for.
keys_ids (list of str): The key_ids to check for.
"""
# Check the datastore to see if we have one cached.
cached = yield self.store.get_server_verify_keys(server_name, key_ids)
if cached:
defer.returnValue(cached[0])
return
# Try to fetch the key from the remote server.
# TODO(markjh): Ratelimit requests to a given server.
(response, tls_certificate) = yield fetch_server_key(
server_name, self.hs.tls_context_factory
)
# Check the response.
x509_certificate_bytes = crypto.dump_certificate(
crypto.FILETYPE_ASN1, tls_certificate
)
if ("signatures" not in response
or server_name not in response["signatures"]):
raise ValueError("Key response not signed by remote server")
if "tls_certificate" not in response:
raise ValueError("Key response missing TLS certificate")
tls_certificate_b64 = response["tls_certificate"]
if encode_base64(x509_certificate_bytes) != tls_certificate_b64:
raise ValueError("TLS certificate doesn't match")
verify_keys = {}
for key_id, key_base64 in response["verify_keys"].items():
if is_signing_algorithm_supported(key_id):
key_bytes = decode_base64(key_base64)
verify_key = decode_verify_key_bytes(key_id, key_bytes)
verify_keys[key_id] = verify_key
for key_id in response["signatures"][server_name]:
if key_id not in response["verify_keys"]:
raise ValueError(
"Key response must include verification keys for all"
" signatures"
)
if key_id in verify_keys:
verify_signed_json(
response,
server_name,
verify_keys[key_id]
)
# Cache the result in the datastore.
time_now_ms = self.clock.time_msec()
self.store.store_server_certificate(
server_name,
server_name,
time_now_ms,
tls_certificate,
)
for key_id, key in verify_keys.items():
self.store.store_server_verify_key(
server_name, server_name, time_now_ms, key
)
for key_id in key_ids:
if key_id in verify_keys:
defer.returnValue(verify_keys[key_id])
return
raise ValueError("No verification key found for given key ids")

View file

@ -1,111 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import reactor, ssl
from twisted.web import server
from twisted.web.resource import Resource
from twisted.python.log import PythonLoggingObserver
from synapse.crypto.resource.key import LocalKey
from synapse.crypto.config import load_config
from syutil.base64util import decode_base64
from OpenSSL import crypto, SSL
import logging
import nacl.signing
import sys
class KeyServerSSLContextFactory(ssl.ContextFactory):
"""Factory for PyOpenSSL SSL contexts that are used to handle incoming
connections and to make connections to remote servers."""
def __init__(self, key_server):
self._context = SSL.Context(SSL.SSLv23_METHOD)
self.configure_context(self._context, key_server)
@staticmethod
def configure_context(context, key_server):
context.set_options(SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3)
context.use_certificate(key_server.tls_certificate)
context.use_privatekey(key_server.tls_private_key)
context.load_tmp_dh(key_server.tls_dh_params_path)
context.set_cipher_list("!ADH:HIGH+kEDH:!AECDH:HIGH+kEECDH")
def getContext(self):
return self._context
class KeyServer(object):
"""An HTTPS server serving LocalKey and RemoteKey resources."""
def __init__(self, server_name, tls_certificate_path, tls_private_key_path,
tls_dh_params_path, signing_key_path, bind_host, bind_port):
self.server_name = server_name
self.tls_certificate = self.read_tls_certificate(tls_certificate_path)
self.tls_private_key = self.read_tls_private_key(tls_private_key_path)
self.tls_dh_params_path = tls_dh_params_path
self.signing_key = self.read_signing_key(signing_key_path)
self.bind_host = bind_host
self.bind_port = int(bind_port)
self.ssl_context_factory = KeyServerSSLContextFactory(self)
@staticmethod
def read_tls_certificate(cert_path):
with open(cert_path) as cert_file:
cert_pem = cert_file.read()
return crypto.load_certificate(crypto.FILETYPE_PEM, cert_pem)
@staticmethod
def read_tls_private_key(private_key_path):
with open(private_key_path) as private_key_file:
private_key_pem = private_key_file.read()
return crypto.load_privatekey(crypto.FILETYPE_PEM, private_key_pem)
@staticmethod
def read_signing_key(signing_key_path):
with open(signing_key_path) as signing_key_file:
signing_key_b64 = signing_key_file.read()
signing_key_bytes = decode_base64(signing_key_b64)
return nacl.signing.SigningKey(signing_key_bytes)
def run(self):
root = Resource()
root.putChild("key", LocalKey(self))
site = server.Site(root)
reactor.listenSSL(
self.bind_port,
site,
self.ssl_context_factory,
interface=self.bind_host
)
logging.basicConfig(level=logging.DEBUG)
observer = PythonLoggingObserver()
observer.start()
reactor.run()
def main():
key_server = KeyServer(**load_config(__doc__, sys.argv[1:]))
key_server.run()
if __name__ == "__main__":
main()

View file

@ -1,15 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -66,6 +66,8 @@ class ReplicationLayer(object):
hs, self.transaction_actions, transport_layer hs, self.transaction_actions, transport_layer
) )
self.keyring = hs.get_keyring()
self.handler = None self.handler = None
self.edu_handlers = {} self.edu_handlers = {}
self.query_handlers = {} self.query_handlers = {}
@ -291,6 +293,10 @@ class ReplicationLayer(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def on_incoming_transaction(self, transaction_data): def on_incoming_transaction(self, transaction_data):
yield self.keyring.verify_json_for_server(
transaction_data["origin"], transaction_data
)
transaction = Transaction(**transaction_data) transaction = Transaction(**transaction_data)
for p in transaction.pdus: for p in transaction.pdus:
@ -590,7 +596,7 @@ class _TransactionQueue(object):
transaction = Transaction.create_new( transaction = Transaction.create_new(
ts=self._clock.time_msec(), ts=self._clock.time_msec(),
transaction_id=self._next_txn_id, transaction_id=str(self._next_txn_id),
origin=self.server_name, origin=self.server_name,
destination=destination, destination=destination,
pdus=pdus, pdus=pdus,
@ -611,20 +617,18 @@ class _TransactionQueue(object):
# FIXME (erikj): This is a bit of a hack to make the Pdu age # FIXME (erikj): This is a bit of a hack to make the Pdu age
# keys work # keys work
def cb(transaction): def json_data_cb():
data = transaction.get_dict()
now = int(self._clock.time_msec()) now = int(self._clock.time_msec())
if "pdus" in transaction: if "pdus" in data:
for p in transaction["pdus"]: for p in data["pdus"]:
if "age_ts" in p: if "age_ts" in p:
p["age"] = now - int(p["age_ts"]) p["age"] = now - int(p["age_ts"])
data = sign_json(data, server_name, signing_key)
transaction = sign_json(transaction, server_name, signing_key) return data
return transaction
code, response = yield self.transport_layer.send_transaction( code, response = yield self.transport_layer.send_transaction(
transaction, transaction, json_data_cb
on_send_callback=cb,
) )
logger.debug("TX [%s] Sent transaction", destination) logger.debug("TX [%s] Sent transaction", destination)

View file

@ -144,7 +144,7 @@ class TransportLayer(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def send_transaction(self, transaction, on_send_callback=None): def send_transaction(self, transaction, json_data_callback=None):
""" Sends the given Transaction to it's destination """ Sends the given Transaction to it's destination
Args: Args:
@ -163,24 +163,26 @@ class TransportLayer(object):
if transaction.destination == self.server_name: if transaction.destination == self.server_name:
raise RuntimeError("Transport layer cannot send to itself!") raise RuntimeError("Transport layer cannot send to itself!")
data = transaction.get_dict() if json_data_callback is None:
def json_data_callback():
return transaction.get_dict()
# FIXME (erikj): This is a bit of a hack to make the Pdu age # FIXME (erikj): This is a bit of a hack to make the Pdu age
# keys work # keys work
def cb(destination, method, path_bytes, producer): def cb(destination, method, path_bytes, producer):
if not on_send_callback: json_data = json_data_callback()
return del json_data["destination"]
del json_data["transaction_id"]
producer.reset(json_data)
transaction = json.loads(producer.body) json_data = transaction.get_dict()
del json_data["destination"]
new_transaction = on_send_callback(transaction) del json_data["transaction_id"]
producer.reset(new_transaction)
code, response = yield self.client.put_json( code, response = yield self.client.put_json(
transaction.destination, transaction.destination,
path=PREFIX + "/send/%s/" % transaction.transaction_id, path=PREFIX + "/send/%s/" % transaction.transaction_id,
data=data, data=json_data,
on_send_callback=cb, on_send_callback=cb,
) )

View file

@ -186,9 +186,6 @@ class Transaction(JsonEncodedObject):
"previous_ids", "previous_ids",
"pdus", "pdus",
"edus", "edus",
]
internal_keys = [
"transaction_id", "transaction_id",
"destination", "destination",
] ]

View file

@ -34,6 +34,7 @@ from synapse.util.distributor import Distributor
from synapse.util.lockutils import LockManager from synapse.util.lockutils import LockManager
from synapse.streams.events import EventSources from synapse.streams.events import EventSources
from synapse.api.ratelimiting import Ratelimiter from synapse.api.ratelimiting import Ratelimiter
from synapse.crypto.keyring import Keyring
class BaseHomeServer(object): class BaseHomeServer(object):
@ -78,6 +79,7 @@ class BaseHomeServer(object):
'resource_for_server_key', 'resource_for_server_key',
'event_sources', 'event_sources',
'ratelimiter', 'ratelimiter',
'keyring',
] ]
def __init__(self, hostname, **kwargs): def __init__(self, hostname, **kwargs):
@ -201,6 +203,9 @@ class HomeServer(BaseHomeServer):
def build_ratelimiter(self): def build_ratelimiter(self):
return Ratelimiter() return Ratelimiter()
def build_keyring(self):
return Keyring(self)
def register_servlets(self): def register_servlets(self):
""" Register all servlets associated with this HomeServer. """ Register all servlets associated with this HomeServer.
""" """

View file

@ -56,6 +56,7 @@ SCHEMAS = [
"presence", "presence",
"im", "im",
"room_aliases", "room_aliases",
"keys",
] ]

View file

@ -18,7 +18,8 @@ from _base import SQLBaseStore
from twisted.internet import defer from twisted.internet import defer
import OpenSSL import OpenSSL
import nacl.signing from syutil.crypto.signing_key import decode_verify_key_bytes
import hashlib
class KeyStore(SQLBaseStore): class KeyStore(SQLBaseStore):
"""Persistence for signature verification keys and tls X.509 certificates """Persistence for signature verification keys and tls X.509 certificates
@ -42,62 +43,74 @@ class KeyStore(SQLBaseStore):
) )
defer.returnValue(tls_certificate) defer.returnValue(tls_certificate)
def store_server_certificate(self, server_name, key_server, ts_now_ms, def store_server_certificate(self, server_name, from_server, time_now_ms,
tls_certificate): tls_certificate):
"""Stores the TLS X.509 certificate for the given server """Stores the TLS X.509 certificate for the given server
Args: Args:
server_name (bytes): The name of the server. server_name (str): The name of the server.
key_server (bytes): Where the certificate was looked up from_server (str): Where the certificate was looked up
ts_now_ms (int): The time now in milliseconds time_now_ms (int): The time now in milliseconds
tls_certificate (OpenSSL.crypto.X509): The X.509 certificate. tls_certificate (OpenSSL.crypto.X509): The X.509 certificate.
""" """
tls_certificate_bytes = OpenSSL.crypto.dump_certificate( tls_certificate_bytes = OpenSSL.crypto.dump_certificate(
OpenSSL.crypto.FILETYPE_ASN1, tls_certificate OpenSSL.crypto.FILETYPE_ASN1, tls_certificate
) )
fingerprint = hashlib.sha256(tls_certificate_bytes).hexdigest()
return self._simple_insert( return self._simple_insert(
table="server_tls_certificates", table="server_tls_certificates",
keyvalues={ values={
"server_name": server_name, "server_name": server_name,
"key_server": key_server, "fingerprint": fingerprint,
"ts_added_ms": ts_now_ms, "from_server": from_server,
"tls_certificate": tls_certificate_bytes, "ts_added_ms": time_now_ms,
"tls_certificate": buffer(tls_certificate_bytes),
}, },
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def get_server_verification_key(self, server_name): def get_server_verify_keys(self, server_name, key_ids):
"""Retrieve the NACL verification key for a given server """Retrieve the NACL verification key for a given server for the given
key_ids
Args: Args:
server_name (bytes): The name of the server. server_name (str): The name of the server.
key_ids (list of str): List of key_ids to try and look up.
Returns: Returns:
(nacl.signing.VerifyKey): The verification key. (list of VerifyKey): The verification keys.
""" """
verification_key_bytes, = yield self._simple_select_one( sql = (
table="server_signature_keys", "SELECT key_id, verify_key FROM server_signature_keys"
key_values={"server_name": server_name}, " WHERE server_name = ?"
retcols=("tls_certificate",), " AND key_id in (" + ",".join("?" for key_id in key_ids) + ")"
) )
verification_key = nacl.signing.VerifyKey(verification_key_bytes)
defer.returnValue(verification_key)
def store_server_verification_key(self, server_name, key_version, rows = yield self._execute_and_decode(sql, server_name, *key_ids)
key_server, ts_now_ms, verification_key):
keys = []
for row in rows:
key_id = row["key_id"]
key_bytes = row["verify_key"]
key = decode_verify_key_bytes(key_id, str(key_bytes))
keys.append(key)
defer.returnValue(keys)
def store_server_verify_key(self, server_name, from_server, time_now_ms,
verify_key):
"""Stores a NACL verification key for the given server. """Stores a NACL verification key for the given server.
Args: Args:
server_name (bytes): The name of the server. server_name (str): The name of the server.
key_version (bytes): The version of the key for the server. key_id (str): The version of the key for the server.
key_server (bytes): Where the verification key was looked up from_server (str): Where the verification key was looked up
ts_now_ms (int): The time now in milliseconds ts_now_ms (int): The time now in milliseconds
verification_key (nacl.signing.VerifyKey): The NACL verify key. verification_key (VerifyKey): The NACL verify key.
""" """
verification_key_bytes = verification_key.encode() verify_key_bytes = verify_key.encode()
return self._simple_insert( return self._simple_insert(
table="server_signature_keys", table="server_signature_keys",
key_values={ values={
"server_name": server_name, "server_name": server_name,
"key_version": key_version, "key_id": "%s:%s" % (verify_key.alg, verify_key.version),
"key_server": key_server, "from_server": from_server,
"ts_added_ms": ts_now_ms, "ts_added_ms": time_now_ms,
"verification_key": verification_key_bytes, "verify_key": buffer(verify_key.encode()),
}, },
) )

View file

@ -14,17 +14,18 @@
*/ */
CREATE TABLE IF NOT EXISTS server_tls_certificates( CREATE TABLE IF NOT EXISTS server_tls_certificates(
server_name TEXT, -- Server name. server_name TEXT, -- Server name.
key_server TEXT, -- Which key server the certificate was fetched from. fingerprint TEXT, -- Certificate fingerprint.
from_server TEXT, -- Which key server the certificate was fetched from.
ts_added_ms INTEGER, -- When the certifcate was added. ts_added_ms INTEGER, -- When the certifcate was added.
tls_certificate BLOB, -- DER encoded x509 certificate. tls_certificate BLOB, -- DER encoded x509 certificate.
CONSTRAINT uniqueness UNIQUE (server_name) CONSTRAINT uniqueness UNIQUE (server_name, fingerprint)
); );
CREATE TABLE IF NOT EXISTS server_signature_keys( CREATE TABLE IF NOT EXISTS server_signature_keys(
server_name TEXT, -- Server name. server_name TEXT, -- Server name.
key_version TEXT, -- Key version. key_id TEXT, -- Key version.
key_server TEXT, -- Which key server the key was fetched form. from_server TEXT, -- Which key server the key was fetched form.
ts_added_ms INTEGER, -- When the key was added. ts_added_ms INTEGER, -- When the key was added.
verification_key BLOB, -- NACL verification key. verify_key BLOB, -- NACL verification key.
CONSTRAINT uniqueness UNIQUE (server_name, key_version) CONSTRAINT uniqueness UNIQUE (server_name, key_id)
); );

View file

@ -74,6 +74,7 @@ class FederationTestCase(unittest.TestCase):
datastore=self.mock_persistence, datastore=self.mock_persistence,
clock=self.clock, clock=self.clock,
config=self.mock_config, config=self.mock_config,
keyring=Mock(),
) )
self.federation = initialize_http_replication(hs) self.federation = initialize_http_replication(hs)
self.distributor = hs.get_distributor() self.distributor = hs.get_distributor()

View file

@ -17,7 +17,7 @@
from tests import unittest from tests import unittest
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from mock import Mock, call, ANY, NonCallableMock from mock import Mock, call, ANY, NonCallableMock, patch
import json import json
from tests.utils import ( from tests.utils import (
@ -59,7 +59,6 @@ class JustPresenceHandlers(object):
def __init__(self, hs): def __init__(self, hs):
self.presence_handler = PresenceHandler(hs) self.presence_handler = PresenceHandler(hs)
class PresenceStateTestCase(unittest.TestCase): class PresenceStateTestCase(unittest.TestCase):
""" Tests presence management. """ """ Tests presence management. """
@ -78,6 +77,7 @@ class PresenceStateTestCase(unittest.TestCase):
resource_for_federation=Mock(), resource_for_federation=Mock(),
http_client=None, http_client=None,
config=self.mock_config, config=self.mock_config,
keyring=Mock(),
) )
hs.handlers = JustPresenceHandlers(hs) hs.handlers = JustPresenceHandlers(hs)
@ -230,6 +230,7 @@ class PresenceInvitesTestCase(unittest.TestCase):
resource_for_federation=self.mock_federation_resource, resource_for_federation=self.mock_federation_resource,
http_client=self.mock_http_client, http_client=self.mock_http_client,
config=self.mock_config, config=self.mock_config,
keyring=Mock(),
) )
hs.handlers = JustPresenceHandlers(hs) hs.handlers = JustPresenceHandlers(hs)
@ -533,6 +534,7 @@ class PresencePushTestCase(unittest.TestCase):
resource_for_federation=self.mock_federation_resource, resource_for_federation=self.mock_federation_resource,
http_client=self.mock_http_client, http_client=self.mock_http_client,
config=self.mock_config, config=self.mock_config,
keyring=Mock(),
) )
hs.handlers = JustPresenceHandlers(hs) hs.handlers = JustPresenceHandlers(hs)
@ -1026,6 +1028,7 @@ class PresencePollingTestCase(unittest.TestCase):
resource_for_federation=self.mock_federation_resource, resource_for_federation=self.mock_federation_resource,
http_client=self.mock_http_client, http_client=self.mock_http_client,
config=self.mock_config, config=self.mock_config,
keyring=Mock(),
) )
hs.handlers = JustPresenceHandlers(hs) hs.handlers = JustPresenceHandlers(hs)

View file

@ -79,6 +79,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
resource_for_federation=self.mock_federation_resource, resource_for_federation=self.mock_federation_resource,
http_client=self.mock_http_client, http_client=self.mock_http_client,
config=self.mock_config, config=self.mock_config,
keyring=Mock(),
) )
hs.handlers = JustTypingNotificationHandlers(hs) hs.handlers = JustTypingNotificationHandlers(hs)