Merge branch 'develop' of github.com:matrix-org/synapse into matrix-org-hotfixes

This commit is contained in:
Erik Johnston 2019-01-30 11:48:29 +00:00
commit d351be1567
92 changed files with 2135 additions and 527 deletions

1
.gitignore vendored
View file

@ -12,6 +12,7 @@ dbs/
dist/
docs/build/
*.egg-info
pip-wheel-metadata/
cmdclient_config.json
homeserver*.db

View file

@ -15,6 +15,7 @@ recursive-include docs *
recursive-include scripts *
recursive-include scripts-dev *
recursive-include synapse *.pyi
recursive-include tests *.pem
recursive-include tests *.py
recursive-include synapse/res *

1
changelog.d/4408.feature Normal file
View file

@ -0,0 +1 @@
Implement MSC1708 (.well-known routing for server-server federation)

View file

@ -1 +0,0 @@
Refactor 'sign_request' as 'build_auth_headers'

1
changelog.d/4409.feature Normal file
View file

@ -0,0 +1 @@
Implement MSC1708 (.well-known routing for server-server federation)

View file

@ -1 +0,0 @@
Remove redundant federation connection wrapping code

1
changelog.d/4426.feature Normal file
View file

@ -0,0 +1 @@
Implement MSC1708 (.well-known routing for server-server federation)

View file

@ -1 +0,0 @@
Remove redundant SynapseKeyClientProtocol magic

1
changelog.d/4427.feature Normal file
View file

@ -0,0 +1 @@
Implement MSC1708 (.well-known routing for server-server federation)

View file

@ -1 +0,0 @@
Refactor and cleanup for SRV record lookup

1
changelog.d/4428.feature Normal file
View file

@ -0,0 +1 @@
Implement MSC1708 (.well-known routing for server-server federation)

View file

@ -1 +0,0 @@
Move SRV logic into the Agent layer

1
changelog.d/4464.feature Normal file
View file

@ -0,0 +1 @@
Implement MSC1708 (.well-known routing for server-server federation)

View file

@ -1 +0,0 @@
Move SRV logic into the Agent layer

1
changelog.d/4468.feature Normal file
View file

@ -0,0 +1 @@
Implement MSC1708 (.well-known routing for server-server federation)

View file

@ -1 +0,0 @@
Move SRV logic into the Agent layer

1
changelog.d/4481.misc Normal file
View file

@ -0,0 +1 @@
Add infrastructure to support different event formats

1
changelog.d/4483.feature Normal file
View file

@ -0,0 +1 @@
Add support for room version 3

1
changelog.d/4486.bugfix Normal file
View file

@ -0,0 +1 @@
Workaround for login error when using both LDAP and internal authentication.

1
changelog.d/4487.feature Normal file
View file

@ -0,0 +1 @@
Implement MSC1708 (.well-known routing for server-server federation)

View file

@ -1 +0,0 @@
Fix idna and ipv6 literal handling in MatrixFederationAgent

1
changelog.d/4489.feature Normal file
View file

@ -0,0 +1 @@
Implement MSC1708 (.well-known routing for server-server federation)

1
changelog.d/4494.misc Normal file
View file

@ -0,0 +1 @@
Add infrastructure to support different event formats

1
changelog.d/4495.feature Normal file
View file

@ -0,0 +1 @@
Synapse will now reload TLS certificates from disk upon SIGHUP.

1
changelog.d/4496.misc Normal file
View file

@ -0,0 +1 @@
Add infrastructure to support different event formats

1
changelog.d/4498.misc Normal file
View file

@ -0,0 +1 @@
Clarify documentation for the `public_baseurl` config param

1
changelog.d/4499.feature Normal file
View file

@ -0,0 +1 @@
Add support for room version 3

1
changelog.d/4506.misc Normal file
View file

@ -0,0 +1 @@
Make it possible to set the log level for tests via an environment variable

1
changelog.d/4507.misc Normal file
View file

@ -0,0 +1 @@
Reduce the log level of linearizer lock acquirement to DEBUG.

1
changelog.d/4509.removal Normal file
View file

@ -0,0 +1 @@
Synapse no longer generates self-signed TLS certificates when generating a configuration file.

1
changelog.d/4510.misc Normal file
View file

@ -0,0 +1 @@
Add infrastructure to support different event formats

1
changelog.d/4511.feature Normal file
View file

@ -0,0 +1 @@
Implement MSC1708 (.well-known routing for server-server federation)

1
changelog.d/4512.bugfix Normal file
View file

@ -0,0 +1 @@
Fix a bug where setting a relative consent directory path would cause a crash.

1
changelog.d/4514.misc Normal file
View file

@ -0,0 +1 @@
Add infrastructure to support different event formats

1
changelog.d/4515.feature Normal file
View file

@ -0,0 +1 @@
Add support for room version 3

1
changelog.d/4516.feature Normal file
View file

@ -0,0 +1 @@
Implement MSC1708 (.well-known routing for server-server federation)

1
changelog.d/4519.misc Normal file
View file

@ -0,0 +1 @@
Fix code to comply with linting in PyFlakes 3.7.1.

1
changelog.d/4520.feature Normal file
View file

@ -0,0 +1 @@
Implement MSC1708 (.well-known routing for server-server federation)

1
changelog.d/4521.feature Normal file
View file

@ -0,0 +1 @@
Implement MSC1708 (.well-known routing for server-server federation)

View file

@ -46,7 +46,7 @@ def request_registration(
# Get the nonce
r = requests.get(url, verify=False)
if r.status_code is not 200:
if r.status_code != 200:
_print("ERROR! Received %d %s" % (r.status_code, r.reason))
if 400 <= r.status_code < 500:
try:
@ -84,7 +84,7 @@ def request_registration(
_print("Sending registration request...")
r = requests.post(url, json=data, verify=False)
if r.status_code is not 200:
if r.status_code != 200:
_print("ERROR! Received %d %s" % (r.status_code, r.reason))
if 400 <= r.status_code < 500:
try:

View file

@ -550,17 +550,6 @@ class Auth(object):
"""
return self.store.is_server_admin(user)
@defer.inlineCallbacks
def add_auth_events(self, builder, context):
prev_state_ids = yield context.get_prev_state_ids(self.store)
auth_ids = yield self.compute_auth_events(builder, prev_state_ids)
auth_events_entries = yield self.store.add_event_hashes(
auth_ids
)
builder.auth_events = auth_events_entries
@defer.inlineCallbacks
def compute_auth_events(self, event, current_state_ids, for_verification=False):
if event.type == EventTypes.Create:
@ -577,7 +566,7 @@ class Auth(object):
key = (EventTypes.JoinRules, "", )
join_rule_event_id = current_state_ids.get(key)
key = (EventTypes.Member, event.user_id, )
key = (EventTypes.Member, event.sender, )
member_event_id = current_state_ids.get(key)
key = (EventTypes.Create, "", )
@ -627,7 +616,7 @@ class Auth(object):
defer.returnValue(auth_ids)
def check_redaction(self, event, auth_events):
def check_redaction(self, room_version, event, auth_events):
"""Check whether the event sender is allowed to redact the target event.
Returns:
@ -640,7 +629,7 @@ class Auth(object):
AuthError if the event sender is definitely not allowed to redact
the target event.
"""
return event_auth.check_redaction(event, auth_events)
return event_auth.check_redaction(room_version, event, auth_events)
@defer.inlineCallbacks
def check_can_change_room_list(self, room_id, user):

View file

@ -104,7 +104,7 @@ class ThirdPartyEntityKind(object):
class RoomVersions(object):
V1 = "1"
V2 = "2"
VDH_TEST = "vdh-test-version"
V3 = "3"
STATE_V2_TEST = "state-v2-test"
@ -116,7 +116,7 @@ DEFAULT_ROOM_VERSION = RoomVersions.V1
KNOWN_ROOM_VERSIONS = {
RoomVersions.V1,
RoomVersions.V2,
RoomVersions.VDH_TEST,
RoomVersions.V3,
RoomVersions.STATE_V2_TEST,
}
@ -126,10 +126,12 @@ class EventFormatVersions(object):
independently from the room version.
"""
V1 = 1
V2 = 2
KNOWN_EVENT_FORMAT_VERSIONS = {
EventFormatVersions.V1,
EventFormatVersions.V2,
}

View file

@ -143,6 +143,9 @@ def listen_metrics(bind_addresses, port):
def listen_tcp(bind_addresses, port, factory, reactor=reactor, backlog=50):
"""
Create a TCP socket for a port and several addresses
Returns:
list (empty)
"""
for address in bind_addresses:
try:
@ -155,25 +158,37 @@ def listen_tcp(bind_addresses, port, factory, reactor=reactor, backlog=50):
except error.CannotListenError as e:
check_bind_error(e, address, bind_addresses)
logger.info("Synapse now listening on TCP port %d", port)
return []
def listen_ssl(
bind_addresses, port, factory, context_factory, reactor=reactor, backlog=50
):
"""
Create an SSL socket for a port and several addresses
Create an TLS-over-TCP socket for a port and several addresses
Returns:
list of twisted.internet.tcp.Port listening for TLS connections
"""
r = []
for address in bind_addresses:
try:
reactor.listenSSL(
port,
factory,
context_factory,
backlog,
address
r.append(
reactor.listenSSL(
port,
factory,
context_factory,
backlog,
address
)
)
except error.CannotListenError as e:
check_bind_error(e, address, bind_addresses)
logger.info("Synapse now listening on port %d (TLS)", port)
return r
def check_bind_error(e, address, bind_addresses):
"""

View file

@ -17,6 +17,7 @@
import gc
import logging
import os
import signal
import sys
import traceback
@ -27,6 +28,7 @@ from prometheus_client import Gauge
from twisted.application import service
from twisted.internet import defer, reactor
from twisted.protocols.tls import TLSMemoryBIOFactory
from twisted.web.resource import EncodingResourceWrapper, NoResource
from twisted.web.server import GzipEncoderFactory
from twisted.web.static import File
@ -84,6 +86,7 @@ def gz_wrap(r):
class SynapseHomeServer(HomeServer):
DATASTORE_CLASS = DataStore
_listening_services = []
def _listener_http(self, config, listener_config):
port = listener_config["port"]
@ -121,7 +124,7 @@ class SynapseHomeServer(HomeServer):
root_resource = create_resource_tree(resources, root_resource)
if tls:
listen_ssl(
return listen_ssl(
bind_addresses,
port,
SynapseSite(
@ -135,7 +138,7 @@ class SynapseHomeServer(HomeServer):
)
else:
listen_tcp(
return listen_tcp(
bind_addresses,
port,
SynapseSite(
@ -146,7 +149,6 @@ class SynapseHomeServer(HomeServer):
self.version_string,
)
)
logger.info("Synapse now listening on port %d", port)
def _configure_named_resource(self, name, compress=False):
"""Build a resource map for a named resource
@ -242,7 +244,9 @@ class SynapseHomeServer(HomeServer):
for listener in config.listeners:
if listener["type"] == "http":
self._listener_http(config, listener)
self._listening_services.extend(
self._listener_http(config, listener)
)
elif listener["type"] == "manhole":
listen_tcp(
listener["bind_addresses"],
@ -322,7 +326,19 @@ def setup(config_options):
# generating config files and shouldn't try to continue.
sys.exit(0)
synapse.config.logger.setup_logging(config, use_worker_options=False)
sighup_callbacks = []
synapse.config.logger.setup_logging(
config,
use_worker_options=False,
register_sighup=sighup_callbacks.append
)
def handle_sighup(*args, **kwargs):
for i in sighup_callbacks:
i(*args, **kwargs)
if hasattr(signal, "SIGHUP"):
signal.signal(signal.SIGHUP, handle_sighup)
events.USE_FROZEN_DICTS = config.use_frozen_dicts
@ -359,6 +375,31 @@ def setup(config_options):
hs.setup()
def refresh_certificate(*args):
"""
Refresh the TLS certificates that Synapse is using by re-reading them
from disk and updating the TLS context factories to use them.
"""
logging.info("Reloading certificate from disk...")
hs.config.read_certificate_from_disk()
hs.tls_server_context_factory = context_factory.ServerContextFactory(config)
hs.tls_client_options_factory = context_factory.ClientTLSOptionsFactory(
config
)
logging.info("Certificate reloaded.")
logging.info("Updating context factories...")
for i in hs._listening_services:
if isinstance(i.factory, TLSMemoryBIOFactory):
i.factory = TLSMemoryBIOFactory(
hs.tls_server_context_factory,
False,
i.factory.wrappedFactory
)
logging.info("Context factories updated.")
sighup_callbacks.append(refresh_certificate)
@defer.inlineCallbacks
def start():
try:

View file

@ -13,6 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from os import path
from synapse.config import ConfigError
from ._base import Config
DEFAULT_CONFIG = """\
@ -85,7 +89,15 @@ class ConsentConfig(Config):
if consent_config is None:
return
self.user_consent_version = str(consent_config["version"])
self.user_consent_template_dir = consent_config["template_dir"]
self.user_consent_template_dir = self.abspath(
consent_config["template_dir"]
)
if not path.isdir(self.user_consent_template_dir):
raise ConfigError(
"Could not find template directory '%s'" % (
self.user_consent_template_dir,
),
)
self.user_consent_server_notice_content = consent_config.get(
"server_notice_content",
)

View file

@ -127,7 +127,7 @@ class LoggingConfig(Config):
)
def setup_logging(config, use_worker_options=False):
def setup_logging(config, use_worker_options=False, register_sighup=None):
""" Set up python logging
Args:
@ -136,7 +136,16 @@ def setup_logging(config, use_worker_options=False):
use_worker_options (bool): True to use 'worker_log_config' and
'worker_log_file' options instead of 'log_config' and 'log_file'.
register_sighup (func | None): Function to call to register a
sighup handler.
"""
if not register_sighup:
if getattr(signal, "SIGHUP"):
register_sighup = lambda x: signal.signal(signal.SIGHUP, x)
else:
register_sighup = lambda x: None
log_config = (config.worker_log_config if use_worker_options
else config.log_config)
log_file = (config.worker_log_file if use_worker_options
@ -198,13 +207,7 @@ def setup_logging(config, use_worker_options=False):
load_log_config()
# TODO(paul): obviously this is a terrible mechanism for
# stealing SIGHUP, because it means no other part of synapse
# can use it instead. If we want to catch SIGHUP anywhere
# else as well, I'd suggest we find a nicer way to broadcast
# it around.
if getattr(signal, "SIGHUP"):
signal.signal(signal.SIGHUP, sighup)
register_sighup(sighup)
# make sure that the first thing we log is a thing we can grep backwards
# for

View file

@ -261,7 +261,7 @@ class ServerConfig(Config):
# enter into the 'custom HS URL' field on their client. If you
# use synapse with a reverse proxy, this should be the URL to reach
# synapse via the proxy.
# public_baseurl: https://example.com:8448/
# public_baseurl: https://example.com/
# Set the soft limit on the number of file descriptors synapse can use
# Zero is used to indicate synapse should set the soft limit to the

View file

@ -15,6 +15,7 @@
import logging
import os
import warnings
from datetime import datetime
from hashlib import sha256
@ -39,8 +40,8 @@ class TlsConfig(Config):
self.acme_bind_addresses = acme_config.get("bind_addresses", ["127.0.0.1"])
self.acme_reprovision_threshold = acme_config.get("reprovision_threshold", 30)
self.tls_certificate_file = os.path.abspath(config.get("tls_certificate_path"))
self.tls_private_key_file = os.path.abspath(config.get("tls_private_key_path"))
self.tls_certificate_file = self.abspath(config.get("tls_certificate_path"))
self.tls_private_key_file = self.abspath(config.get("tls_private_key_path"))
self._original_tls_fingerprints = config["tls_fingerprints"]
self.tls_fingerprints = list(self._original_tls_fingerprints)
self.no_tls = config.get("no_tls", False)
@ -94,6 +95,16 @@ class TlsConfig(Config):
"""
self.tls_certificate = self.read_tls_certificate(self.tls_certificate_file)
# Check if it is self-signed, and issue a warning if so.
if self.tls_certificate.get_issuer() == self.tls_certificate.get_subject():
warnings.warn(
(
"Self-signed TLS certificates will not be accepted by Synapse 1.0. "
"Please either provide a valid certificate, or use Synapse's ACME "
"support to provision one."
)
)
if not self.no_tls:
self.tls_private_key = self.read_tls_private_key(self.tls_private_key_file)
@ -118,10 +129,11 @@ class TlsConfig(Config):
return (
"""\
# PEM encoded X509 certificate for TLS.
# You can replace the self-signed certificate that synapse
# autogenerates on launch with your own SSL certificate + key pair
# if you like. Any required intermediary certificates can be
# appended after the primary certificate in hierarchical order.
# This certificate, as of Synapse 1.0, will need to be a valid
# and verifiable certificate, with a root that is available in
# the root store of other servers you wish to federate to. Any
# required intermediary certificates can be appended after the
# primary certificate in hierarchical order.
tls_certificate_path: "%(tls_certificate_path)s"
# PEM encoded private key for TLS
@ -183,40 +195,3 @@ class TlsConfig(Config):
def read_tls_private_key(self, private_key_path):
private_key_pem = self.read_file(private_key_path, "tls_private_key")
return crypto.load_privatekey(crypto.FILETYPE_PEM, private_key_pem)
def generate_files(self, config):
tls_certificate_path = config["tls_certificate_path"]
tls_private_key_path = config["tls_private_key_path"]
if not self.path_exists(tls_private_key_path):
with open(tls_private_key_path, "wb") as private_key_file:
tls_private_key = crypto.PKey()
tls_private_key.generate_key(crypto.TYPE_RSA, 2048)
private_key_pem = crypto.dump_privatekey(
crypto.FILETYPE_PEM, tls_private_key
)
private_key_file.write(private_key_pem)
else:
with open(tls_private_key_path) as private_key_file:
private_key_pem = private_key_file.read()
tls_private_key = crypto.load_privatekey(
crypto.FILETYPE_PEM, private_key_pem
)
if not self.path_exists(tls_certificate_path):
with open(tls_certificate_path, "wb") as certificate_file:
cert = crypto.X509()
subject = cert.get_subject()
subject.CN = config["server_name"]
cert.set_serial_number(1000)
cert.gmtime_adj_notBefore(0)
cert.gmtime_adj_notAfter(10 * 365 * 24 * 60 * 60)
cert.set_issuer(cert.get_subject())
cert.set_pubkey(tls_private_key)
cert.sign(tls_private_key, 'sha256')
cert_pem = crypto.dump_certificate(crypto.FILETYPE_PEM, cert)
certificate_file.write(cert_pem)

View file

@ -131,12 +131,12 @@ def compute_event_signature(event_dict, signature_name, signing_key):
return redact_json["signatures"]
def add_hashes_and_signatures(event, signature_name, signing_key,
def add_hashes_and_signatures(event_dict, signature_name, signing_key,
hash_algorithm=hashlib.sha256):
"""Add content hash and sign the event
Args:
event_dict (EventBuilder): The event to add hashes to and sign
event_dict (dict): The event to add hashes to and sign
signature_name (str): The name of the entity signing the event
(typically the server's hostname).
signing_key (syutil.crypto.SigningKey): The key to sign with
@ -144,16 +144,12 @@ def add_hashes_and_signatures(event, signature_name, signing_key,
to hash the event
"""
name, digest = compute_content_hash(
event.get_pdu_json(), hash_algorithm=hash_algorithm,
)
name, digest = compute_content_hash(event_dict, hash_algorithm=hash_algorithm)
if not hasattr(event, "hashes"):
event.hashes = {}
event.hashes[name] = encode_base64(digest)
event_dict.setdefault("hashes", {})[name] = encode_base64(digest)
event.signatures = compute_event_signature(
event.get_pdu_json(),
event_dict["signatures"] = compute_event_signature(
event_dict,
signature_name=signature_name,
signing_key=signing_key,
)

View file

@ -20,7 +20,14 @@ from signedjson.key import decode_verify_key_bytes
from signedjson.sign import SignatureVerifyException, verify_signed_json
from unpaddedbase64 import decode_base64
from synapse.api.constants import KNOWN_ROOM_VERSIONS, EventTypes, JoinRules, Membership
from synapse.api.constants import (
KNOWN_ROOM_VERSIONS,
EventFormatVersions,
EventTypes,
JoinRules,
Membership,
RoomVersions,
)
from synapse.api.errors import AuthError, EventSizeError, SynapseError
from synapse.types import UserID, get_domain_from_id
@ -49,7 +56,6 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru
if do_sig_check:
sender_domain = get_domain_from_id(event.sender)
event_id_domain = get_domain_from_id(event.event_id)
is_invite_via_3pid = (
event.type == EventTypes.Member
@ -66,9 +72,13 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru
if not is_invite_via_3pid:
raise AuthError(403, "Event not signed by sender's server")
# Check the event_id's domain has signed the event
if not event.signatures.get(event_id_domain):
raise AuthError(403, "Event not signed by sending server")
if event.format_version in (EventFormatVersions.V1,):
# Only older room versions have event IDs to check.
event_id_domain = get_domain_from_id(event.event_id)
# Check the origin domain has signed the event
if not event.signatures.get(event_id_domain):
raise AuthError(403, "Event not signed by sending server")
if auth_events is None:
# Oh, we don't know what the state of the room was, so we
@ -168,7 +178,7 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru
_check_power_levels(event, auth_events)
if event.type == EventTypes.Redaction:
check_redaction(event, auth_events)
check_redaction(room_version, event, auth_events)
logger.debug("Allowing! %s", event)
@ -422,7 +432,7 @@ def _can_send_event(event, auth_events):
return True
def check_redaction(event, auth_events):
def check_redaction(room_version, event, auth_events):
"""Check whether the event sender is allowed to redact the target event.
Returns:
@ -442,10 +452,16 @@ def check_redaction(event, auth_events):
if user_level >= redact_level:
return False
redacter_domain = get_domain_from_id(event.event_id)
redactee_domain = get_domain_from_id(event.redacts)
if redacter_domain == redactee_domain:
if room_version in (RoomVersions.V1, RoomVersions.V2,):
redacter_domain = get_domain_from_id(event.event_id)
redactee_domain = get_domain_from_id(event.redacts)
if redacter_domain == redactee_domain:
return True
elif room_version == RoomVersions.V3:
event.internal_metadata.recheck_redaction = True
return True
else:
raise RuntimeError("Unrecognized room version %r" % (room_version,))
raise AuthError(
403,

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -18,11 +19,9 @@ from distutils.util import strtobool
import six
from synapse.api.constants import (
KNOWN_EVENT_FORMAT_VERSIONS,
KNOWN_ROOM_VERSIONS,
EventFormatVersions,
)
from unpaddedbase64 import encode_base64
from synapse.api.constants import KNOWN_ROOM_VERSIONS, EventFormatVersions, RoomVersions
from synapse.util.caches import intern_dict
from synapse.util.frozenutils import freeze
@ -63,6 +62,21 @@ class _EventInternalMetadata(object):
"""
return getattr(self, "send_on_behalf_of", None)
def need_to_check_redaction(self):
"""Whether the redaction event needs to be rechecked when fetching
from the database.
Starting in room v3 redaction events are accepted up front, and later
checked to see if the redacter and redactee's domains match.
If the sender of the redaction event is allowed to redact any event
due to auth rules, then this will always return false.
Returns:
bool
"""
return getattr(self, "recheck_redaction", False)
def _event_dict_property(key):
# We want to be able to use hasattr with the event dict properties.
@ -225,16 +239,6 @@ class FrozenEvent(EventBase):
rejected_reason=rejected_reason,
)
@staticmethod
def from_event(event):
e = FrozenEvent(
event.get_pdu_json()
)
e.internal_metadata = event.internal_metadata
return e
def __str__(self):
return self.__repr__()
@ -246,6 +250,85 @@ class FrozenEvent(EventBase):
)
class FrozenEventV2(EventBase):
format_version = EventFormatVersions.V2 # All events of this type are V2
def __init__(self, event_dict, internal_metadata_dict={}, rejected_reason=None):
event_dict = dict(event_dict)
# Signatures is a dict of dicts, and this is faster than doing a
# copy.deepcopy
signatures = {
name: {sig_id: sig for sig_id, sig in sigs.items()}
for name, sigs in event_dict.pop("signatures", {}).items()
}
assert "event_id" not in event_dict
unsigned = dict(event_dict.pop("unsigned", {}))
# We intern these strings because they turn up a lot (especially when
# caching).
event_dict = intern_dict(event_dict)
if USE_FROZEN_DICTS:
frozen_dict = freeze(event_dict)
else:
frozen_dict = event_dict
self._event_id = None
self.type = event_dict["type"]
if "state_key" in event_dict:
self.state_key = event_dict["state_key"]
super(FrozenEventV2, self).__init__(
frozen_dict,
signatures=signatures,
unsigned=unsigned,
internal_metadata_dict=internal_metadata_dict,
rejected_reason=rejected_reason,
)
@property
def event_id(self):
# We have to import this here as otherwise we get an import loop which
# is hard to break.
from synapse.crypto.event_signing import compute_event_reference_hash
if self._event_id:
return self._event_id
self._event_id = "$" + encode_base64(compute_event_reference_hash(self)[1])
return self._event_id
def prev_event_ids(self):
"""Returns the list of prev event IDs. The order matches the order
specified in the event, though there is no meaning to it.
Returns:
list[str]: The list of event IDs of this event's prev_events
"""
return self.prev_events
def auth_event_ids(self):
"""Returns the list of auth event IDs. The order matches the order
specified in the event, though there is no meaning to it.
Returns:
list[str]: The list of event IDs of this event's auth_events
"""
return self.auth_events
def __str__(self):
return self.__repr__()
def __repr__(self):
return "<FrozenEventV2 event_id='%s', type='%s', state_key='%s'>" % (
self.event_id,
self.get("type", None),
self.get("state_key", None),
)
def room_version_to_event_format(room_version):
"""Converts a room version string to the event format
@ -259,7 +342,14 @@ def room_version_to_event_format(room_version):
# We should have already checked version, so this should not happen
raise RuntimeError("Unrecognized room version %s" % (room_version,))
return EventFormatVersions.V1
if room_version in (
RoomVersions.V1, RoomVersions.V2, RoomVersions.STATE_V2_TEST,
):
return EventFormatVersions.V1
elif room_version in (RoomVersions.V3,):
return EventFormatVersions.V2
else:
raise RuntimeError("Unrecognized room version %s" % (room_version,))
def event_type_from_format_version(format_version):
@ -273,8 +363,12 @@ def event_type_from_format_version(format_version):
type: A type that can be initialized as per the initializer of
`FrozenEvent`
"""
if format_version not in KNOWN_EVENT_FORMAT_VERSIONS:
if format_version == EventFormatVersions.V1:
return FrozenEvent
elif format_version == EventFormatVersions.V2:
return FrozenEventV2
else:
raise Exception(
"No event format %r" % (format_version,)
)
return FrozenEvent

View file

@ -13,79 +13,161 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import attr
from synapse.api.constants import RoomVersions
from twisted.internet import defer
from synapse.api.constants import (
KNOWN_EVENT_FORMAT_VERSIONS,
KNOWN_ROOM_VERSIONS,
MAX_DEPTH,
EventFormatVersions,
)
from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.types import EventID
from synapse.util.stringutils import random_string
from . import EventBase, FrozenEvent, _event_dict_property
from . import (
_EventInternalMetadata,
event_type_from_format_version,
room_version_to_event_format,
)
def get_event_builder(room_version, key_values={}, internal_metadata_dict={}):
"""Generate an event builder appropriate for the given room version
@attr.s(slots=True, cmp=False, frozen=True)
class EventBuilder(object):
"""A format independent event builder used to build up the event content
before signing the event.
Args:
room_version (str): Version of the room that we're creating an
event builder for
key_values (dict): Fields used as the basis of the new event
internal_metadata_dict (dict): Used to create the `_EventInternalMetadata`
object.
(Note that while objects of this class are frozen, the
content/unsigned/internal_metadata fields are still mutable)
Returns:
EventBuilder
Attributes:
format_version (int): Event format version
room_id (str)
type (str)
sender (str)
content (dict)
unsigned (dict)
internal_metadata (_EventInternalMetadata)
_state (StateHandler)
_auth (synapse.api.Auth)
_store (DataStore)
_clock (Clock)
_hostname (str): The hostname of the server creating the event
_signing_key: The signing key to use to sign the event as the server
"""
if room_version in {
RoomVersions.V1,
RoomVersions.V2,
RoomVersions.VDH_TEST,
RoomVersions.STATE_V2_TEST,
}:
return EventBuilder(key_values, internal_metadata_dict)
else:
raise Exception(
"No event format defined for version %r" % (room_version,)
_state = attr.ib()
_auth = attr.ib()
_store = attr.ib()
_clock = attr.ib()
_hostname = attr.ib()
_signing_key = attr.ib()
format_version = attr.ib()
room_id = attr.ib()
type = attr.ib()
sender = attr.ib()
content = attr.ib(default=attr.Factory(dict))
unsigned = attr.ib(default=attr.Factory(dict))
# These only exist on a subset of events, so they raise AttributeError if
# someone tries to get them when they don't exist.
_state_key = attr.ib(default=None)
_redacts = attr.ib(default=None)
internal_metadata = attr.ib(default=attr.Factory(lambda: _EventInternalMetadata({})))
@property
def state_key(self):
if self._state_key is not None:
return self._state_key
raise AttributeError("state_key")
def is_state(self):
return self._state_key is not None
@defer.inlineCallbacks
def build(self, prev_event_ids):
"""Transform into a fully signed and hashed event
Args:
prev_event_ids (list[str]): The event IDs to use as the prev events
Returns:
Deferred[FrozenEvent]
"""
state_ids = yield self._state.get_current_state_ids(
self.room_id, prev_event_ids,
)
auth_ids = yield self._auth.compute_auth_events(
self, state_ids,
)
if self.format_version == EventFormatVersions.V1:
auth_events = yield self._store.add_event_hashes(auth_ids)
prev_events = yield self._store.add_event_hashes(prev_event_ids)
else:
auth_events = auth_ids
prev_events = prev_event_ids
class EventBuilder(EventBase):
def __init__(self, key_values={}, internal_metadata_dict={}):
signatures = copy.deepcopy(key_values.pop("signatures", {}))
unsigned = copy.deepcopy(key_values.pop("unsigned", {}))
super(EventBuilder, self).__init__(
key_values,
signatures=signatures,
unsigned=unsigned,
internal_metadata_dict=internal_metadata_dict,
old_depth = yield self._store.get_max_depth_of(
prev_event_ids,
)
depth = old_depth + 1
event_id = _event_dict_property("event_id")
state_key = _event_dict_property("state_key")
type = _event_dict_property("type")
# we cap depth of generated events, to ensure that they are not
# rejected by other servers (and so that they can be persisted in
# the db)
depth = min(depth, MAX_DEPTH)
def build(self):
return FrozenEvent.from_event(self)
event_dict = {
"auth_events": auth_events,
"prev_events": prev_events,
"type": self.type,
"room_id": self.room_id,
"sender": self.sender,
"content": self.content,
"unsigned": self.unsigned,
"depth": depth,
"prev_state": [],
}
if self.is_state():
event_dict["state_key"] = self._state_key
if self._redacts is not None:
event_dict["redacts"] = self._redacts
defer.returnValue(
create_local_event_from_event_dict(
clock=self._clock,
hostname=self._hostname,
signing_key=self._signing_key,
format_version=self.format_version,
event_dict=event_dict,
internal_metadata_dict=self.internal_metadata.get_dict(),
)
)
class EventBuilderFactory(object):
def __init__(self, clock, hostname):
self.clock = clock
self.hostname = hostname
def __init__(self, hs):
self.clock = hs.get_clock()
self.hostname = hs.hostname
self.signing_key = hs.config.signing_key[0]
self.event_id_count = 0
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
self.auth = hs.get_auth()
def create_event_id(self):
i = str(self.event_id_count)
self.event_id_count += 1
local_part = str(int(self.clock.time())) + i + random_string(5)
e_id = EventID(local_part, self.hostname)
return e_id.to_string()
def new(self, room_version, key_values={}):
def new(self, room_version, key_values):
"""Generate an event builder appropriate for the given room version
Args:
@ -98,27 +180,103 @@ class EventBuilderFactory(object):
"""
# There's currently only the one event version defined
if room_version not in {
RoomVersions.V1,
RoomVersions.V2,
RoomVersions.VDH_TEST,
RoomVersions.STATE_V2_TEST,
}:
if room_version not in KNOWN_ROOM_VERSIONS:
raise Exception(
"No event format defined for version %r" % (room_version,)
)
key_values["event_id"] = self.create_event_id()
return EventBuilder(
store=self.store,
state=self.state,
auth=self.auth,
clock=self.clock,
hostname=self.hostname,
signing_key=self.signing_key,
format_version=room_version_to_event_format(room_version),
type=key_values["type"],
state_key=key_values.get("state_key"),
room_id=key_values["room_id"],
sender=key_values["sender"],
content=key_values.get("content", {}),
unsigned=key_values.get("unsigned", {}),
redacts=key_values.get("redacts", None),
)
time_now = int(self.clock.time_msec())
key_values.setdefault("origin", self.hostname)
key_values.setdefault("origin_server_ts", time_now)
def create_local_event_from_event_dict(clock, hostname, signing_key,
format_version, event_dict,
internal_metadata_dict=None):
"""Takes a fully formed event dict, ensuring that fields like `origin`
and `origin_server_ts` have correct values for a locally produced event,
then signs and hashes it.
key_values.setdefault("unsigned", {})
age = key_values["unsigned"].pop("age", 0)
key_values["unsigned"].setdefault("age_ts", time_now - age)
Args:
clock (Clock)
hostname (str)
signing_key
format_version (int)
event_dict (dict)
internal_metadata_dict (dict|None)
key_values["signatures"] = {}
Returns:
FrozenEvent
"""
return EventBuilder(key_values=key_values,)
# There's currently only the one event version defined
if format_version not in KNOWN_EVENT_FORMAT_VERSIONS:
raise Exception(
"No event format defined for version %r" % (format_version,)
)
if internal_metadata_dict is None:
internal_metadata_dict = {}
time_now = int(clock.time_msec())
if format_version == EventFormatVersions.V1:
event_dict["event_id"] = _create_event_id(clock, hostname)
event_dict["origin"] = hostname
event_dict["origin_server_ts"] = time_now
event_dict.setdefault("unsigned", {})
age = event_dict["unsigned"].pop("age", 0)
event_dict["unsigned"].setdefault("age_ts", time_now - age)
event_dict.setdefault("signatures", {})
add_hashes_and_signatures(
event_dict,
hostname,
signing_key,
)
return event_type_from_format_version(format_version)(
event_dict, internal_metadata_dict=internal_metadata_dict,
)
# A counter used when generating new event IDs
_event_id_counter = 0
def _create_event_id(clock, hostname):
"""Create a new event ID
Args:
clock (Clock)
hostname (str): The server name for the event ID
Returns:
str
"""
global _event_id_counter
i = str(_event_id_counter)
_event_id_counter += 1
local_part = str(int(clock.time())) + i + random_string(5)
e_id = EventID(local_part, hostname)
return e_id.to_string()

View file

@ -267,6 +267,7 @@ def serialize_event(e, time_now_ms, as_client_event=True,
Returns:
dict
"""
# FIXME(erikj): To handle the case of presence events and the like
if not isinstance(e, EventBase):
return e
@ -276,6 +277,8 @@ def serialize_event(e, time_now_ms, as_client_event=True,
# Should this strip out None's?
d = {k: v for k, v in e.get_dict().items()}
d["event_id"] = e.event_id
if "age_ts" in d["unsigned"]:
d["unsigned"]["age"] = time_now_ms - d["unsigned"]["age_ts"]
del d["unsigned"]["age_ts"]

View file

@ -15,23 +15,29 @@
from six import string_types
from synapse.api.constants import EventTypes, Membership
from synapse.api.constants import EventFormatVersions, EventTypes, Membership
from synapse.api.errors import SynapseError
from synapse.types import EventID, RoomID, UserID
class EventValidator(object):
def validate_new(self, event):
"""Validates the event has roughly the right format
def validate(self, event):
EventID.from_string(event.event_id)
RoomID.from_string(event.room_id)
Args:
event (FrozenEvent)
"""
self.validate_builder(event)
if event.format_version == EventFormatVersions.V1:
EventID.from_string(event.event_id)
required = [
# "auth_events",
"auth_events",
"content",
# "hashes",
"hashes",
"origin",
# "prev_events",
"prev_events",
"sender",
"type",
]
@ -41,8 +47,25 @@ class EventValidator(object):
raise SynapseError(400, "Event does not have key %s" % (k,))
# Check that the following keys have string values
strings = [
event_strings = [
"origin",
]
for s in event_strings:
if not isinstance(getattr(event, s), string_types):
raise SynapseError(400, "'%s' not a string type" % (s,))
def validate_builder(self, event):
"""Validates that the builder/event has roughly the right format. Only
checks values that we expect a proto event to have, rather than all the
fields an event would have
Args:
event (EventBuilder|FrozenEvent)
"""
strings = [
"room_id",
"sender",
"type",
]
@ -54,22 +77,7 @@ class EventValidator(object):
if not isinstance(getattr(event, s), string_types):
raise SynapseError(400, "Not '%s' a string type" % (s,))
if event.type == EventTypes.Member:
if "membership" not in event.content:
raise SynapseError(400, "Content has not membership key")
if event.content["membership"] not in Membership.LIST:
raise SynapseError(400, "Invalid membership key")
# Check that the following keys have dictionary values
# TODO
# Check that the following keys have the correct format for DAGs
# TODO
def validate_new(self, event):
self.validate(event)
RoomID.from_string(event.room_id)
UserID.from_string(event.sender)
if event.type == EventTypes.Message:
@ -86,9 +94,16 @@ class EventValidator(object):
elif event.type == EventTypes.Name:
self._ensure_strings(event.content, ["name"])
elif event.type == EventTypes.Member:
if "membership" not in event.content:
raise SynapseError(400, "Content has not membership key")
if event.content["membership"] not in Membership.LIST:
raise SynapseError(400, "Invalid membership key")
def _ensure_strings(self, d, keys):
for s in keys:
if s not in d:
raise SynapseError(400, "'%s' not in content" % (s,))
if not isinstance(d[s], string_types):
raise SynapseError(400, "Not '%s' a string type" % (s,))
raise SynapseError(400, "'%s' not a string type" % (s,))

View file

@ -20,7 +20,7 @@ import six
from twisted.internet import defer
from twisted.internet.defer import DeferredList
from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
from synapse.api.constants import MAX_DEPTH, EventTypes, Membership, RoomVersions
from synapse.api.errors import Codes, SynapseError
from synapse.crypto.event_signing import check_event_content_hash
from synapse.events import event_type_from_format_version
@ -66,7 +66,7 @@ class FederationBase(object):
Returns:
Deferred : A list of PDUs that have valid signatures and hashes.
"""
deferreds = self._check_sigs_and_hashes(pdus)
deferreds = self._check_sigs_and_hashes(room_version, pdus)
@defer.inlineCallbacks
def handle_check_result(pdu, deferred):
@ -121,16 +121,17 @@ class FederationBase(object):
else:
defer.returnValue([p for p in valid_pdus if p])
def _check_sigs_and_hash(self, pdu):
def _check_sigs_and_hash(self, room_version, pdu):
return logcontext.make_deferred_yieldable(
self._check_sigs_and_hashes([pdu])[0],
self._check_sigs_and_hashes(room_version, [pdu])[0],
)
def _check_sigs_and_hashes(self, pdus):
def _check_sigs_and_hashes(self, room_version, pdus):
"""Checks that each of the received events is correctly signed by the
sending server.
Args:
room_version (str): The room version of the PDUs
pdus (list[FrozenEvent]): the events to be checked
Returns:
@ -141,7 +142,7 @@ class FederationBase(object):
* throws a SynapseError if the signature check failed.
The deferreds run their callbacks in the sentinel logcontext.
"""
deferreds = _check_sigs_on_pdus(self.keyring, pdus)
deferreds = _check_sigs_on_pdus(self.keyring, room_version, pdus)
ctx = logcontext.LoggingContext.current_context()
@ -203,16 +204,17 @@ class FederationBase(object):
class PduToCheckSig(namedtuple("PduToCheckSig", [
"pdu", "redacted_pdu_json", "event_id_domain", "sender_domain", "deferreds",
"pdu", "redacted_pdu_json", "sender_domain", "deferreds",
])):
pass
def _check_sigs_on_pdus(keyring, pdus):
def _check_sigs_on_pdus(keyring, room_version, pdus):
"""Check that the given events are correctly signed
Args:
keyring (synapse.crypto.Keyring): keyring object to do the checks
room_version (str): the room version of the PDUs
pdus (Collection[EventBase]): the events to be checked
Returns:
@ -225,9 +227,7 @@ def _check_sigs_on_pdus(keyring, pdus):
# we want to check that the event is signed by:
#
# (a) the server which created the event_id
#
# (b) the sender's server.
# (a) the sender's server
#
# - except in the case of invites created from a 3pid invite, which are exempt
# from this check, because the sender has to match that of the original 3pid
@ -241,34 +241,26 @@ def _check_sigs_on_pdus(keyring, pdus):
# and signatures are *supposed* to be valid whether or not an event has been
# redacted. But this isn't the worst of the ways that 3pid invites are broken.
#
# (b) for V1 and V2 rooms, the server which created the event_id
#
# let's start by getting the domain for each pdu, and flattening the event back
# to JSON.
pdus_to_check = [
PduToCheckSig(
pdu=p,
redacted_pdu_json=prune_event(p).get_pdu_json(),
event_id_domain=get_domain_from_id(p.event_id),
sender_domain=get_domain_from_id(p.sender),
deferreds=[],
)
for p in pdus
]
# first make sure that the event is signed by the event_id's domain
deferreds = keyring.verify_json_objects_for_server([
(p.event_id_domain, p.redacted_pdu_json)
for p in pdus_to_check
])
for p, d in zip(pdus_to_check, deferreds):
p.deferreds.append(d)
# now let's look for events where the sender's domain is different to the
# event id's domain (normally only the case for joins/leaves), and add additional
# checks.
# First we check that the sender event is signed by the sender's domain
# (except if its a 3pid invite, in which case it may be sent by any server)
pdus_to_check_sender = [
p for p in pdus_to_check
if p.sender_domain != p.event_id_domain and not _is_invite_via_3pid(p.pdu)
if not _is_invite_via_3pid(p.pdu)
]
more_deferreds = keyring.verify_json_objects_for_server([
@ -279,19 +271,43 @@ def _check_sigs_on_pdus(keyring, pdus):
for p, d in zip(pdus_to_check_sender, more_deferreds):
p.deferreds.append(d)
# now let's look for events where the sender's domain is different to the
# event id's domain (normally only the case for joins/leaves), and add additional
# checks. Only do this if the room version has a concept of event ID domain
if room_version in (
RoomVersions.V1, RoomVersions.V2, RoomVersions.STATE_V2_TEST,
):
pdus_to_check_event_id = [
p for p in pdus_to_check
if p.sender_domain != get_domain_from_id(p.pdu.event_id)
]
more_deferreds = keyring.verify_json_objects_for_server([
(get_domain_from_id(p.pdu.event_id), p.redacted_pdu_json)
for p in pdus_to_check_event_id
])
for p, d in zip(pdus_to_check_event_id, more_deferreds):
p.deferreds.append(d)
elif room_version in (RoomVersions.V3,):
pass # No further checks needed, as event IDs are hashes here
else:
raise RuntimeError("Unrecognized room version %s" % (room_version,))
# replace lists of deferreds with single Deferreds
return [_flatten_deferred_list(p.deferreds) for p in pdus_to_check]
def _flatten_deferred_list(deferreds):
"""Given a list of one or more deferreds, either return the single deferred, or
combine into a DeferredList.
"""Given a list of deferreds, either return the single deferred,
combine into a DeferredList, or return an already resolved deferred.
"""
if len(deferreds) > 1:
return DeferredList(deferreds, fireOnOneErrback=True, consumeErrors=True)
else:
assert len(deferreds) == 1
elif len(deferreds) == 1:
return deferreds[0]
else:
return defer.succeed(None)
def _is_invite_via_3pid(event):
@ -319,7 +335,7 @@ def event_from_pdu_json(pdu_json, event_format_version, outlier=False):
"""
# we could probably enforce a bunch of other fields here (room_id, sender,
# origin, etc etc)
assert_params_in_dict(pdu_json, ('event_id', 'type', 'depth'))
assert_params_in_dict(pdu_json, ('type', 'depth'))
depth = pdu_json['depth']
if not isinstance(depth, six.integer_types):

View file

@ -37,8 +37,7 @@ from synapse.api.errors import (
HttpResponseException,
SynapseError,
)
from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.events import room_version_to_event_format
from synapse.events import builder, room_version_to_event_format
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
from synapse.util import logcontext, unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache
@ -72,7 +71,8 @@ class FederationClient(FederationBase):
self.state = hs.get_state_handler()
self.transport_layer = hs.get_federation_transport_client()
self.event_builder_factory = hs.get_event_builder_factory()
self.hostname = hs.hostname
self.signing_key = hs.config.signing_key[0]
self._get_pdu_cache = ExpiringCache(
cache_name="get_pdu_cache",
@ -205,7 +205,7 @@ class FederationClient(FederationBase):
# FIXME: We should handle signature failures more gracefully.
pdus[:] = yield logcontext.make_deferred_yieldable(defer.gatherResults(
self._check_sigs_and_hashes(pdus),
self._check_sigs_and_hashes(room_version, pdus),
consumeErrors=True,
).addErrback(unwrapFirstError))
@ -268,7 +268,7 @@ class FederationClient(FederationBase):
pdu = pdu_list[0]
# Check signatures are correct.
signed_pdu = yield self._check_sigs_and_hash(pdu)
signed_pdu = yield self._check_sigs_and_hash(room_version, pdu)
break
@ -608,18 +608,10 @@ class FederationClient(FederationBase):
if "prev_state" not in pdu_dict:
pdu_dict["prev_state"] = []
# Strip off the fields that we want to clobber.
pdu_dict.pop("origin", None)
pdu_dict.pop("origin_server_ts", None)
pdu_dict.pop("unsigned", None)
builder = self.event_builder_factory.new(room_version, pdu_dict)
add_hashes_and_signatures(
builder,
self.hs.hostname,
self.hs.config.signing_key[0]
ev = builder.create_local_event_from_event_dict(
self._clock, self.hostname, self.signing_key,
format_version=event_format, event_dict=pdu_dict,
)
ev = builder.build()
defer.returnValue(
(destination, ev, event_format)
@ -751,18 +743,9 @@ class FederationClient(FederationBase):
@defer.inlineCallbacks
def send_invite(self, destination, room_id, event_id, pdu):
time_now = self._clock.time_msec()
try:
code, content = yield self.transport_layer.send_invite(
destination=destination,
room_id=room_id,
event_id=event_id,
content=pdu.get_pdu_json(time_now),
)
except HttpResponseException as e:
if e.code == 403:
raise e.to_synapse_error()
raise
room_version = yield self.store.get_room_version(room_id)
content = yield self._do_send_invite(destination, pdu, room_version)
pdu_dict = content["event"]
@ -774,12 +757,61 @@ class FederationClient(FederationBase):
pdu = event_from_pdu_json(pdu_dict, format_ver)
# Check signatures are correct.
pdu = yield self._check_sigs_and_hash(pdu)
pdu = yield self._check_sigs_and_hash(room_version, pdu)
# FIXME: We should handle signature failures more gracefully.
defer.returnValue(pdu)
@defer.inlineCallbacks
def _do_send_invite(self, destination, pdu, room_version):
"""Actually sends the invite, first trying v2 API and falling back to
v1 API if necessary.
Args:
destination (str): Target server
pdu (FrozenEvent)
room_version (str)
Returns:
dict: The event as a dict as returned by the remote server
"""
time_now = self._clock.time_msec()
try:
content = yield self.transport_layer.send_invite_v2(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content={
"event": pdu.get_pdu_json(time_now),
"room_version": room_version,
"invite_room_state": pdu.unsigned.get("invite_room_state", []),
},
)
defer.returnValue(content)
except HttpResponseException as e:
if e.code in [400, 404]:
if room_version in (RoomVersions.V1, RoomVersions.V2):
pass # We'll fall through
else:
raise Exception("Remote server is too old")
elif e.code == 403:
raise e.to_synapse_error()
else:
raise
# Didn't work, try v1 API.
# Note the v1 API returns a tuple of `(200, content)`
_, content = yield self.transport_layer.send_invite_v1(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)
defer.returnValue(content)
def send_leave(self, destinations, pdu):
"""Sends a leave event to one of a list of homeservers.

View file

@ -25,7 +25,7 @@ from twisted.internet import defer
from twisted.internet.abstract import isIPAddress
from twisted.python import failure
from synapse.api.constants import EventTypes
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import (
AuthError,
FederationError,
@ -322,7 +322,7 @@ class FederationServer(FederationBase):
if self.hs.is_mine_id(event.event_id):
event.signatures.update(
compute_event_signature(
event,
event.get_pdu_json(),
self.hs.hostname,
self.hs.config.signing_key[0]
)
@ -620,16 +620,19 @@ class FederationServer(FederationBase):
"""
# check that it's actually being sent from a valid destination to
# workaround bug #1753 in 0.18.5 and 0.18.6
if origin != get_domain_from_id(pdu.event_id):
if origin != get_domain_from_id(pdu.sender):
# We continue to accept join events from any server; this is
# necessary for the federation join dance to work correctly.
# (When we join over federation, the "helper" server is
# responsible for sending out the join event, rather than the
# origin. See bug #1893).
# origin. See bug #1893. This is also true for some third party
# invites).
if not (
pdu.type == 'm.room.member' and
pdu.content and
pdu.content.get("membership", None) == 'join'
pdu.content.get("membership", None) in (
Membership.JOIN, Membership.INVITE,
)
):
logger.info(
"Discarding PDU %s from invalid origin %s",
@ -642,9 +645,12 @@ class FederationServer(FederationBase):
pdu.event_id, origin
)
# We've already checked that we know the room version by this point
room_version = yield self.store.get_room_version(pdu.room_id)
# Check signature.
try:
pdu = yield self._check_sigs_and_hash(pdu)
pdu = yield self._check_sigs_and_hash(room_version, pdu)
except SynapseError as e:
raise FederationError(
"ERROR",

View file

@ -175,7 +175,7 @@ class TransactionQueue(object):
def handle_event(event):
# Only send events for this server.
send_on_behalf_of = event.internal_metadata.get_send_on_behalf_of()
is_mine = self.is_mine_id(event.event_id)
is_mine = self.is_mine_id(event.sender)
if not is_mine and send_on_behalf_of is None:
return

View file

@ -21,7 +21,7 @@ from six.moves import urllib
from twisted.internet import defer
from synapse.api.constants import Membership
from synapse.api.urls import FEDERATION_V1_PREFIX
from synapse.api.urls import FEDERATION_V1_PREFIX, FEDERATION_V2_PREFIX
from synapse.util.logutils import log_function
logger = logging.getLogger(__name__)
@ -289,7 +289,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
def send_invite(self, destination, room_id, event_id, content):
def send_invite_v1(self, destination, room_id, event_id, content):
path = _create_v1_path("/invite/%s/%s", room_id, event_id)
response = yield self.client.put_json(
@ -301,6 +301,20 @@ class TransportLayerClient(object):
defer.returnValue(response)
@defer.inlineCallbacks
@log_function
def send_invite_v2(self, destination, room_id, event_id, content):
path = _create_v2_path("/invite/%s/%s", room_id, event_id)
response = yield self.client.put_json(
destination=destination,
path=path,
data=content,
ignore_backoff=True,
)
defer.returnValue(response)
@defer.inlineCallbacks
@log_function
def get_public_rooms(self, remote_server, limit, since_token,
@ -958,3 +972,24 @@ def _create_v1_path(path, *args):
FEDERATION_V1_PREFIX
+ path % tuple(urllib.parse.quote(arg, "") for arg in args)
)
def _create_v2_path(path, *args):
"""Creates a path against V2 federation API from the path template and
args. Ensures that all args are url encoded.
Example:
_create_v2_path("/event/%s/", event_id)
Args:
path (str): String template for the path
args: ([str]): Args to insert into path. Each arg will be url encoded
Returns:
str
"""
return (
FEDERATION_V2_PREFIX
+ path % tuple(urllib.parse.quote(arg, "") for arg in args)
)

View file

@ -57,8 +57,8 @@ class DirectoryHandler(BaseHandler):
# general association creation for both human users and app services
for wchar in string.whitespace:
if wchar in room_alias.localpart:
raise SynapseError(400, "Invalid characters in room alias")
if wchar in room_alias.localpart:
raise SynapseError(400, "Invalid characters in room alias")
if not self.hs.is_mine(room_alias):
raise SynapseError(400, "Room alias must be local")

View file

@ -102,7 +102,7 @@ class FederationHandler(BaseHandler):
self.hs = hs
self.store = hs.get_datastore() # type: synapse.storage.DataStore
self.store = hs.get_datastore()
self.federation_client = hs.get_federation_client()
self.state_handler = hs.get_state_handler()
self.server_name = hs.hostname
@ -1300,7 +1300,7 @@ class FederationHandler(BaseHandler):
event.signatures.update(
compute_event_signature(
event,
event.get_pdu_json(),
self.hs.hostname,
self.hs.config.signing_key[0]
)
@ -2282,7 +2282,7 @@ class FederationHandler(BaseHandler):
room_version = yield self.store.get_room_version(room_id)
builder = self.event_builder_factory.new(room_version, event_dict)
EventValidator().validate_new(builder)
EventValidator().validate_builder(builder)
event, context = yield self.event_creation_handler.create_new_client_event(
builder=builder
)
@ -2291,6 +2291,12 @@ class FederationHandler(BaseHandler):
room_version, event_dict, event, context
)
EventValidator().validate_new(event)
# We need to tell the transaction queue to send this out, even
# though the sender isn't a local user.
event.internal_metadata.send_on_behalf_of = self.hs.hostname
try:
yield self.auth.check_from_context(room_version, event, context)
except AuthError as e:
@ -2340,6 +2346,10 @@ class FederationHandler(BaseHandler):
raise e
yield self._check_signature(event, context)
# We need to tell the transaction queue to send this out, even
# though the sender isn't a local user.
event.internal_metadata.send_on_behalf_of = get_domain_from_id(event.sender)
# XXX we send the invite here, but send_membership_event also sends it,
# so we end up making two requests. I think this is redundant.
returned_invite = yield self.send_invite(origin, event)
@ -2376,10 +2386,11 @@ class FederationHandler(BaseHandler):
# auth check code will explode appropriately.
builder = self.event_builder_factory.new(room_version, event_dict)
EventValidator().validate_new(builder)
EventValidator().validate_builder(builder)
event, context = yield self.event_creation_handler.create_new_client_event(
builder=builder,
)
EventValidator().validate_new(event)
defer.returnValue((event, context))
@defer.inlineCallbacks

View file

@ -22,7 +22,7 @@ from canonicaljson import encode_canonical_json, json
from twisted.internet import defer
from twisted.internet.defer import succeed
from synapse.api.constants import MAX_DEPTH, EventTypes, Membership, RoomVersions
from synapse.api.constants import EventTypes, Membership, RoomVersions
from synapse.api.errors import (
AuthError,
Codes,
@ -31,7 +31,6 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.api.urls import ConsentURIBuilder
from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
@ -288,7 +287,7 @@ class EventCreationHandler(object):
builder = self.event_builder_factory.new(room_version, event_dict)
self.validator.validate_new(builder)
self.validator.validate_builder(builder)
if builder.type == EventTypes.Member:
membership = builder.content.get("membership", None)
@ -326,6 +325,8 @@ class EventCreationHandler(object):
prev_events_and_hashes=prev_events_and_hashes,
)
self.validator.validate_new(event)
defer.returnValue((event, context))
def _is_exempt_from_privacy_policy(self, builder, requester):
@ -543,40 +544,19 @@ class EventCreationHandler(object):
prev_events_and_hashes = \
yield self.store.get_prev_events_for_room(builder.room_id)
if prev_events_and_hashes:
depth = max([d for _, _, d in prev_events_and_hashes]) + 1
# we cap depth of generated events, to ensure that they are not
# rejected by other servers (and so that they can be persisted in
# the db)
depth = min(depth, MAX_DEPTH)
else:
depth = 1
prev_events = [
(event_id, prev_hashes)
for event_id, prev_hashes, _ in prev_events_and_hashes
]
builder.prev_events = prev_events
builder.depth = depth
context = yield self.state.compute_event_context(builder)
event = yield builder.build(
prev_event_ids=[p for p, _ in prev_events],
)
context = yield self.state.compute_event_context(event)
if requester:
context.app_service = requester.app_service
if builder.is_state():
builder.prev_state = yield self.store.add_event_hashes(
context.prev_state_events
)
yield self.auth.add_auth_events(builder, context)
signing_key = self.hs.config.signing_key[0]
add_hashes_and_signatures(
builder, self.server_name, signing_key
)
event = builder.build()
self.validator.validate_new(event)
logger.debug(
"Created event %s",
@ -765,7 +745,8 @@ class EventCreationHandler(object):
auth_events = {
(e.type, e.state_key): e for e in auth_events.values()
}
if self.auth.check_redaction(event, auth_events=auth_events):
room_version = yield self.store.get_room_version(event.room_id)
if self.auth.check_redaction(room_version, event, auth_events=auth_events):
original_event = yield self.store.get_event(
event.redacts,
check_redacted=False,
@ -779,6 +760,9 @@ class EventCreationHandler(object):
"You don't have permission to redact events"
)
# We've already checked.
event.internal_metadata.recheck_redaction = False
if event.type == EventTypes.Create:
prev_state_ids = yield context.get_prev_state_ids(self.store)
if prev_state_ids:

View file

@ -960,7 +960,8 @@ class RoomMemberHandler(object):
# first member event?
create_event_id = current_state_ids.get(("m.room.create", ""))
if len(current_state_ids) == 1 and create_event_id:
defer.returnValue(self.hs.is_mine_id(create_event_id))
# We can only get here if we're in the process of creating the room
defer.returnValue(True)
for etype, state_key in current_state_ids:
if etype != EventTypes.Member or not self.hs.is_mine_id(state_key):

View file

@ -12,7 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import random
import time
import attr
from netaddr import IPAddress
@ -20,14 +23,30 @@ from zope.interface import implementer
from twisted.internet import defer
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.web.client import URI, Agent, HTTPConnectionPool
from twisted.web.client import URI, Agent, HTTPConnectionPool, RedirectAgent, readBody
from twisted.web.http import stringToDatetime
from twisted.web.http_headers import Headers
from twisted.web.iweb import IAgent
from synapse.http.federation.srv_resolver import SrvResolver, pick_server_from_list
from synapse.util.caches.ttlcache import TTLCache
from synapse.util.logcontext import make_deferred_yieldable
# period to cache .well-known results for by default
WELL_KNOWN_DEFAULT_CACHE_PERIOD = 24 * 3600
# jitter to add to the .well-known default cache ttl
WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER = 10 * 60
# period to cache failure to fetch .well-known for
WELL_KNOWN_INVALID_CACHE_PERIOD = 1 * 3600
# cap for .well-known cache period
WELL_KNOWN_MAX_CACHE_PERIOD = 48 * 3600
logger = logging.getLogger(__name__)
well_known_cache = TTLCache('well-known')
@implementer(IAgent)
@ -43,13 +62,20 @@ class MatrixFederationAgent(object):
tls_client_options_factory (ClientTLSOptionsFactory|None):
factory to use for fetching client tls options, or none to disable TLS.
_well_known_tls_policy (IPolicyForHTTPS|None):
TLS policy to use for fetching .well-known files. None to use a default
(browser-like) implementation.
srv_resolver (SrvResolver|None):
SRVResolver impl to use for looking up SRV records. None to use a default
implementation.
"""
def __init__(
self, reactor, tls_client_options_factory, _srv_resolver=None,
self, reactor, tls_client_options_factory,
_well_known_tls_policy=None,
_srv_resolver=None,
_well_known_cache=well_known_cache,
):
self._reactor = reactor
self._tls_client_options_factory = tls_client_options_factory
@ -62,6 +88,18 @@ class MatrixFederationAgent(object):
self._pool.maxPersistentPerHost = 5
self._pool.cachedConnectionTimeout = 2 * 60
agent_args = {}
if _well_known_tls_policy is not None:
# the param is called 'contextFactory', but actually passing a
# contextfactory is deprecated, and it expects an IPolicyForHTTPS.
agent_args['contextFactory'] = _well_known_tls_policy
_well_known_agent = RedirectAgent(
Agent(self._reactor, pool=self._pool, **agent_args),
)
self._well_known_agent = _well_known_agent
self._well_known_cache = _well_known_cache
@defer.inlineCallbacks
def request(self, method, uri, headers=None, bodyProducer=None):
"""
@ -114,7 +152,11 @@ class MatrixFederationAgent(object):
class EndpointFactory(object):
@staticmethod
def endpointForURI(_uri):
logger.info("Connecting to %s:%s", res.target_host, res.target_port)
logger.info(
"Connecting to %s:%i",
res.target_host.decode("ascii"),
res.target_port,
)
ep = HostnameEndpoint(self._reactor, res.target_host, res.target_port)
if tls_options is not None:
ep = wrapClientTLS(tls_options, ep)
@ -127,7 +169,7 @@ class MatrixFederationAgent(object):
defer.returnValue(res)
@defer.inlineCallbacks
def _route_matrix_uri(self, parsed_uri):
def _route_matrix_uri(self, parsed_uri, lookup_well_known=True):
"""Helper for `request`: determine the routing for a Matrix URI
Args:
@ -135,6 +177,9 @@ class MatrixFederationAgent(object):
parsed with URI.fromBytes(uri, defaultPort=-1) to set the `port` to -1
if there is no explicit port given.
lookup_well_known (bool): True if we should look up the .well-known file if
there is no SRV record.
Returns:
Deferred[_RoutingResult]
"""
@ -169,6 +214,42 @@ class MatrixFederationAgent(object):
service_name = b"_matrix._tcp.%s" % (parsed_uri.host,)
server_list = yield self._srv_resolver.resolve_service(service_name)
if not server_list and lookup_well_known:
# try a .well-known lookup
well_known_server = yield self._get_well_known(parsed_uri.host)
if well_known_server:
# if we found a .well-known, start again, but don't do another
# .well-known lookup.
# parse the server name in the .well-known response into host/port.
# (This code is lifted from twisted.web.client.URI.fromBytes).
if b':' in well_known_server:
well_known_host, well_known_port = well_known_server.rsplit(b':', 1)
try:
well_known_port = int(well_known_port)
except ValueError:
# the part after the colon could not be parsed as an int
# - we assume it is an IPv6 literal with no port (the closing
# ']' stops it being parsed as an int)
well_known_host, well_known_port = well_known_server, -1
else:
well_known_host, well_known_port = well_known_server, -1
new_uri = URI(
scheme=parsed_uri.scheme,
netloc=well_known_server,
host=well_known_host,
port=well_known_port,
path=parsed_uri.path,
params=parsed_uri.params,
query=parsed_uri.query,
fragment=parsed_uri.fragment,
)
res = yield self._route_matrix_uri(new_uri, lookup_well_known=False)
defer.returnValue(res)
if not server_list:
target_host = parsed_uri.host
port = 8448
@ -190,6 +271,114 @@ class MatrixFederationAgent(object):
target_port=port,
))
@defer.inlineCallbacks
def _get_well_known(self, server_name):
"""Attempt to fetch and parse a .well-known file for the given server
Args:
server_name (bytes): name of the server, from the requested url
Returns:
Deferred[bytes|None]: either the new server name, from the .well-known, or
None if there was no .well-known file.
"""
try:
cached = self._well_known_cache[server_name]
defer.returnValue(cached)
except KeyError:
pass
# TODO: should we linearise so that we don't end up doing two .well-known requests
# for the same server in parallel?
uri = b"https://%s/.well-known/matrix/server" % (server_name, )
uri_str = uri.decode("ascii")
logger.info("Fetching %s", uri_str)
try:
response = yield make_deferred_yieldable(
self._well_known_agent.request(b"GET", uri),
)
body = yield make_deferred_yieldable(readBody(response))
if response.code != 200:
raise Exception("Non-200 response %s" % (response.code, ))
except Exception as e:
logger.info("Error fetching %s: %s", uri_str, e)
# add some randomness to the TTL to avoid a stampeding herd every hour
# after startup
cache_period = WELL_KNOWN_INVALID_CACHE_PERIOD
cache_period += random.uniform(0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER)
self._well_known_cache.set(server_name, None, cache_period)
defer.returnValue(None)
try:
parsed_body = json.loads(body.decode('utf-8'))
logger.info("Response from .well-known: %s", parsed_body)
if not isinstance(parsed_body, dict):
raise Exception("not a dict")
if "m.server" not in parsed_body:
raise Exception("Missing key 'm.server'")
except Exception as e:
raise Exception("invalid .well-known response from %s: %s" % (uri_str, e,))
result = parsed_body["m.server"].encode("ascii")
cache_period = _cache_period_from_headers(
response.headers,
time_now=self._reactor.seconds,
)
if cache_period is None:
cache_period = WELL_KNOWN_DEFAULT_CACHE_PERIOD
# add some randomness to the TTL to avoid a stampeding herd every 24 hours
# after startup
cache_period += random.uniform(0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER)
else:
cache_period = min(cache_period, WELL_KNOWN_MAX_CACHE_PERIOD)
if cache_period > 0:
self._well_known_cache.set(server_name, result, cache_period)
defer.returnValue(result)
def _cache_period_from_headers(headers, time_now=time.time):
cache_controls = _parse_cache_control(headers)
if b'no-store' in cache_controls:
return 0
if b'max-age' in cache_controls:
try:
max_age = int(cache_controls[b'max-age'])
return max_age
except ValueError:
pass
expires = headers.getRawHeaders(b'expires')
if expires is not None:
try:
expires_date = stringToDatetime(expires[-1])
return expires_date - time_now()
except ValueError:
# RFC7234 says 'A cache recipient MUST interpret invalid date formats,
# especially the value "0", as representing a time in the past (i.e.,
# "already expired").
return 0
return None
def _parse_cache_control(headers):
cache_controls = {}
for hdr in headers.getRawHeaders(b'cache-control', []):
for directive in hdr.split(b','):
splits = [x.strip() for x in directive.split(b'=', 1)]
k = splits[0].lower()
v = splits[1] if len(splits) > 1 else None
cache_controls[k] = v
return cache_controls
@attr.s
class _RoutingResult(object):

View file

@ -84,7 +84,7 @@ def _rule_to_template(rule):
templaterule["pattern"] = thecond["pattern"]
if unscoped_rule_id:
templaterule['rule_id'] = unscoped_rule_id
templaterule['rule_id'] = unscoped_rule_id
if 'default' in rule:
templaterule['default'] = rule['default']
return templaterule

View file

@ -95,7 +95,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
event_and_contexts = []
for event_payload in event_payloads:
event_dict = event_payload["event"]
format_ver = content["event_format_version"]
format_ver = event_payload["event_format_version"]
internal_metadata = event_payload["internal_metadata"]
rejected_reason = event_payload["rejected_reason"]

View file

@ -101,16 +101,7 @@ class ConsentResource(Resource):
"missing in config file.",
)
# daemonize changes the cwd to /, so make the path absolute now.
consent_template_directory = path.abspath(
hs.config.user_consent_template_dir,
)
if not path.isdir(consent_template_directory):
raise ConfigError(
"Could not find template directory '%s'" % (
consent_template_directory,
),
)
consent_template_directory = hs.config.user_consent_template_dir
loader = jinja2.FileSystemLoader(consent_template_directory)
self._jinja_env = jinja2.Environment(

View file

@ -355,10 +355,7 @@ class HomeServer(object):
return Keyring(self)
def build_event_builder_factory(self):
return EventBuilderFactory(
clock=self.get_clock(),
hostname=self.hostname,
)
return EventBuilderFactory(self)
def build_filtering(self):
return Filtering(self)

View file

@ -608,7 +608,7 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
state_sets, event_map, state_res_store.get_events,
)
elif room_version in (
RoomVersions.VDH_TEST, RoomVersions.STATE_V2_TEST, RoomVersions.V2,
RoomVersions.STATE_V2_TEST, RoomVersions.V2, RoomVersions.V3,
):
return v2.resolve_events_with_store(
room_version, state_sets, event_map, state_res_store,

View file

@ -317,7 +317,7 @@ class DataStore(RoomMemberStore, RoomStore,
thirty_days_ago_in_secs))
for row in txn:
if row[0] is 'unknown':
if row[0] == 'unknown':
pass
results[row[0]] = row[1]

View file

@ -125,6 +125,29 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
return dict(txn)
@defer.inlineCallbacks
def get_max_depth_of(self, event_ids):
"""Returns the max depth of a set of event IDs
Args:
event_ids (list[str])
Returns
Deferred[int]
"""
rows = yield self._simple_select_many_batch(
table="events",
column="event_id",
iterable=event_ids,
retcols=("depth",),
desc="get_max_depth_of",
)
if not rows:
defer.returnValue(0)
else:
defer.returnValue(max(row["depth"] for row in rows))
def _get_oldest_events_in_room_txn(self, txn, room_id):
return self._simple_select_onecol_txn(
txn,

View file

@ -904,105 +904,105 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
def _update_current_state_txn(self, txn, state_delta_by_room, max_stream_order):
for room_id, current_state_tuple in iteritems(state_delta_by_room):
to_delete, to_insert = current_state_tuple
to_delete, to_insert = current_state_tuple
# First we add entries to the current_state_delta_stream. We
# do this before updating the current_state_events table so
# that we can use it to calculate the `prev_event_id`. (This
# allows us to not have to pull out the existing state
# unnecessarily).
sql = """
INSERT INTO current_state_delta_stream
(stream_id, room_id, type, state_key, event_id, prev_event_id)
SELECT ?, ?, ?, ?, ?, (
SELECT event_id FROM current_state_events
WHERE room_id = ? AND type = ? AND state_key = ?
)
"""
txn.executemany(sql, (
(
max_stream_order, room_id, etype, state_key, None,
room_id, etype, state_key,
)
for etype, state_key in to_delete
# We sanity check that we're deleting rather than updating
if (etype, state_key) not in to_insert
))
txn.executemany(sql, (
(
max_stream_order, room_id, etype, state_key, ev_id,
room_id, etype, state_key,
)
for (etype, state_key), ev_id in iteritems(to_insert)
))
# Now we actually update the current_state_events table
txn.executemany(
"DELETE FROM current_state_events"
" WHERE room_id = ? AND type = ? AND state_key = ?",
(
(room_id, etype, state_key)
for etype, state_key in itertools.chain(to_delete, to_insert)
),
# First we add entries to the current_state_delta_stream. We
# do this before updating the current_state_events table so
# that we can use it to calculate the `prev_event_id`. (This
# allows us to not have to pull out the existing state
# unnecessarily).
sql = """
INSERT INTO current_state_delta_stream
(stream_id, room_id, type, state_key, event_id, prev_event_id)
SELECT ?, ?, ?, ?, ?, (
SELECT event_id FROM current_state_events
WHERE room_id = ? AND type = ? AND state_key = ?
)
self._simple_insert_many_txn(
txn,
table="current_state_events",
values=[
{
"event_id": ev_id,
"room_id": room_id,
"type": key[0],
"state_key": key[1],
}
for key, ev_id in iteritems(to_insert)
],
"""
txn.executemany(sql, (
(
max_stream_order, room_id, etype, state_key, None,
room_id, etype, state_key,
)
txn.call_after(
self._curr_state_delta_stream_cache.entity_has_changed,
room_id, max_stream_order,
for etype, state_key in to_delete
# We sanity check that we're deleting rather than updating
if (etype, state_key) not in to_insert
))
txn.executemany(sql, (
(
max_stream_order, room_id, etype, state_key, ev_id,
room_id, etype, state_key,
)
for (etype, state_key), ev_id in iteritems(to_insert)
))
# Invalidate the various caches
# Now we actually update the current_state_events table
# Figure out the changes of membership to invalidate the
# `get_rooms_for_user` cache.
# We find out which membership events we may have deleted
# and which we have added, then we invlidate the caches for all
# those users.
members_changed = set(
state_key
for ev_type, state_key in itertools.chain(to_delete, to_insert)
if ev_type == EventTypes.Member
)
txn.executemany(
"DELETE FROM current_state_events"
" WHERE room_id = ? AND type = ? AND state_key = ?",
(
(room_id, etype, state_key)
for etype, state_key in itertools.chain(to_delete, to_insert)
),
)
for member in members_changed:
self._invalidate_cache_and_stream(
txn, self.get_rooms_for_user_with_stream_ordering, (member,)
)
self._simple_insert_many_txn(
txn,
table="current_state_events",
values=[
{
"event_id": ev_id,
"room_id": room_id,
"type": key[0],
"state_key": key[1],
}
for key, ev_id in iteritems(to_insert)
],
)
for host in set(get_domain_from_id(u) for u in members_changed):
self._invalidate_cache_and_stream(
txn, self.is_host_joined, (room_id, host)
)
self._invalidate_cache_and_stream(
txn, self.was_host_joined, (room_id, host)
)
txn.call_after(
self._curr_state_delta_stream_cache.entity_has_changed,
room_id, max_stream_order,
)
# Invalidate the various caches
# Figure out the changes of membership to invalidate the
# `get_rooms_for_user` cache.
# We find out which membership events we may have deleted
# and which we have added, then we invlidate the caches for all
# those users.
members_changed = set(
state_key
for ev_type, state_key in itertools.chain(to_delete, to_insert)
if ev_type == EventTypes.Member
)
for member in members_changed:
self._invalidate_cache_and_stream(
txn, self.get_users_in_room, (room_id,)
txn, self.get_rooms_for_user_with_stream_ordering, (member,)
)
for host in set(get_domain_from_id(u) for u in members_changed):
self._invalidate_cache_and_stream(
txn, self.get_room_summary, (room_id,)
txn, self.is_host_joined, (room_id, host)
)
self._invalidate_cache_and_stream(
txn, self.was_host_joined, (room_id, host)
)
self._invalidate_cache_and_stream(
txn, self.get_current_state_ids, (room_id,)
)
self._invalidate_cache_and_stream(
txn, self.get_users_in_room, (room_id,)
)
self._invalidate_cache_and_stream(
txn, self.get_room_summary, (room_id,)
)
self._invalidate_cache_and_stream(
txn, self.get_current_state_ids, (room_id,)
)
def _update_forward_extremities_txn(self, txn, new_forward_extremities,
max_stream_order):

View file

@ -21,13 +21,14 @@ from canonicaljson import json
from twisted.internet import defer
from synapse.api.constants import EventFormatVersions
from synapse.api.constants import EventFormatVersions, EventTypes
from synapse.api.errors import NotFoundError
from synapse.events import FrozenEvent, event_type_from_format_version # noqa: F401
# these are only included to make the type annotations work
from synapse.events.snapshot import EventContext # noqa: F401
from synapse.events.utils import prune_event
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import get_domain_from_id
from synapse.util.logcontext import (
LoggingContext,
PreserveLoggingContext,
@ -162,7 +163,6 @@ class EventsWorkerStore(SQLBaseStore):
missing_events = yield self._enqueue_events(
missing_events_ids,
check_redacted=check_redacted,
allow_rejected=allow_rejected,
)
@ -174,6 +174,29 @@ class EventsWorkerStore(SQLBaseStore):
if not entry:
continue
# Starting in room version v3, some redactions need to be rechecked if we
# didn't have the redacted event at the time, so we recheck on read
# instead.
if not allow_rejected and entry.event.type == EventTypes.Redaction:
if entry.event.internal_metadata.need_to_check_redaction():
orig = yield self.get_event(
entry.event.redacts,
allow_none=True,
allow_rejected=True,
get_prev_content=False,
)
expected_domain = get_domain_from_id(entry.event.sender)
if orig and get_domain_from_id(orig.sender) == expected_domain:
# This redaction event is allowed. Mark as not needing a
# recheck.
entry.event.internal_metadata.recheck_redaction = False
else:
# We don't have the event that is being redacted, so we
# assume that the event isn't authorized for now. (If we
# later receive the event, then we will always redact
# it anyway, since we have this redaction)
continue
if allow_rejected or not entry.event.rejected_reason:
if check_redacted and entry.redacted_event:
event = entry.redacted_event
@ -197,7 +220,7 @@ class EventsWorkerStore(SQLBaseStore):
defer.returnValue(events)
def _invalidate_get_event_cache(self, event_id):
self._get_event_cache.invalidate((event_id,))
self._get_event_cache.invalidate((event_id,))
def _get_events_from_cache(self, events, allow_rejected, update_metrics=True):
"""Fetch events from the caches
@ -310,7 +333,7 @@ class EventsWorkerStore(SQLBaseStore):
self.hs.get_reactor().callFromThread(fire, event_list, e)
@defer.inlineCallbacks
def _enqueue_events(self, events, check_redacted=True, allow_rejected=False):
def _enqueue_events(self, events, allow_rejected=False):
"""Fetches events from the database using the _event_fetch_list. This
allows batch and bulk fetching of events - it allows us to fetch events
without having to create a new transaction for each request for events.
@ -443,6 +466,19 @@ class EventsWorkerStore(SQLBaseStore):
# will serialise this field correctly
redacted_event.unsigned["redacted_because"] = because
# Starting in room version v3, some redactions need to be
# rechecked if we didn't have the redacted event at the
# time, so we recheck on read instead.
if because.internal_metadata.need_to_check_redaction():
expected_domain = get_domain_from_id(original_ev.sender)
if get_domain_from_id(because.sender) == expected_domain:
# This redaction event is allowed. Mark as not needing a
# recheck.
because.internal_metadata.recheck_redaction = False
else:
# Senders don't match, so the event isn't actually redacted
redacted_event = None
cache_entry = _EventCacheEntry(
event=original_ev,
redacted_event=redacted_event,

View file

@ -201,7 +201,7 @@ class Linearizer(object):
if entry[0] >= self.max_count:
res = self._await_lock(key)
else:
logger.info(
logger.debug(
"Acquired uncontended linearizer lock %r for key %r", self.name, key,
)
entry[0] += 1
@ -215,7 +215,7 @@ class Linearizer(object):
try:
yield
finally:
logger.info("Releasing linearizer lock %r for key %r", self.name, key)
logger.debug("Releasing linearizer lock %r for key %r", self.name, key)
# We've finished executing so check if there are any things
# blocked waiting to execute and start one of them
@ -247,7 +247,7 @@ class Linearizer(object):
"""
entry = self.key_to_defer[key]
logger.info(
logger.debug(
"Waiting to acquire linearizer lock %r for key %r", self.name, key,
)
@ -255,7 +255,7 @@ class Linearizer(object):
entry[1][new_defer] = 1
def cb(_r):
logger.info("Acquired linearizer lock %r for key %r", self.name, key)
logger.debug("Acquired linearizer lock %r for key %r", self.name, key)
entry[0] += 1
# if the code holding the lock completes synchronously, then it
@ -273,7 +273,7 @@ class Linearizer(object):
def eb(e):
logger.info("defer %r got err %r", new_defer, e)
if isinstance(e, CancelledError):
logger.info(
logger.debug(
"Cancelling wait for linearizer lock %r for key %r",
self.name, key,
)

View file

@ -0,0 +1,161 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
import attr
from sortedcontainers import SortedList
from synapse.util.caches import register_cache
logger = logging.getLogger(__name__)
SENTINEL = object()
class TTLCache(object):
"""A key/value cache implementation where each entry has its own TTL"""
def __init__(self, cache_name, timer=time.time):
# map from key to _CacheEntry
self._data = {}
# the _CacheEntries, sorted by expiry time
self._expiry_list = SortedList()
self._timer = timer
self._metrics = register_cache("ttl", cache_name, self)
def set(self, key, value, ttl):
"""Add/update an entry in the cache
Args:
key: key for this entry
value: value for this entry
ttl (float): TTL for this entry, in seconds
"""
expiry = self._timer() + ttl
self.expire()
e = self._data.pop(key, SENTINEL)
if e != SENTINEL:
self._expiry_list.remove(e)
entry = _CacheEntry(expiry_time=expiry, key=key, value=value)
self._data[key] = entry
self._expiry_list.add(entry)
def get(self, key, default=SENTINEL):
"""Get a value from the cache
Args:
key: key to look up
default: default value to return, if key is not found. If not set, and the
key is not found, a KeyError will be raised
Returns:
value from the cache, or the default
"""
self.expire()
e = self._data.get(key, SENTINEL)
if e == SENTINEL:
self._metrics.inc_misses()
if default == SENTINEL:
raise KeyError(key)
return default
self._metrics.inc_hits()
return e.value
def get_with_expiry(self, key):
"""Get a value, and its expiry time, from the cache
Args:
key: key to look up
Returns:
Tuple[Any, float]: the value from the cache, and the expiry time
Raises:
KeyError if the entry is not found
"""
self.expire()
try:
e = self._data[key]
except KeyError:
self._metrics.inc_misses()
raise
self._metrics.inc_hits()
return e.value, e.expiry_time
def pop(self, key, default=SENTINEL):
"""Remove a value from the cache
If key is in the cache, remove it and return its value, else return default.
If default is not given and key is not in the cache, a KeyError is raised.
Args:
key: key to look up
default: default value to return, if key is not found. If not set, and the
key is not found, a KeyError will be raised
Returns:
value from the cache, or the default
"""
self.expire()
e = self._data.pop(key, SENTINEL)
if e == SENTINEL:
self._metrics.inc_misses()
if default == SENTINEL:
raise KeyError(key)
return default
self._expiry_list.remove(e)
self._metrics.inc_hits()
return e.value
def __getitem__(self, key):
return self.get(key)
def __delitem__(self, key):
self.pop(key)
def __contains__(self, key):
return key in self._data
def __len__(self):
self.expire()
return len(self._data)
def expire(self):
"""Run the expiry on the cache. Any entries whose expiry times are due will
be removed
"""
now = self._timer()
while self._expiry_list:
first_entry = self._expiry_list[0]
if first_entry.expiry_time - now > 0.0:
break
del self._data[first_entry.key]
del self._expiry_list[0]
@attr.s(frozen=True, slots=True)
class _CacheEntry(object):
"""TTLCache entry"""
# expiry_time is the first attribute, so that entries are sorted by expiry.
expiry_time = attr.ib()
key = attr.ib()
value = attr.ib()

View file

@ -285,7 +285,10 @@ class LoggingContext(object):
self.alive = False
# if we have a parent, pass our CPU usage stats on
if self.parent_context is not None:
if (
self.parent_context is not None
and hasattr(self.parent_context, '_resource_usage')
):
self.parent_context._resource_usage += self._resource_usage
# reset them in case we get entered again

View file

@ -50,8 +50,6 @@ class ConfigGenerationTestCase(unittest.TestCase):
"homeserver.yaml",
"lemurs.win.log.config",
"lemurs.win.signing.key",
"lemurs.win.tls.crt",
"lemurs.win.tls.key",
]
),
set(os.listdir(self.dir)),

75
tests/config/test_tls.py Normal file
View file

@ -0,0 +1,75 @@
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from synapse.config.tls import TlsConfig
from tests.unittest import TestCase
class TLSConfigTests(TestCase):
def test_warn_self_signed(self):
"""
Synapse will give a warning when it loads a self-signed certificate.
"""
config_dir = self.mktemp()
os.mkdir(config_dir)
with open(os.path.join(config_dir, "cert.pem"), 'w') as f:
f.write("""-----BEGIN CERTIFICATE-----
MIID6DCCAtACAws9CjANBgkqhkiG9w0BAQUFADCBtzELMAkGA1UEBhMCVFIxDzAN
BgNVBAgMBsOHb3J1bTEUMBIGA1UEBwwLQmHFn21ha8OnxLExEjAQBgNVBAMMCWxv
Y2FsaG9zdDEcMBoGA1UECgwTVHdpc3RlZCBNYXRyaXggTGFiczEkMCIGA1UECwwb
QXV0b21hdGVkIFRlc3RpbmcgQXV0aG9yaXR5MSkwJwYJKoZIhvcNAQkBFhpzZWN1
cml0eUB0d2lzdGVkbWF0cml4LmNvbTAgFw0xNzA3MTIxNDAxNTNaGA8yMTE3MDYx
ODE0MDE1M1owgbcxCzAJBgNVBAYTAlRSMQ8wDQYDVQQIDAbDh29ydW0xFDASBgNV
BAcMC0JhxZ9tYWvDp8SxMRIwEAYDVQQDDAlsb2NhbGhvc3QxHDAaBgNVBAoME1R3
aXN0ZWQgTWF0cml4IExhYnMxJDAiBgNVBAsMG0F1dG9tYXRlZCBUZXN0aW5nIEF1
dGhvcml0eTEpMCcGCSqGSIb3DQEJARYac2VjdXJpdHlAdHdpc3RlZG1hdHJpeC5j
b20wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDwT6kbqtMUI0sMkx4h
I+L780dA59KfksZCqJGmOsMD6hte9EguasfkZzvCF3dk3NhwCjFSOvKx6rCwiteo
WtYkVfo+rSuVNmt7bEsOUDtuTcaxTzIFB+yHOYwAaoz3zQkyVW0c4pzioiLCGCmf
FLdiDBQGGp74tb+7a0V6kC3vMLFoM3L6QWq5uYRB5+xLzlPJ734ltyvfZHL3Us6p
cUbK+3WTWvb4ER0W2RqArAj6Bc/ERQKIAPFEiZi9bIYTwvBH27OKHRz+KoY/G8zY
+l+WZoJqDhupRAQAuh7O7V/y6bSP+KNxJRie9QkZvw1PSaGSXtGJI3WWdO12/Ulg
epJpAgMBAAEwDQYJKoZIhvcNAQEFBQADggEBAJXEq5P9xwvP9aDkXIqzcD0L8sf8
ewlhlxTQdeqt2Nace0Yk18lIo2oj1t86Y8jNbpAnZJeI813Rr5M7FbHCXoRc/SZG
I8OtG1xGwcok53lyDuuUUDexnK4O5BkjKiVlNPg4HPim5Kuj2hRNFfNt/F2BVIlj
iZupikC5MT1LQaRwidkSNxCku1TfAyueiBwhLnFwTmIGNnhuDCutEVAD9kFmcJN2
SznugAcPk4doX2+rL+ila+ThqgPzIkwTUHtnmjI0TI6xsDUlXz5S3UyudrE2Qsfz
s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg=
-----END CERTIFICATE-----""")
config = {
"tls_certificate_path": os.path.join(config_dir, "cert.pem"),
"no_tls": True,
"tls_fingerprints": []
}
t = TlsConfig()
t.read_config(config)
t.read_certificate_from_disk()
warnings = self.flushWarnings()
self.assertEqual(len(warnings), 1)
self.assertEqual(
warnings[0]["message"],
(
"Self-signed TLS certificates will not be accepted by "
"Synapse 1.0. Please either provide a valid certificate, "
"or use Synapse's ACME support to provision one."
)
)

View file

@ -18,7 +18,7 @@ import nacl.signing
from unpaddedbase64 import decode_base64
from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.events.builder import EventBuilder
from synapse.events import FrozenEvent
from tests import unittest
@ -40,20 +40,18 @@ class EventSigningTestCase(unittest.TestCase):
self.signing_key.version = KEY_VER
def test_sign_minimal(self):
builder = EventBuilder(
{
'event_id': "$0:domain",
'origin': "domain",
'origin_server_ts': 1000000,
'signatures': {},
'type': "X",
'unsigned': {'age_ts': 1000000},
}
)
event_dict = {
'event_id': "$0:domain",
'origin': "domain",
'origin_server_ts': 1000000,
'signatures': {},
'type': "X",
'unsigned': {'age_ts': 1000000},
}
add_hashes_and_signatures(builder, HOSTNAME, self.signing_key)
add_hashes_and_signatures(event_dict, HOSTNAME, self.signing_key)
event = builder.build()
event = FrozenEvent(event_dict)
self.assertTrue(hasattr(event, 'hashes'))
self.assertIn('sha256', event.hashes)
@ -71,23 +69,21 @@ class EventSigningTestCase(unittest.TestCase):
)
def test_sign_message(self):
builder = EventBuilder(
{
'content': {'body': "Here is the message content"},
'event_id': "$0:domain",
'origin': "domain",
'origin_server_ts': 1000000,
'type': "m.room.message",
'room_id': "!r:domain",
'sender': "@u:domain",
'signatures': {},
'unsigned': {'age_ts': 1000000},
}
)
event_dict = {
'content': {'body': "Here is the message content"},
'event_id': "$0:domain",
'origin': "domain",
'origin_server_ts': 1000000,
'type': "m.room.message",
'room_id': "!r:domain",
'sender': "@u:domain",
'signatures': {},
'unsigned': {'age_ts': 1000000},
}
add_hashes_and_signatures(builder, HOSTNAME, self.signing_key)
add_hashes_and_signatures(event_dict, HOSTNAME, self.signing_key)
event = builder.build()
event = FrozenEvent(event_dict)
self.assertTrue(hasattr(event, 'hashes'))
self.assertIn('sha256', event.hashes)

View file

@ -0,0 +1,42 @@
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path
from OpenSSL import SSL
def get_test_cert_file():
"""get the path to the test cert"""
# the cert file itself is made with:
#
# openssl req -x509 -newkey rsa:4096 -keyout server.pem -out server.pem -days 36500 \
# -nodes -subj '/CN=testserv'
return os.path.join(
os.path.dirname(__file__),
'server.pem',
)
class ServerTLSContext(object):
"""A TLS Context which presents our test cert."""
def __init__(self):
self.filename = get_test_cert_file()
def getContext(self):
ctx = SSL.Context(SSL.TLSv1_METHOD)
ctx.use_certificate_file(self.filename)
ctx.use_privatekey_file(self.filename)
return ctx

View file

@ -17,18 +17,26 @@ import logging
from mock import Mock
import treq
from zope.interface import implementer
from twisted.internet import defer
from twisted.internet._sslverify import ClientTLSOptions, OpenSSLCertificateOptions
from twisted.internet.protocol import Factory
from twisted.protocols.tls import TLSMemoryBIOFactory
from twisted.test.ssl_helpers import ServerTLSContext
from twisted.web.http import HTTPChannel
from twisted.web.http_headers import Headers
from twisted.web.iweb import IPolicyForHTTPS
from synapse.crypto.context_factory import ClientTLSOptionsFactory
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.http.federation.matrix_federation_agent import (
MatrixFederationAgent,
_cache_period_from_headers,
)
from synapse.http.federation.srv_resolver import Server
from synapse.util.caches.ttlcache import TTLCache
from synapse.util.logcontext import LoggingContext
from tests.http import ServerTLSContext
from tests.server import FakeTransport, ThreadedMemoryReactorClock
from tests.unittest import TestCase
@ -41,10 +49,14 @@ class MatrixFederationAgentTests(TestCase):
self.mock_resolver = Mock()
self.well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds)
self.agent = MatrixFederationAgent(
reactor=self.reactor,
tls_client_options_factory=ClientTLSOptionsFactory(None),
_well_known_tls_policy=TrustingTLSPolicyForHTTPS(),
_srv_resolver=self.mock_resolver,
_well_known_cache=self.well_known_cache,
)
def _make_connection(self, client_factory, expected_sni):
@ -65,10 +77,14 @@ class MatrixFederationAgentTests(TestCase):
# Normally this would be done by the TCP socket code in Twisted, but we are
# stubbing that out here.
client_protocol = client_factory.buildProtocol(None)
client_protocol.makeConnection(FakeTransport(server_tls_protocol, self.reactor))
client_protocol.makeConnection(
FakeTransport(server_tls_protocol, self.reactor, client_protocol),
)
# tell the server tls protocol to send its stuff back to the client, too
server_tls_protocol.makeConnection(FakeTransport(client_protocol, self.reactor))
server_tls_protocol.makeConnection(
FakeTransport(client_protocol, self.reactor, server_tls_protocol),
)
# give the reactor a pump to get the TLS juices flowing.
self.reactor.pump((0.1,))
@ -101,9 +117,55 @@ class MatrixFederationAgentTests(TestCase):
try:
fetch_res = yield fetch_d
defer.returnValue(fetch_res)
except Exception as e:
logger.info("Fetch of %s failed: %s", uri.decode("ascii"), e)
raise
finally:
_check_logcontext(context)
def _handle_well_known_connection(
self, client_factory, expected_sni, target_server, response_headers={},
):
"""Handle an outgoing HTTPs connection: wire it up to a server, check that the
request is for a .well-known, and send the response.
Args:
client_factory (IProtocolFactory): outgoing connection
expected_sni (bytes): SNI that we expect the outgoing connection to send
target_server (bytes): target server that we should redirect to in the
.well-known response.
Returns:
HTTPChannel: server impl
"""
# make the connection for .well-known
well_known_server = self._make_connection(
client_factory,
expected_sni=expected_sni,
)
# check the .well-known request and send a response
self.assertEqual(len(well_known_server.requests), 1)
request = well_known_server.requests[0]
self._send_well_known_response(request, target_server, headers=response_headers)
return well_known_server
def _send_well_known_response(self, request, target_server, headers={}):
"""Check that an incoming request looks like a valid .well-known request, and
send back the response.
"""
self.assertEqual(request.method, b'GET')
self.assertEqual(request.path, b'/.well-known/matrix/server')
self.assertEqual(
request.requestHeaders.getRawHeaders(b'host'),
[b'testserv'],
)
# send back a response
for k, v in headers.items():
request.setHeader(k, v)
request.write(b'{ "m.server": "%s" }' % (target_server,))
request.finish()
self.reactor.pump((0.1, ))
def test_get(self):
"""
happy-path test of a GET request with an explicit port
@ -283,9 +345,9 @@ class MatrixFederationAgentTests(TestCase):
self.reactor.pump((0.1,))
self.successResultOf(test_d)
def test_get_hostname_no_srv(self):
def test_get_no_srv_no_well_known(self):
"""
Test the behaviour when the server name has no port, and no SRV record
Test the behaviour when the server name has no port, no SRV, and no well-known
"""
self.mock_resolver.resolve_service.side_effect = lambda _: []
@ -300,11 +362,24 @@ class MatrixFederationAgentTests(TestCase):
b"_matrix._tcp.testserv",
)
# Make sure treq is trying to connect
# there should be an attempt to connect on port 443 for the .well-known
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
self.assertEqual(host, '1.2.3.4')
self.assertEqual(port, 443)
# fonx the connection
client_factory.clientConnectionFailed(None, Exception("nope"))
# attemptdelay on the hostnameendpoint is 0.3, so takes that long before the
# .well-known request fails.
self.reactor.pump((0.4,))
# we should fall back to a direct connection
self.assertEqual(len(clients), 2)
(host, port, client_factory, _timeout, _bindAddress) = clients[1]
self.assertEqual(host, '1.2.3.4')
self.assertEqual(port, 8448)
# make a test server, and wire up the client
@ -327,6 +402,171 @@ class MatrixFederationAgentTests(TestCase):
self.reactor.pump((0.1,))
self.successResultOf(test_d)
def test_get_well_known(self):
"""Test the behaviour when the server name has no port and no SRV record, but
the .well-known redirects elsewhere
"""
self.mock_resolver.resolve_service.side_effect = lambda _: []
self.reactor.lookups["testserv"] = "1.2.3.4"
self.reactor.lookups["target-server"] = "1::f"
test_d = self._make_get_request(b"matrix://testserv/foo/bar")
# Nothing happened yet
self.assertNoResult(test_d)
self.mock_resolver.resolve_service.assert_called_once_with(
b"_matrix._tcp.testserv",
)
self.mock_resolver.resolve_service.reset_mock()
# there should be an attempt to connect on port 443 for the .well-known
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
self.assertEqual(host, '1.2.3.4')
self.assertEqual(port, 443)
self._handle_well_known_connection(
client_factory, expected_sni=b"testserv", target_server=b"target-server",
)
# there should be another SRV lookup
self.mock_resolver.resolve_service.assert_called_once_with(
b"_matrix._tcp.target-server",
)
# now we should get a connection to the target server
self.assertEqual(len(clients), 2)
(host, port, client_factory, _timeout, _bindAddress) = clients[1]
self.assertEqual(host, '1::f')
self.assertEqual(port, 8448)
# make a test server, and wire up the client
http_server = self._make_connection(
client_factory,
expected_sni=b'target-server',
)
self.assertEqual(len(http_server.requests), 1)
request = http_server.requests[0]
self.assertEqual(request.method, b'GET')
self.assertEqual(request.path, b'/foo/bar')
self.assertEqual(
request.requestHeaders.getRawHeaders(b'host'),
[b'target-server'],
)
# finish the request
request.finish()
self.reactor.pump((0.1,))
self.successResultOf(test_d)
self.assertEqual(self.well_known_cache[b"testserv"], b"target-server")
# check the cache expires
self.reactor.pump((25 * 3600,))
self.well_known_cache.expire()
self.assertNotIn(b"testserv", self.well_known_cache)
def test_get_well_known_redirect(self):
"""Test the behaviour when the server name has no port and no SRV record, but
the .well-known has a 300 redirect
"""
self.mock_resolver.resolve_service.side_effect = lambda _: []
self.reactor.lookups["testserv"] = "1.2.3.4"
self.reactor.lookups["target-server"] = "1::f"
test_d = self._make_get_request(b"matrix://testserv/foo/bar")
# Nothing happened yet
self.assertNoResult(test_d)
self.mock_resolver.resolve_service.assert_called_once_with(
b"_matrix._tcp.testserv",
)
self.mock_resolver.resolve_service.reset_mock()
# there should be an attempt to connect on port 443 for the .well-known
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients.pop()
self.assertEqual(host, '1.2.3.4')
self.assertEqual(port, 443)
redirect_server = self._make_connection(
client_factory,
expected_sni=b"testserv",
)
# send a 302 redirect
self.assertEqual(len(redirect_server.requests), 1)
request = redirect_server.requests[0]
request.redirect(b'https://testserv/even_better_known')
request.finish()
self.reactor.pump((0.1, ))
# now there should be another connection
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients.pop()
self.assertEqual(host, '1.2.3.4')
self.assertEqual(port, 443)
well_known_server = self._make_connection(
client_factory,
expected_sni=b"testserv",
)
self.assertEqual(len(well_known_server.requests), 1, "No request after 302")
request = well_known_server.requests[0]
self.assertEqual(request.method, b'GET')
self.assertEqual(request.path, b'/even_better_known')
request.write(b'{ "m.server": "target-server" }')
request.finish()
self.reactor.pump((0.1, ))
# there should be another SRV lookup
self.mock_resolver.resolve_service.assert_called_once_with(
b"_matrix._tcp.target-server",
)
# now we should get a connection to the target server
self.assertEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
self.assertEqual(host, '1::f')
self.assertEqual(port, 8448)
# make a test server, and wire up the client
http_server = self._make_connection(
client_factory,
expected_sni=b'target-server',
)
self.assertEqual(len(http_server.requests), 1)
request = http_server.requests[0]
self.assertEqual(request.method, b'GET')
self.assertEqual(request.path, b'/foo/bar')
self.assertEqual(
request.requestHeaders.getRawHeaders(b'host'),
[b'target-server'],
)
# finish the request
request.finish()
self.reactor.pump((0.1,))
self.successResultOf(test_d)
self.assertEqual(self.well_known_cache[b"testserv"], b"target-server")
# check the cache expires
self.reactor.pump((25 * 3600,))
self.well_known_cache.expire()
self.assertNotIn(b"testserv", self.well_known_cache)
def test_get_hostname_srv(self):
"""
Test the behaviour when there is a single SRV record
@ -372,6 +612,71 @@ class MatrixFederationAgentTests(TestCase):
self.reactor.pump((0.1,))
self.successResultOf(test_d)
def test_get_well_known_srv(self):
"""Test the behaviour when the server name has no port and no SRV record, but
the .well-known redirects to a place where there is a SRV.
"""
self.mock_resolver.resolve_service.side_effect = lambda _: []
self.reactor.lookups["testserv"] = "1.2.3.4"
self.reactor.lookups["srvtarget"] = "5.6.7.8"
test_d = self._make_get_request(b"matrix://testserv/foo/bar")
# Nothing happened yet
self.assertNoResult(test_d)
self.mock_resolver.resolve_service.assert_called_once_with(
b"_matrix._tcp.testserv",
)
self.mock_resolver.resolve_service.reset_mock()
# there should be an attempt to connect on port 443 for the .well-known
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
self.assertEqual(host, '1.2.3.4')
self.assertEqual(port, 443)
self.mock_resolver.resolve_service.side_effect = lambda _: [
Server(host=b"srvtarget", port=8443),
]
self._handle_well_known_connection(
client_factory, expected_sni=b"testserv", target_server=b"target-server",
)
# there should be another SRV lookup
self.mock_resolver.resolve_service.assert_called_once_with(
b"_matrix._tcp.target-server",
)
# now we should get a connection to the target of the SRV record
self.assertEqual(len(clients), 2)
(host, port, client_factory, _timeout, _bindAddress) = clients[1]
self.assertEqual(host, '5.6.7.8')
self.assertEqual(port, 8443)
# make a test server, and wire up the client
http_server = self._make_connection(
client_factory,
expected_sni=b'target-server',
)
self.assertEqual(len(http_server.requests), 1)
request = http_server.requests[0]
self.assertEqual(request.method, b'GET')
self.assertEqual(request.path, b'/foo/bar')
self.assertEqual(
request.requestHeaders.getRawHeaders(b'host'),
[b'target-server'],
)
# finish the request
request.finish()
self.reactor.pump((0.1,))
self.successResultOf(test_d)
def test_idna_servername(self):
"""test the behaviour when the server name has idna chars in"""
@ -390,11 +695,25 @@ class MatrixFederationAgentTests(TestCase):
b"_matrix._tcp.xn--bcher-kva.com",
)
# Make sure treq is trying to connect
# there should be an attempt to connect on port 443 for the .well-known
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
self.assertEqual(host, '1.2.3.4')
self.assertEqual(port, 443)
# fonx the connection
client_factory.clientConnectionFailed(None, Exception("nope"))
# attemptdelay on the hostnameendpoint is 0.3, so takes that long before the
# .well-known request fails.
self.reactor.pump((0.4,))
# We should fall back to port 8448
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 2)
(host, port, client_factory, _timeout, _bindAddress) = clients[1]
self.assertEqual(host, '1.2.3.4')
self.assertEqual(port, 8448)
# make a test server, and wire up the client
@ -461,6 +780,126 @@ class MatrixFederationAgentTests(TestCase):
self.reactor.pump((0.1,))
self.successResultOf(test_d)
@defer.inlineCallbacks
def do_get_well_known(self, serv):
try:
result = yield self.agent._get_well_known(serv)
logger.info("Result from well-known fetch: %s", result)
except Exception as e:
logger.warning("Error fetching well-known: %s", e)
raise
defer.returnValue(result)
def test_well_known_cache(self):
self.reactor.lookups["testserv"] = "1.2.3.4"
fetch_d = self.do_get_well_known(b'testserv')
# there should be an attempt to connect on port 443 for the .well-known
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
self.assertEqual(host, '1.2.3.4')
self.assertEqual(port, 443)
well_known_server = self._handle_well_known_connection(
client_factory,
expected_sni=b"testserv",
response_headers={b'Cache-Control': b'max-age=10'},
target_server=b"target-server",
)
r = self.successResultOf(fetch_d)
self.assertEqual(r, b'target-server')
# close the tcp connection
well_known_server.loseConnection()
# repeat the request: it should hit the cache
fetch_d = self.do_get_well_known(b'testserv')
r = self.successResultOf(fetch_d)
self.assertEqual(r, b'target-server')
# expire the cache
self.reactor.pump((10.0,))
# now it should connect again
fetch_d = self.do_get_well_known(b'testserv')
self.assertEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
self.assertEqual(host, '1.2.3.4')
self.assertEqual(port, 443)
self._handle_well_known_connection(
client_factory,
expected_sni=b"testserv",
target_server=b"other-server",
)
r = self.successResultOf(fetch_d)
self.assertEqual(r, b'other-server')
class TestCachePeriodFromHeaders(TestCase):
def test_cache_control(self):
# uppercase
self.assertEqual(
_cache_period_from_headers(
Headers({b'Cache-Control': [b'foo, Max-Age = 100, bar']}),
), 100,
)
# missing value
self.assertIsNone(_cache_period_from_headers(
Headers({b'Cache-Control': [b'max-age=, bar']}),
))
# hackernews: bogus due to semicolon
self.assertIsNone(_cache_period_from_headers(
Headers({b'Cache-Control': [b'private; max-age=0']}),
))
# github
self.assertEqual(
_cache_period_from_headers(
Headers({b'Cache-Control': [b'max-age=0, private, must-revalidate']}),
), 0,
)
# google
self.assertEqual(
_cache_period_from_headers(
Headers({b'cache-control': [b'private, max-age=0']}),
), 0,
)
def test_expires(self):
self.assertEqual(
_cache_period_from_headers(
Headers({b'Expires': [b'Wed, 30 Jan 2019 07:35:33 GMT']}),
time_now=lambda: 1548833700
), 33,
)
# cache-control overrides expires
self.assertEqual(
_cache_period_from_headers(
Headers({
b'cache-control': [b'max-age=10'],
b'Expires': [b'Wed, 30 Jan 2019 07:35:33 GMT']
}),
time_now=lambda: 1548833700
), 10,
)
# invalid expires means immediate expiry
self.assertEqual(
_cache_period_from_headers(
Headers({b'Expires': [b'0']}),
), 0,
)
def _check_logcontext(context):
current = LoggingContext.current_context()
@ -492,3 +931,11 @@ def _build_test_server():
def _log_request(request):
"""Implements Factory.log, which is expected by Request.finish"""
logger.info("Completed request %s", request)
@implementer(IPolicyForHTTPS)
class TrustingTLSPolicyForHTTPS(object):
"""An IPolicyForHTTPS which doesn't do any certificate verification"""
def creatorForNetloc(self, hostname, port):
certificateOptions = OpenSSLCertificateOptions()
return ClientTLSOptions(hostname, certificateOptions.getContext())

81
tests/http/server.pem Normal file
View file

@ -0,0 +1,81 @@
-----BEGIN PRIVATE KEY-----
MIIJQgIBADANBgkqhkiG9w0BAQEFAASCCSwwggkoAgEAAoICAQCgF43/3lAgJ+p0
x7Rn8UcL8a4fctvdkikvZrCngw96LkB34Evfq8YGWlOVjU+f9naUJLAKMatmAfEN
r+rMX4VOXmpTwuu6iLtqwreUrRFMESyrmvQxa15p+y85gkY0CFmXMblv6ORbxHTG
ncBGwST4WK4Poewcgt6jcISFCESTUKu1zc3cw1ANIDRyDLB5K44KwIe36dcKckyN
Kdtv4BJ+3fcIZIkPJH62zqCypgFF1oiFt40uJzClxgHdJZlKYpgkfnDTckw4Y/Mx
9k8BbE310KAzUNMV9H7I1eEolzrNr66FQj1eN64X/dqO8lTbwCqAd4diCT4sIUk0
0SVsAUjNd3g8j651hx+Qb1t8fuOjrny8dmeMxtUgIBHoQcpcj76R55Fs7KZ9uar0
8OFTyGIze51W1jG2K/7/5M1zxIqrA+7lsXu5OR81s7I+Ng/UUAhiHA/z+42/aiNa
qEuk6tqj3rHfLctnCbtZ+JrRNqSSwEi8F0lMA021ivEd2eJV+284OyJjhXOmKHrX
QADHrmS7Sh4syTZvRNm9n+qWID0KdDr2Sji/KnS3Enp44HDQ4xriT6/xhwEGsyuX
oH5aAkdLznulbWkHBbyx1SUQSTLpOqzaioF9m1vRrLsFvrkrY3D253mPJ5eU9HM/
dilduFcUgj4rz+6cdXUAh+KK/v95zwIDAQABAoICAFG5tJPaOa0ws0/KYx5s3YgL
aIhFalhCNSQtmCDrlwsYcXDA3/rfBchYdDL0YKGYgBBAal3J3WXFt/j0xThvyu2m
5UC9UPl4s7RckrsjXqEmY1d3UxGnbhtMT19cUdpeKN42VCP9EBaIw9Rg07dLAkSF
gNYaIx6q8F0fI4eGIPvTQtUcqur4CfWpaxyNvckdovV6M85/YXfDwbCOnacPDGIX
jfSK3i0MxGMuOHr6o8uzKR6aBUh6WStHWcw7VXXTvzdiFNbckmx3Gb93rf1b/LBw
QFfx+tBKcC62gKroCOzXso/0sL9YTVeSD/DJZOiJwSiz3Dj/3u1IUMbVvfTU8wSi
CYS7Z+jHxwSOCSSNTXm1wO/MtDsNKbI1+R0cohr/J9pOMQvrVh1+2zSDOFvXAQ1S
yvjn+uqdmijRoV2VEGVHd+34C+ci7eJGAhL/f92PohuuFR2shUETgGWzpACZSJwg
j1d90Hs81hj07vWRb+xCeDh00vimQngz9AD8vYvv/S4mqRGQ6TZdfjLoUwSTg0JD
6sQgRXX026gQhLhn687vLKZfHwzQPZkpQdxOR0dTZ/ho/RyGGRJXH4kN4cA2tPr+
AKYQ29YXGlEzGG7OqikaZcprNWG6UFgEpuXyBxCgp9r4ladZo3J+1Rhgus8ZYatd
uO98q3WEBmP6CZ2n32mBAoIBAQDS/c/ybFTos0YpGHakwdmSfj5OOQJto2y8ywfG
qDHwO0ebcpNnS1+MA+7XbKUQb/3Iq7iJljkkzJG2DIJ6rpKynYts1ViYpM7M/t0T
W3V1gvUcUL62iqkgws4pnpWmubFkqV31cPSHcfIIclnzeQ1aOEGsGHNAvhty0ciC
DnkJACbqApvopFLOR5f6UFTtKExE+hDH0WqgpsCAKJ1L4g6pBzZatI32/CN9JEVU
tDbxLV75hHlFFjUrG7nT1rPyr/gI8Ceh9/2xeXPfjJUR0PrG3U1nwLqUCZkvFzO6
XpN2+A+/v4v5xqMjKDKDFy1oq6SCMomwv/viw6wl/84TMbolAoIBAQDCPiMecnR8
REik6tqVzQO/uSe9ZHjz6J15t5xdwaI6HpSwLlIkQPkLTjyXtFpemK5DOYRxrJvQ
remfrZrN2qtLlb/DKpuGPWRsPOvWCrSuNEp48ivUehtclljrzxAFfy0sM+fWeJ48
nTnR+td9KNhjNtZixzWdAy/mE+jdaMsXVnk66L73Uz+2WsnvVMW2R6cpCR0F2eP/
B4zDWRqlT2w47sePAB81mFYSQLvPC6Xcgg1OqMubfiizJI49c8DO6Jt+FFYdsxhd
kG52Eqa/Net6rN3ueiS6yXL5TU3Y6g96bPA2KyNCypucGcddcBfqaiVx/o4AH6yT
NrdsrYtyvk/jAoIBAQDHUwKVeeRJJbvdbQAArCV4MI155n+1xhMe1AuXkCQFWGtQ
nlBE4D72jmyf1UKnIbW2Uwv15xY6/ouVWYIWlj9+QDmMaozVP7Uiko+WDuwLRNl8
k4dn+dzHV2HejbPBG2JLv3lFOx23q1zEwArcaXrExaq9Ayg2fKJ/uVHcFAIiD6Oz
pR1XDY4w1A/uaN+iYFSVQUyDCQLbnEz1hej73CaPZoHh9Pq83vxD5/UbjVjuRTeZ
L55FNzKpc/r89rNvTPBcuUwnxplDhYKDKVNWzn9rSXwrzTY2Tk8J3rh+k4RqevSd
6D47jH1n5Dy7/TRn0ueKHGZZtTUnyEUkbOJo3ayFAoIBAHKDyZaQqaX9Z8p6fwWj
yVsFoK0ih8BcWkLBAdmwZ6DWGJjJpjmjaG/G3ygc9s4gO1R8m12dAnuDnGE8KzDD
gwtbrKM2Alyg4wyA2hTlWOH/CAzH0RlCJ9Fs/d1/xJVJBeuyajLiB3/6vXTS6qnq
I7BSSxAPG8eGcn21LSsjNeB7ZZtaTgNnu/8ZBUYo9yrgkWc67TZe3/ChldYxOOlO
qqHh/BqNWtjxB4VZTp/g4RbgQVInZ2ozdXEv0v/dt0UEk29ANAjsZif7F3RayJ2f
/0TilzCaJ/9K9pKNhaClVRy7Dt8QjYg6BIWCGSw4ApF7pLnQ9gySn95mersCkVzD
YDsCggEAb0E/TORjQhKfNQvahyLfQFm151e+HIoqBqa4WFyfFxe/IJUaLH/JSSFw
VohbQqPdCmaAeuQ8ERL564DdkcY5BgKcax79fLLCOYP5bT11aQx6uFpfl2Dcm6Z9
QdCRI4jzPftsd5fxLNH1XtGyC4t6vTic4Pji2O71WgWzx0j5v4aeDY4sZQeFxqCV
/q7Ee8hem1Rn5RFHu14FV45RS4LAWl6wvf5pQtneSKzx8YL0GZIRRytOzdEfnGKr
FeUlAj5uL+5/p0ZEgM7gPsEBwdm8scF79qSUn8UWSoXNeIauF9D4BDg8RZcFFxka
KILVFsq3cQC+bEnoM4eVbjEQkGs1RQ==
-----END PRIVATE KEY-----
-----BEGIN CERTIFICATE-----
MIIE/jCCAuagAwIBAgIJANFtVaGvJWZlMA0GCSqGSIb3DQEBCwUAMBMxETAPBgNV
BAMMCHRlc3RzZXJ2MCAXDTE5MDEyNzIyMDIzNloYDzIxMTkwMTAzMjIwMjM2WjAT
MREwDwYDVQQDDAh0ZXN0c2VydjCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoC
ggIBAKAXjf/eUCAn6nTHtGfxRwvxrh9y292SKS9msKeDD3ouQHfgS9+rxgZaU5WN
T5/2dpQksAoxq2YB8Q2v6sxfhU5ealPC67qIu2rCt5StEUwRLKua9DFrXmn7LzmC
RjQIWZcxuW/o5FvEdMadwEbBJPhYrg+h7ByC3qNwhIUIRJNQq7XNzdzDUA0gNHIM
sHkrjgrAh7fp1wpyTI0p22/gEn7d9whkiQ8kfrbOoLKmAUXWiIW3jS4nMKXGAd0l
mUpimCR+cNNyTDhj8zH2TwFsTfXQoDNQ0xX0fsjV4SiXOs2vroVCPV43rhf92o7y
VNvAKoB3h2IJPiwhSTTRJWwBSM13eDyPrnWHH5BvW3x+46OufLx2Z4zG1SAgEehB
ylyPvpHnkWzspn25qvTw4VPIYjN7nVbWMbYr/v/kzXPEiqsD7uWxe7k5HzWzsj42
D9RQCGIcD/P7jb9qI1qoS6Tq2qPesd8ty2cJu1n4mtE2pJLASLwXSUwDTbWK8R3Z
4lX7bzg7ImOFc6YoetdAAMeuZLtKHizJNm9E2b2f6pYgPQp0OvZKOL8qdLcSenjg
cNDjGuJPr/GHAQazK5egfloCR0vOe6VtaQcFvLHVJRBJMuk6rNqKgX2bW9GsuwW+
uStjcPbneY8nl5T0cz92KV24VxSCPivP7px1dQCH4or+/3nPAgMBAAGjUzBRMB0G
A1UdDgQWBBQcQZpzLzTk5KdS/Iz7sGCV7gTd/zAfBgNVHSMEGDAWgBQcQZpzLzTk
5KdS/Iz7sGCV7gTd/zAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IC
AQAr/Pgha57jqYsDDX1LyRrVdqoVBpLBeB7x/p9dKYm7S6tBTDFNMZ0SZyQP8VEG
7UoC9/OQ9nCdEMoR7ZKpQsmipwcIqpXHS6l4YOkf5EEq5jpMgvlEesHmBJJeJew/
FEPDl1bl8d0tSrmWaL3qepmwzA+2lwAAouWk2n+rLiP8CZ3jZeoTXFqYYrUlEqO9
fHMvuWqTV4KCSyNY+GWCrnHetulgKHlg+W2J1mZnrCKcBhWf9C2DesTJO+JldIeM
ornTFquSt21hZi+k3aySuMn2N3MWiNL8XsZVsAnPSs0zA+2fxjJkShls8Gc7cCvd
a6XrNC+PY6pONguo7rEU4HiwbvnawSTngFFglmH/ImdA/HkaAekW6o82aI8/UxFx
V9fFMO3iKDQdOrg77hI1bx9RlzKNZZinE2/Pu26fWd5d2zqDWCjl8ykGQRAfXgYN
H3BjgyXLl+ao5/pOUYYtzm3ruTXTgRcy5hhL6hVTYhSrf9vYh4LNIeXNKnZ78tyG
TX77/kU2qXhBGCFEUUMqUNV/+ITir2lmoxVjknt19M07aGr8C7SgYt6Rs+qDpMiy
JurgvRh8LpVq4pHx1efxzxCFmo58DMrG40I0+CF3y/niNpOb1gp2wAqByRiORkds
f0ytW6qZ0TpHbD6gOtQLYDnhx3ISuX+QYSekVwQUpffeWQ==
-----END CERTIFICATE-----

View file

@ -354,7 +354,13 @@ class FakeTransport(object):
:type: twisted.internet.interfaces.IReactorTime
"""
_protocol = attr.ib(default=None)
"""The Protocol which is producing data for this transport. Optional, but if set
will get called back for connectionLost() notifications etc.
"""
disconnecting = False
disconnected = False
buffer = attr.ib(default=b'')
producer = attr.ib(default=None)
@ -364,11 +370,17 @@ class FakeTransport(object):
def getHost(self):
return None
def loseConnection(self):
self.disconnecting = True
def loseConnection(self, reason=None):
if not self.disconnecting:
logger.info("FakeTransport: loseConnection(%s)", reason)
self.disconnecting = True
if self._protocol:
self._protocol.connectionLost(reason)
self.disconnected = True
def abortConnection(self):
self.disconnecting = True
logger.info("FakeTransport: abortConnection()")
self.loseConnection()
def pauseProducing(self):
if not self.producer:
@ -407,9 +419,16 @@ class FakeTransport(object):
# TLSMemoryBIOProtocol
return
if self.disconnected:
return
logger.info("%s->%s: %s", self._protocol, self.other, self.buffer)
if getattr(self.other, "transport") is not None:
self.other.dataReceived(self.buffer)
self.buffer = b""
try:
self.other.dataReceived(self.buffer)
self.buffer = b""
except Exception as e:
logger.warning("Exception writing to protocol: %s", e)
return
self._reactor.callLater(0.0, _write)

View file

@ -11,7 +11,7 @@ class BackgroundUpdateTestCase(unittest.TestCase):
def setUp(self):
hs = yield setup_test_homeserver(
self.addCleanup
) # type: synapse.server.HomeServer
)
self.store = hs.get_datastore()
self.clock = hs.get_clock()

View file

@ -20,9 +20,6 @@ import tests.utils
class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
def __init__(self, *args, **kwargs):
super(EndToEndKeyStoreTestCase, self).__init__(*args, **kwargs)
self.store = None # type: synapse.storage.DataStore
@defer.inlineCallbacks
def setUp(self):

View file

@ -22,9 +22,6 @@ import tests.utils
class KeyStoreTestCase(tests.unittest.TestCase):
def __init__(self, *args, **kwargs):
super(KeyStoreTestCase, self).__init__(*args, **kwargs)
self.store = None # type: synapse.storage.keys.KeyStore
@defer.inlineCallbacks
def setUp(self):

View file

@ -28,9 +28,6 @@ logger = logging.getLogger(__name__)
class StateStoreTestCase(tests.unittest.TestCase):
def __init__(self, *args, **kwargs):
super(StateStoreTestCase, self).__init__(*args, **kwargs)
self.store = None # type: synapse.storage.DataStore
@defer.inlineCallbacks
def setUp(self):

View file

@ -0,0 +1,18 @@
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utilities for running the unit tests
"""

View file

@ -0,0 +1,54 @@
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import twisted.logger
from synapse.util.logcontext import LoggingContextFilter
class ToTwistedHandler(logging.Handler):
"""logging handler which sends the logs to the twisted log"""
tx_log = twisted.logger.Logger()
def emit(self, record):
log_entry = self.format(record)
log_level = record.levelname.lower().replace('warning', 'warn')
self.tx_log.emit(
twisted.logger.LogLevel.levelWithName(log_level),
log_entry.replace("{", r"(").replace("}", r")"),
)
def setup_logging():
"""Configure the python logging appropriately for the tests.
(Logs will end up in _trial_temp.)
"""
root_logger = logging.getLogger()
log_format = (
"%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s - %(message)s"
)
handler = ToTwistedHandler()
formatter = logging.Formatter(log_format)
handler.setFormatter(formatter)
handler.addFilter(LoggingContextFilter(request=""))
root_logger.addHandler(handler)
log_level = os.environ.get("SYNAPSE_TEST_LOG_LEVEL", "ERROR")
root_logger.setLevel(log_level)

View file

@ -166,7 +166,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def inject_message(self, user_id, content=None):
if content is None:
content = {"body": "testytest"}
content = {"body": "testytest", "msgtype": "m.text"}
builder = self.event_builder_factory.new(
RoomVersions.V1,
{

View file

@ -31,38 +31,14 @@ from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest
from synapse.server import HomeServer
from synapse.types import UserID, create_requester
from synapse.util.logcontext import LoggingContext, LoggingContextFilter
from synapse.util.logcontext import LoggingContext
from tests.server import get_clock, make_request, render, setup_test_homeserver
from tests.test_utils.logging_setup import setup_logging
from tests.utils import default_config, setupdb
setupdb()
# Set up putting Synapse's logs into Trial's.
rootLogger = logging.getLogger()
log_format = (
"%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s - %(message)s"
)
class ToTwistedHandler(logging.Handler):
tx_log = twisted.logger.Logger()
def emit(self, record):
log_entry = self.format(record)
log_level = record.levelname.lower().replace('warning', 'warn')
self.tx_log.emit(
twisted.logger.LogLevel.levelWithName(log_level),
log_entry.replace("{", r"(").replace("}", r")"),
)
handler = ToTwistedHandler()
formatter = logging.Formatter(log_format)
handler.setFormatter(formatter)
handler.addFilter(LoggingContextFilter(request=""))
rootLogger.addHandler(handler)
setup_logging()
def around(target):
@ -96,7 +72,7 @@ class TestCase(unittest.TestCase):
method = getattr(self, methodName)
level = getattr(method, "loglevel", getattr(self, "loglevel", logging.WARNING))
level = getattr(method, "loglevel", getattr(self, "loglevel", None))
@around(self)
def setUp(orig):
@ -114,7 +90,7 @@ class TestCase(unittest.TestCase):
)
old_level = logging.getLogger().level
if old_level != level:
if level is not None and old_level != level:
@around(self)
def tearDown(orig):
@ -122,7 +98,8 @@ class TestCase(unittest.TestCase):
logging.getLogger().setLevel(old_level)
return ret
logging.getLogger().setLevel(level)
logging.getLogger().setLevel(level)
return orig()
@around(self)

View file

@ -0,0 +1,83 @@
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from mock import Mock
from synapse.util.caches.ttlcache import TTLCache
from tests import unittest
class CacheTestCase(unittest.TestCase):
def setUp(self):
self.mock_timer = Mock(side_effect=lambda: 100.0)
self.cache = TTLCache("test_cache", self.mock_timer)
def test_get(self):
"""simple set/get tests"""
self.cache.set('one', '1', 10)
self.cache.set('two', '2', 20)
self.cache.set('three', '3', 30)
self.assertEqual(len(self.cache), 3)
self.assertTrue('one' in self.cache)
self.assertEqual(self.cache.get('one'), '1')
self.assertEqual(self.cache['one'], '1')
self.assertEqual(self.cache.get_with_expiry('one'), ('1', 110))
self.assertEqual(self.cache._metrics.hits, 3)
self.assertEqual(self.cache._metrics.misses, 0)
self.cache.set('two', '2.5', 20)
self.assertEqual(self.cache['two'], '2.5')
self.assertEqual(self.cache._metrics.hits, 4)
# non-existent-item tests
self.assertEqual(self.cache.get('four', '4'), '4')
self.assertIs(self.cache.get('four', None), None)
with self.assertRaises(KeyError):
self.cache['four']
with self.assertRaises(KeyError):
self.cache.get('four')
with self.assertRaises(KeyError):
self.cache.get_with_expiry('four')
self.assertEqual(self.cache._metrics.hits, 4)
self.assertEqual(self.cache._metrics.misses, 5)
def test_expiry(self):
self.cache.set('one', '1', 10)
self.cache.set('two', '2', 20)
self.cache.set('three', '3', 30)
self.assertEqual(len(self.cache), 3)
self.assertEqual(self.cache['one'], '1')
self.assertEqual(self.cache['two'], '2')
# enough for the first entry to expire, but not the rest
self.mock_timer.side_effect = lambda: 110.0
self.assertEqual(len(self.cache), 2)
self.assertFalse('one' in self.cache)
self.assertEqual(self.cache['two'], '2')
self.assertEqual(self.cache['three'], '3')
self.assertEqual(self.cache.get_with_expiry('two'), ('2', 120))
self.assertEqual(self.cache._metrics.hits, 5)
self.assertEqual(self.cache._metrics.misses, 0)