Merge branch 'anoa/room_dir_quick_fix' into matrix-org-hotfixes

This commit is contained in:
Andrew Morgan 2019-01-24 14:51:35 +00:00
commit acaca1b4e9
93 changed files with 2360 additions and 1003 deletions

15
.codecov.yml Normal file
View file

@ -0,0 +1,15 @@
comment:
layout: "diff"
coverage:
status:
project:
default:
target: 0 # Target % coverage, can be auto. Turned off for now
threshold: null
base: auto
patch:
default:
target: 0
threshold: null
base: auto

View file

@ -1,11 +1,7 @@
[run]
branch = True
parallel = True
source = synapse
[paths]
source=
coverage
include = synapse/*
[report]
precision = 2

6
.gitignore vendored
View file

@ -25,9 +25,9 @@ homeserver*.pid
*.tls.dh
*.tls.key
.coverage
.coverage.*
!.coverage.rc
.coverage*
coverage.*
!.coveragerc
htmlcov
demo/*/*.db

View file

@ -12,6 +12,9 @@ cache:
#
- $HOME/.cache/pip/wheels
addons:
postgresql: "9.4"
# don't clone the whole repo history, one commit will do
git:
depth: 1
@ -68,6 +71,13 @@ matrix:
install:
- pip install tox
# if we don't have python3.6 in this environment, travis unhelpfully gives us
# a `python3.6` on our path which does nothing but spit out a warning. Tox
# tries to run it (even if we're not running a py36 env), so the build logs
# then have warnings which look like errors. To reduce the noise, remove the
# non-functional python3.6.
- ( ! command -v python3.6 || python3.6 --version ) &>/dev/null || rm -f $(command -v python3.6)
script:
- tox -e $TOX_ENV

View file

@ -37,6 +37,7 @@ prune docker
prune .circleci
prune .coveragerc
prune debian
prune .codecov.yml
exclude jenkins*
recursive-exclude jenkins *.sh

View file

@ -184,7 +184,7 @@ Configuring Synapse
Before you can start Synapse, you will need to generate a configuration
file. To do this, run (in your virtualenv, as before)::
cd ~/.synapse
cd ~/synapse
python -m synapse.app.homeserver \
--server-name my.domain.name \
--config-path homeserver.yaml \
@ -220,7 +220,7 @@ is configured to use TLS with a self-signed certificate. If you would like
to do initial test with a client without having to setup a reverse proxy,
you can temporarly use another certificate. (Note that a self-signed
certificate is fine for `Federation`_). You can do so by changing
``tls_certificate_path``, ``tls_private_key_path`` and ``tls_dh_params_path``
``tls_certificate_path`` and ``tls_private_key_path``
in ``homeserver.yaml``; alternatively, you can use a reverse-proxy, but be sure
to read `Using a reverse proxy with Synapse`_ when doing so.
@ -796,8 +796,7 @@ A manual password reset can be done via direct database access as follows.
First calculate the hash of the new password::
$ source ~/.synapse/bin/activate
$ ./scripts/hash_password
$ ~/synapse/env/bin/hash_password
Password:
Confirm password:
$2a$12$xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx

1
changelog.d/4229.feature Normal file
View file

@ -0,0 +1 @@
Synapse's cipher string has been updated to require ECDH key exchange. Configuring and generating dh_params is no longer required, and they will be ignored.

1
changelog.d/4306.misc Normal file
View file

@ -0,0 +1 @@
Synapse will now take advantage of native UPSERT functionality in PostgreSQL 9.5+ and SQLite 3.24+.

1
changelog.d/4342.misc Normal file
View file

@ -0,0 +1 @@
Update README to use the new virtualenv everywhere

1
changelog.d/4368.misc Normal file
View file

@ -0,0 +1 @@
Add better logging for unexpected errors while sending transactions

1
changelog.d/4369.bugfix Normal file
View file

@ -0,0 +1 @@
Prevent users with access tokens predating the introduction of device IDs from creating spurious entries in the user_ips table.

1
changelog.d/4370.misc Normal file
View file

@ -0,0 +1 @@
Apply a unique index to the user_ips table, preventing duplicates.

1
changelog.d/4377.misc Normal file
View file

@ -0,0 +1 @@
Silence travis-ci build warnings by removing non-functional python3.6

1
changelog.d/4384.feature Normal file
View file

@ -0,0 +1 @@
Synapse can now automatically provision TLS certificates via ACME (the protocol used by CAs like Let's Encrypt).

1
changelog.d/4387.misc Normal file
View file

@ -0,0 +1 @@
Fix a comment in the generated config file

1
changelog.d/4390.misc Normal file
View file

@ -0,0 +1 @@
Add ground work for implementing future federation API versions

1
changelog.d/4392.bugfix Normal file
View file

@ -0,0 +1 @@
Fix typo in ALL_USER_TYPES definition to ensure type is a tuple

1
changelog.d/4397.bugfix Normal file
View file

@ -0,0 +1 @@
Fix high CPU usage due to remote devicelist updates

1
changelog.d/4399.misc Normal file
View file

@ -0,0 +1 @@
Update dependencies on msgpack and pymacaroons to use the up-to-date packages.

1
changelog.d/4400.misc Normal file
View file

@ -0,0 +1 @@
Tweak codecov settings to make them less loud.

1
changelog.d/4402.misc Normal file
View file

@ -0,0 +1 @@
Implement server support for MSC1794 - Federation v2 Invite API

1
changelog.d/4404.bugfix Normal file
View file

@ -0,0 +1 @@
Fix potential bug where creating or joining a room could fail

1
changelog.d/4407.bugfix Normal file
View file

@ -0,0 +1 @@
Fix incorrect logcontexts after a Deferred was cancelled

1
changelog.d/4408.misc Normal file
View file

@ -0,0 +1 @@
Refactor 'sign_request' as 'build_auth_headers'

1
changelog.d/4409.misc Normal file
View file

@ -0,0 +1 @@
Remove redundant federation connection wrapping code

1
changelog.d/4411.bugfix Normal file
View file

@ -0,0 +1 @@
Ensure encrypted room state is persisted across room upgrades.

1
changelog.d/4423.feature Normal file
View file

@ -0,0 +1 @@
Config option to disable requesting MSISDN on registration.

1
changelog.d/4426.misc Normal file
View file

@ -0,0 +1 @@
Remove redundant SynapseKeyClientProtocol magic

1
changelog.d/4427.misc Normal file
View file

@ -0,0 +1 @@
Refactor and cleanup for SRV record lookup

1
changelog.d/4428.misc Normal file
View file

@ -0,0 +1 @@
Move SRV logic into the Agent layer

1
changelog.d/4432.misc Normal file
View file

@ -0,0 +1 @@
Apply a unique index to the user_ips table, preventing duplicates.

1
changelog.d/4433.misc Normal file
View file

@ -0,0 +1 @@
debian package: symlink to explicit python version

1
changelog.d/4434.misc Normal file
View file

@ -0,0 +1 @@
Apply a unique index to the user_ips table, preventing duplicates.

1
changelog.d/4445.feature Normal file
View file

@ -0,0 +1 @@
Add a metric for tracking event stream position of the user directory.

1
changelog.d/4452.bugfix Normal file
View file

@ -0,0 +1 @@
Don't send IP addresses as SNI

1
changelog.d/4461.bugfix Normal file
View file

@ -0,0 +1 @@
Add a timeout to filtered room directory queries.

View file

@ -6,7 +6,16 @@
set -e
export DH_VIRTUALENV_INSTALL_ROOT=/opt/venvs
SNAKE=/usr/bin/python3
# make sure that the virtualenv links to the specific version of python, by
# dereferencing the python3 symlink.
#
# Otherwise, if somebody tries to install (say) the stretch package on buster,
# they will get a confusing error about "No module named 'synapse'", because
# python won't look in the right directory. At least this way, the error will
# be a *bit* more obvious.
#
SNAKE=`readlink -e /usr/bin/python3`
# try to set the CFLAGS so any compiled C extensions are compiled with the most
# generic as possible x64 instructions, so that compiling it on a new Intel chip
@ -46,3 +55,7 @@ cp -r tests "$tmpdir"
PYTHONPATH="$tmpdir" \
debian/matrix-synapse-py3/opt/venvs/matrix-synapse/bin/python \
-B -m twisted.trial --reporter=text -j2 tests
# add a dependency on the right version of python to substvars.
PYPKG=`basename $SNAKE`
echo "synapse:pydepends=$PYPKG" >> debian/matrix-synapse-py3.substvars

12
debian/changelog vendored
View file

@ -1,3 +1,15 @@
matrix-synapse-py3 (0.34.1.1++1) stable; urgency=medium
* Update conflicts specifications to allow smoother transition from matrix-synapse.
-- Synapse Packaging team <packages@matrix.org> Sat, 12 Jan 2019 12:58:35 +0000
matrix-synapse-py3 (0.34.1.1) stable; urgency=high
* New synapse release 0.34.1.1
-- Synapse Packaging team <packages@matrix.org> Thu, 10 Jan 2019 15:04:52 +0000
matrix-synapse-py3 (0.34.1+1) stable; urgency=medium
* Remove 'Breaks: matrix-synapse-ldap3'. (matrix-synapse-py3 includes

8
debian/control vendored
View file

@ -19,16 +19,16 @@ Homepage: https://github.com/matrix-org/synapse
Package: matrix-synapse-py3
Architecture: amd64
Provides: matrix-synapse
Breaks:
matrix-synapse (<< 0.34.0-0matrix2),
matrix-synapse (>= 0.34.0-1),
Conflicts:
matrix-synapse (<< 0.34.0.1-0matrix2),
matrix-synapse (>= 0.34.0.1-1),
Pre-Depends: dpkg (>= 1.16.1)
Depends:
adduser,
debconf,
python3-distutils|libpython3-stdlib (<< 3.6),
python3,
${misc:Depends},
${synapse:pydepends},
# some of our scripts use perl, but none of them are important,
# so we put perl:Depends in Suggests rather than Depends.
Suggests:

View file

@ -9,9 +9,6 @@ tls_certificate_path: "/etc/matrix-synapse/homeserver.tls.crt"
# PEM encoded private key for TLS
tls_private_key_path: "/etc/matrix-synapse/homeserver.tls.key"
# PEM dh parameters for ephemeral keys
tls_dh_params_path: "/etc/matrix-synapse/homeserver.tls.dh"
# Don't bind to the https port
no_tls: False

View file

@ -1,9 +0,0 @@
2048-bit DH parameters taken from rfc3526
-----BEGIN DH PARAMETERS-----
MIIBCAKCAQEA///////////JD9qiIWjCNMTGYouA3BzRKQJOCIpnzHQCC76mOxOb
IlFKCHmONATd75UZs806QxswKwpt8l8UN0/hNW1tUcJF5IW1dmJefsb0TELppjft
awv/XLb0Brft7jhr+1qJn6WunyQRfEsf5kkoZlHs5Fs9wgB8uKFjvwWY2kg2HFXT
mmkWP6j9JM9fg2VdI9yjrZYcYvNWIIVSu57VKQdwlpZtZww1Tkq8mATxdGwIyhgh
fDKQXkYuNs474553LBgOhgObJ4Oi7Aeij7XFXfBvTFLJ3ivL9pVYFxg5lUl86pVq
5RXSJhiY+gUQFXKOWoqsqmj//////////wIBAg==
-----END DH PARAMETERS-----

View file

@ -1,46 +0,0 @@
#!/bin/bash
# Build the Debian packages using Docker images.
#
# This script builds the Docker images and then executes them sequentially, each
# one building a Debian package for the targeted operating system. It is
# designed to be a "single command" to produce all the images.
#
# By default, builds for all known distributions, but a list of distributions
# can be passed on the commandline for debugging.
set -ex
cd `dirname $0`
if [ $# -lt 1 ]; then
DISTS=(
debian:stretch
debian:buster
debian:sid
ubuntu:xenial
ubuntu:bionic
ubuntu:cosmic
)
else
DISTS=("$@")
fi
# Make the dir where the debs will live.
#
# Note that we deliberately put this outside the source tree, otherwise we tend
# to get source packages which are full of debs. (We could hack around that
# with more magic in the build_debian.sh script, but that doesn't solve the
# problem for natively-run dpkg-buildpakage).
mkdir -p ../../debs
# Build each OS image;
for i in "${DISTS[@]}"; do
TAG=$(echo ${i} | cut -d ":" -f 2)
docker build --tag dh-venv-builder:${TAG} --build-arg distro=${i} -f Dockerfile-dhvirtualenv .
docker run -it --rm --volume=$(pwd)/../\:/synapse/source:ro --volume=$(pwd)/../../debs:/debs \
-e TARGET_USERID=$(id -u) \
-e TARGET_GROUPID=$(id -g) \
dh-venv-builder:${TAG}
done

View file

@ -4,7 +4,6 @@
tls_certificate_path: "/data/{{ SYNAPSE_SERVER_NAME }}.tls.crt"
tls_private_key_path: "/data/{{ SYNAPSE_SERVER_NAME }}.tls.key"
tls_dh_params_path: "/data/{{ SYNAPSE_SERVER_NAME }}.tls.dh"
no_tls: {{ "True" if SYNAPSE_NO_TLS else "False" }}
tls_fingerprints: []

154
scripts-dev/build_debian_packages Executable file
View file

@ -0,0 +1,154 @@
#!/usr/bin/env python3
# Build the Debian packages using Docker images.
#
# This script builds the Docker images and then executes them sequentially, each
# one building a Debian package for the targeted operating system. It is
# designed to be a "single command" to produce all the images.
#
# By default, builds for all known distributions, but a list of distributions
# can be passed on the commandline for debugging.
import argparse
import os
import signal
import subprocess
import sys
import threading
from concurrent.futures import ThreadPoolExecutor
DISTS = (
"debian:stretch",
"debian:buster",
"debian:sid",
"ubuntu:xenial",
"ubuntu:bionic",
"ubuntu:cosmic",
)
DESC = '''\
Builds .debs for synapse, using a Docker image for the build environment.
By default, builds for all known distributions, but a list of distributions
can be passed on the commandline for debugging.
'''
class Builder(object):
def __init__(self, redirect_stdout=False):
self.redirect_stdout = redirect_stdout
self.active_containers = set()
self._lock = threading.Lock()
self._failed = False
def run_build(self, dist):
"""Build deb for a single distribution"""
if self._failed:
print("not building %s due to earlier failure" % (dist, ))
raise Exception("failed")
try:
self._inner_build(dist)
except Exception as e:
print("build of %s failed: %s" % (dist, e), file=sys.stderr)
self._failed = True
raise
def _inner_build(self, dist):
projdir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
os.chdir(projdir)
tag = dist.split(":", 1)[1]
# Make the dir where the debs will live.
#
# Note that we deliberately put this outside the source tree, otherwise
# we tend to get source packages which are full of debs. (We could hack
# around that with more magic in the build_debian.sh script, but that
# doesn't solve the problem for natively-run dpkg-buildpakage).
debsdir = os.path.join(projdir, '../debs')
os.makedirs(debsdir, exist_ok=True)
if self.redirect_stdout:
logfile = os.path.join(debsdir, "%s.buildlog" % (tag, ))
print("building %s: directing output to %s" % (dist, logfile))
stdout = open(logfile, "w")
else:
stdout = None
# first build a docker image for the build environment
subprocess.check_call([
"docker", "build",
"--tag", "dh-venv-builder:" + tag,
"--build-arg", "distro=" + dist,
"-f", "docker/Dockerfile-dhvirtualenv",
"docker",
], stdout=stdout, stderr=subprocess.STDOUT)
container_name = "synapse_build_" + tag
with self._lock:
self.active_containers.add(container_name)
# then run the build itself
subprocess.check_call([
"docker", "run",
"--rm",
"--name", container_name,
"--volume=" + projdir + ":/synapse/source:ro",
"--volume=" + debsdir + ":/debs",
"-e", "TARGET_USERID=%i" % (os.getuid(), ),
"-e", "TARGET_GROUPID=%i" % (os.getgid(), ),
"dh-venv-builder:" + tag,
], stdout=stdout, stderr=subprocess.STDOUT)
with self._lock:
self.active_containers.remove(container_name)
if stdout is not None:
stdout.close()
print("Completed build of %s" % (dist, ))
def kill_containers(self):
with self._lock:
active = list(self.active_containers)
for c in active:
print("killing container %s" % (c,))
subprocess.run([
"docker", "kill", c,
], stdout=subprocess.DEVNULL)
with self._lock:
self.active_containers.remove(c)
def run_builds(dists, jobs=1):
builder = Builder(redirect_stdout=(jobs > 1))
def sig(signum, _frame):
print("Caught SIGINT")
builder.kill_containers()
signal.signal(signal.SIGINT, sig)
with ThreadPoolExecutor(max_workers=jobs) as e:
res = e.map(builder.run_build, dists)
# make sure we consume the iterable so that exceptions are raised.
for r in res:
pass
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description=DESC,
)
parser.add_argument(
'-j', '--jobs', type=int, default=1,
help='specify the number of builds to run in parallel',
)
parser.add_argument(
'dist', nargs='*', default=DISTS,
help='a list of distributions to build for. Default: %(default)s',
)
args = parser.parse_args()
run_builds(dists=args.dist, jobs=args.jobs)

View file

@ -68,6 +68,7 @@ class EventTypes(object):
Aliases = "m.room.aliases"
Redaction = "m.room.redaction"
ThirdPartyInvite = "m.room.third_party_invite"
Encryption = "m.room.encryption"
RoomHistoryVisibility = "m.room.history_visibility"
CanonicalAlias = "m.room.canonical_alias"
@ -128,4 +129,4 @@ class UserTypes(object):
'admin' and 'guest' users should also be UserTypes. Normal users are type None
"""
SUPPORT = "support"
ALL_USER_TYPES = (SUPPORT)
ALL_USER_TYPES = (SUPPORT,)

View file

@ -24,7 +24,9 @@ from synapse.config import ConfigError
CLIENT_PREFIX = "/_matrix/client/api/v1"
CLIENT_V2_ALPHA_PREFIX = "/_matrix/client/v2_alpha"
FEDERATION_PREFIX = "/_matrix/federation/v1"
FEDERATION_PREFIX = "/_matrix/federation"
FEDERATION_V1_PREFIX = FEDERATION_PREFIX + "/v1"
FEDERATION_V2_PREFIX = FEDERATION_PREFIX + "/v2"
STATIC_PREFIX = "/_matrix/static"
WEB_CLIENT_PREFIX = "/_matrix/client"
CONTENT_REPO_PREFIX = "/_matrix/content"

View file

@ -13,10 +13,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import logging
import os
import sys
import traceback
from six import iteritems
@ -324,17 +326,12 @@ def setup(config_options):
events.USE_FROZEN_DICTS = config.use_frozen_dicts
tls_server_context_factory = context_factory.ServerContextFactory(config)
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
database_engine = create_engine(config.database_config)
config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection
hs = SynapseHomeServer(
config.server_name,
db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory,
tls_client_options_factory=tls_client_options_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
@ -361,12 +358,53 @@ def setup(config_options):
logger.info("Database prepared in %s.", config.database_config['name'])
hs.setup()
hs.start_listening()
@defer.inlineCallbacks
def start():
hs.get_pusherpool().start()
hs.get_datastore().start_profiling()
hs.get_datastore().start_doing_background_updates()
try:
# Check if the certificate is still valid.
cert_days_remaining = hs.config.is_disk_cert_valid()
if hs.config.acme_enabled:
# If ACME is enabled, we might need to provision a certificate
# before starting.
acme = hs.get_acme_handler()
# Start up the webservices which we will respond to ACME
# challenges with.
yield acme.start_listening()
# We want to reprovision if cert_days_remaining is None (meaning no
# certificate exists), or the days remaining number it returns
# is less than our re-registration threshold.
if (cert_days_remaining is None) or (
not cert_days_remaining > hs.config.acme_reprovision_threshold
):
yield acme.provision_certificate()
# Read the certificate from disk and build the context factories for
# TLS.
hs.config.read_certificate_from_disk()
hs.tls_server_context_factory = context_factory.ServerContextFactory(config)
hs.tls_client_options_factory = context_factory.ClientTLSOptionsFactory(
config
)
# It is now safe to start your Synapse.
hs.start_listening()
hs.get_pusherpool().start()
hs.get_datastore().start_profiling()
hs.get_datastore().start_doing_background_updates()
except Exception as e:
# If a DeferredList failed (like in listening on the ACME listener),
# we need to print the subfailure explicitly.
if isinstance(e, defer.FirstError):
e.subFailure.printTraceback(sys.stderr)
sys.exit(1)
# Something else went wrong when starting. Print it and bail out.
traceback.print_exc(file=sys.stderr)
sys.exit(1)
reactor.callWhenRunning(start)

View file

@ -367,7 +367,7 @@ class Config(object):
if not keys_directory:
keys_directory = os.path.dirname(config_files[-1])
config_dir_path = os.path.abspath(keys_directory)
self.config_dir_path = os.path.abspath(keys_directory)
specified_config = {}
for config_file in config_files:
@ -379,7 +379,7 @@ class Config(object):
server_name = specified_config["server_name"]
config_string = self.generate_config(
config_dir_path=config_dir_path,
config_dir_path=self.config_dir_path,
data_dir_path=os.getcwd(),
server_name=server_name,
generate_secrets=False,

View file

@ -83,9 +83,6 @@ class KeyConfig(Config):
# a secret which is used to sign access tokens. If none is specified,
# the registration_shared_secret is used, if one is given; otherwise,
# a secret key is derived from the signing key.
#
# Note that changing this will invalidate any active access tokens, so
# all clients will have to log back in.
%(macaroon_secret_key)s
# Used to enable access token expiration.

View file

@ -50,6 +50,10 @@ class RegistrationConfig(Config):
raise ConfigError('Invalid auto_join_rooms entry %s' % (room_alias,))
self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True)
self.disable_msisdn_registration = (
config.get("disable_msisdn_registration", False)
)
def default_config(self, generate_secrets=False, **kwargs):
if generate_secrets:
registration_shared_secret = 'registration_shared_secret: "%s"' % (
@ -70,6 +74,11 @@ class RegistrationConfig(Config):
# - email
# - msisdn
# Explicitly disable asking for MSISDNs from the registration
# flow (overrides registrations_require_3pid if MSISDNs are set as required)
#
# disable_msisdn_registration = True
# Mandate that users are only allowed to associate certain formats of
# 3PIDs with accounts on this server.
#

View file

@ -13,52 +13,38 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import subprocess
from datetime import datetime
from hashlib import sha256
from unpaddedbase64 import encode_base64
from OpenSSL import crypto
from ._base import Config
from synapse.config._base import Config
GENERATE_DH_PARAMS = False
logger = logging.getLogger()
class TlsConfig(Config):
def read_config(self, config):
self.tls_certificate = self.read_tls_certificate(
config.get("tls_certificate_path")
)
self.tls_certificate_file = config.get("tls_certificate_path")
acme_config = config.get("acme", {})
self.acme_enabled = acme_config.get("enabled", False)
self.acme_url = acme_config.get(
"url", "https://acme-v01.api.letsencrypt.org/directory"
)
self.acme_port = acme_config.get("port", 8449)
self.acme_bind_addresses = acme_config.get("bind_addresses", ["127.0.0.1"])
self.acme_reprovision_threshold = acme_config.get("reprovision_threshold", 30)
self.tls_certificate_file = os.path.abspath(config.get("tls_certificate_path"))
self.tls_private_key_file = os.path.abspath(config.get("tls_private_key_path"))
self._original_tls_fingerprints = config["tls_fingerprints"]
self.tls_fingerprints = list(self._original_tls_fingerprints)
self.no_tls = config.get("no_tls", False)
if self.no_tls:
self.tls_private_key = None
else:
self.tls_private_key = self.read_tls_private_key(
config.get("tls_private_key_path")
)
self.tls_dh_params_path = self.check_file(
config.get("tls_dh_params_path"), "tls_dh_params"
)
self.tls_fingerprints = config["tls_fingerprints"]
# Check that our own certificate is included in the list of fingerprints
# and include it if it is not.
x509_certificate_bytes = crypto.dump_certificate(
crypto.FILETYPE_ASN1,
self.tls_certificate
)
sha256_fingerprint = encode_base64(sha256(x509_certificate_bytes).digest())
sha256_fingerprints = set(f["sha256"] for f in self.tls_fingerprints)
if sha256_fingerprint not in sha256_fingerprints:
self.tls_fingerprints.append({u"sha256": sha256_fingerprint})
# This config option applies to non-federation HTTP clients
# (e.g. for talking to recaptcha, identity servers, and such)
# It should never be used in production, and is intended for
@ -67,14 +53,70 @@ class TlsConfig(Config):
"use_insecure_ssl_client_just_for_testing_do_not_use"
)
self.tls_certificate = None
self.tls_private_key = None
def is_disk_cert_valid(self):
"""
Is the certificate we have on disk valid, and if so, for how long?
Returns:
int: Days remaining of certificate validity.
None: No certificate exists.
"""
if not os.path.exists(self.tls_certificate_file):
return None
try:
with open(self.tls_certificate_file, 'rb') as f:
cert_pem = f.read()
except Exception:
logger.exception("Failed to read existing certificate off disk!")
raise
try:
tls_certificate = crypto.load_certificate(crypto.FILETYPE_PEM, cert_pem)
except Exception:
logger.exception("Failed to parse existing certificate off disk!")
raise
# YYYYMMDDhhmmssZ -- in UTC
expires_on = datetime.strptime(
tls_certificate.get_notAfter().decode('ascii'), "%Y%m%d%H%M%SZ"
)
now = datetime.utcnow()
days_remaining = (expires_on - now).days
return days_remaining
def read_certificate_from_disk(self):
"""
Read the certificates from disk.
"""
self.tls_certificate = self.read_tls_certificate(self.tls_certificate_file)
if not self.no_tls:
self.tls_private_key = self.read_tls_private_key(self.tls_private_key_file)
self.tls_fingerprints = list(self._original_tls_fingerprints)
# Check that our own certificate is included in the list of fingerprints
# and include it if it is not.
x509_certificate_bytes = crypto.dump_certificate(
crypto.FILETYPE_ASN1, self.tls_certificate
)
sha256_fingerprint = encode_base64(sha256(x509_certificate_bytes).digest())
sha256_fingerprints = set(f["sha256"] for f in self.tls_fingerprints)
if sha256_fingerprint not in sha256_fingerprints:
self.tls_fingerprints.append({u"sha256": sha256_fingerprint})
def default_config(self, config_dir_path, server_name, **kwargs):
base_key_name = os.path.join(config_dir_path, server_name)
tls_certificate_path = base_key_name + ".tls.crt"
tls_private_key_path = base_key_name + ".tls.key"
tls_dh_params_path = base_key_name + ".tls.dh"
return """\
return (
"""\
# PEM encoded X509 certificate for TLS.
# You can replace the self-signed certificate that synapse
# autogenerates on launch with your own SSL certificate + key pair
@ -85,9 +127,6 @@ class TlsConfig(Config):
# PEM encoded private key for TLS
tls_private_key_path: "%(tls_private_key_path)s"
# PEM dh parameters for ephemeral keys
tls_dh_params_path: "%(tls_dh_params_path)s"
# Don't bind to the https port
no_tls: False
@ -118,7 +157,24 @@ class TlsConfig(Config):
#
tls_fingerprints: []
# tls_fingerprints: [{"sha256": "<base64_encoded_sha256_fingerprint>"}]
""" % locals()
## Support for ACME certificate auto-provisioning.
# acme:
# enabled: false
## ACME path.
## If you only want to test, use the staging url:
## https://acme-staging.api.letsencrypt.org/directory
# url: 'https://acme-v01.api.letsencrypt.org/directory'
## Port number (to listen for the HTTP-01 challenge).
## Using port 80 requires utilising something like authbind, or proxying to it.
# port: 8449
## Hosts to bind to.
# bind_addresses: ['127.0.0.1']
## How many days remaining on a certificate before it is renewed.
# reprovision_threshold: 30
"""
% locals()
)
def read_tls_certificate(self, cert_path):
cert_pem = self.read_file(cert_path, "tls_certificate")
@ -131,7 +187,6 @@ class TlsConfig(Config):
def generate_files(self, config):
tls_certificate_path = config["tls_certificate_path"]
tls_private_key_path = config["tls_private_key_path"]
tls_dh_params_path = config["tls_dh_params_path"]
if not self.path_exists(tls_private_key_path):
with open(tls_private_key_path, "wb") as private_key_file:
@ -165,31 +220,3 @@ class TlsConfig(Config):
cert_pem = crypto.dump_certificate(crypto.FILETYPE_PEM, cert)
certificate_file.write(cert_pem)
if not self.path_exists(tls_dh_params_path):
if GENERATE_DH_PARAMS:
subprocess.check_call([
"openssl", "dhparam",
"-outform", "PEM",
"-out", tls_dh_params_path,
"2048"
])
else:
with open(tls_dh_params_path, "w") as dh_params_file:
dh_params_file.write(
"2048-bit DH parameters taken from rfc3526\n"
"-----BEGIN DH PARAMETERS-----\n"
"MIIBCAKCAQEA///////////JD9qiIWjC"
"NMTGYouA3BzRKQJOCIpnzHQCC76mOxOb\n"
"IlFKCHmONATd75UZs806QxswKwpt8l8U"
"N0/hNW1tUcJF5IW1dmJefsb0TELppjft\n"
"awv/XLb0Brft7jhr+1qJn6WunyQRfEsf"
"5kkoZlHs5Fs9wgB8uKFjvwWY2kg2HFXT\n"
"mmkWP6j9JM9fg2VdI9yjrZYcYvNWIIVS"
"u57VKQdwlpZtZww1Tkq8mATxdGwIyhgh\n"
"fDKQXkYuNs474553LBgOhgObJ4Oi7Aei"
"j7XFXfBvTFLJ3ivL9pVYFxg5lUl86pVq\n"
"5RXSJhiY+gUQFXKOWoqsqmj/////////"
"/wIBAg==\n"
"-----END DH PARAMETERS-----\n"
)

View file

@ -17,6 +17,7 @@ from zope.interface import implementer
from OpenSSL import SSL, crypto
from twisted.internet._sslverify import _defaultCurveName
from twisted.internet.abstract import isIPAddress, isIPv6Address
from twisted.internet.interfaces import IOpenSSLClientConnectionCreator
from twisted.internet.ssl import CertificateOptions, ContextFactory
from twisted.python.failure import Failure
@ -46,8 +47,10 @@ class ServerContextFactory(ContextFactory):
if not config.no_tls:
context.use_privatekey(config.tls_private_key)
context.load_tmp_dh(config.tls_dh_params_path)
context.set_cipher_list("!ADH:HIGH+kEDH:!AECDH:HIGH+kEECDH")
# https://hynek.me/articles/hardening-your-web-servers-ssl-ciphers/
context.set_cipher_list(
"ECDH+AESGCM:ECDH+CHACHA20:ECDH+AES256:ECDH+AES128:!aNULL:!SHA1"
)
def getContext(self):
return self._context
@ -96,8 +99,14 @@ class ClientTLSOptions(object):
def __init__(self, hostname, ctx):
self._ctx = ctx
self._hostname = hostname
self._hostnameBytes = _idnaBytes(hostname)
if isIPAddress(hostname) or isIPv6Address(hostname):
self._hostnameBytes = hostname.encode('ascii')
self._sendSNI = False
else:
self._hostnameBytes = _idnaBytes(hostname)
self._sendSNI = True
ctx.set_info_callback(
_tolerateErrors(self._identityVerifyingInfoCallback)
)
@ -109,7 +118,9 @@ class ClientTLSOptions(object):
return connection
def _identityVerifyingInfoCallback(self, connection, where, ret):
if where & SSL.SSL_CB_HANDSHAKE_START:
# Literal IPv4 and IPv6 addresses are not permitted
# as host names according to the RFCs
if where & SSL.SSL_CB_HANDSHAKE_START and self._sendSNI:
connection.set_tlsext_host_name(self._hostnameBytes)

View file

@ -1,149 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from six.moves import urllib
from canonicaljson import json
from twisted.internet import defer, reactor
from twisted.internet.error import ConnectError
from twisted.internet.protocol import Factory
from twisted.names.error import DomainError
from twisted.web.http import HTTPClient
from synapse.http.endpoint import matrix_federation_endpoint
from synapse.util import logcontext
logger = logging.getLogger(__name__)
KEY_API_V2 = "/_matrix/key/v2/server/%s"
@defer.inlineCallbacks
def fetch_server_key(server_name, tls_client_options_factory, key_id):
"""Fetch the keys for a remote server."""
factory = SynapseKeyClientFactory()
factory.path = KEY_API_V2 % (urllib.parse.quote(key_id), )
factory.host = server_name
endpoint = matrix_federation_endpoint(
reactor, server_name, tls_client_options_factory, timeout=30
)
for i in range(5):
try:
with logcontext.PreserveLoggingContext():
protocol = yield endpoint.connect(factory)
server_response, server_certificate = yield protocol.remote_key
defer.returnValue((server_response, server_certificate))
except SynapseKeyClientError as e:
logger.warn("Error getting key for %r: %s", server_name, e)
if e.status.startswith(b"4"):
# Don't retry for 4xx responses.
raise IOError("Cannot get key for %r" % server_name)
except (ConnectError, DomainError) as e:
logger.warn("Error getting key for %r: %s", server_name, e)
except Exception:
logger.exception("Error getting key for %r", server_name)
raise IOError("Cannot get key for %r" % server_name)
class SynapseKeyClientError(Exception):
"""The key wasn't retrieved from the remote server."""
status = None
pass
class SynapseKeyClientProtocol(HTTPClient):
"""Low level HTTPS client which retrieves an application/json response from
the server and extracts the X.509 certificate for the remote peer from the
SSL connection."""
timeout = 30
def __init__(self):
self.remote_key = defer.Deferred()
self.host = None
self._peer = None
def connectionMade(self):
self._peer = self.transport.getPeer()
logger.debug("Connected to %s", self._peer)
if not isinstance(self.path, bytes):
self.path = self.path.encode('ascii')
if not isinstance(self.host, bytes):
self.host = self.host.encode('ascii')
self.sendCommand(b"GET", self.path)
if self.host:
self.sendHeader(b"Host", self.host)
self.endHeaders()
self.timer = reactor.callLater(
self.timeout,
self.on_timeout
)
def errback(self, error):
if not self.remote_key.called:
self.remote_key.errback(error)
def callback(self, result):
if not self.remote_key.called:
self.remote_key.callback(result)
def handleStatus(self, version, status, message):
if status != b"200":
# logger.info("Non-200 response from %s: %s %s",
# self.transport.getHost(), status, message)
error = SynapseKeyClientError(
"Non-200 response %r from %r" % (status, self.host)
)
error.status = status
self.errback(error)
self.transport.abortConnection()
def handleResponse(self, response_body_bytes):
try:
json_response = json.loads(response_body_bytes)
except ValueError:
# logger.info("Invalid JSON response from %s",
# self.transport.getHost())
self.transport.abortConnection()
return
certificate = self.transport.getPeerCertificate()
self.callback((json_response, certificate))
self.transport.abortConnection()
self.timer.cancel()
def on_timeout(self):
logger.debug(
"Timeout waiting for response from %s: %s",
self.host, self._peer,
)
self.errback(IOError("Timeout waiting for response"))
self.transport.abortConnection()
class SynapseKeyClientFactory(Factory):
def protocol(self):
protocol = SynapseKeyClientProtocol()
protocol.path = self.path
protocol.host = self.host
return protocol

View file

@ -14,10 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
import logging
from collections import namedtuple
from six.moves import urllib
from signedjson.key import (
decode_verify_key_bytes,
encode_verify_key_base64,
@ -30,13 +31,11 @@ from signedjson.sign import (
signature_ids,
verify_signed_json,
)
from unpaddedbase64 import decode_base64, encode_base64
from unpaddedbase64 import decode_base64
from OpenSSL import crypto
from twisted.internet import defer
from synapse.api.errors import Codes, SynapseError
from synapse.crypto.keyclient import fetch_server_key
from synapse.util import logcontext, unwrapFirstError
from synapse.util.logcontext import (
LoggingContext,
@ -503,31 +502,16 @@ class Keyring(object):
if requested_key_id in keys:
continue
(response, tls_certificate) = yield fetch_server_key(
server_name, self.hs.tls_client_options_factory, requested_key_id
response = yield self.client.get_json(
destination=server_name,
path="/_matrix/key/v2/server/" + urllib.parse.quote(requested_key_id),
ignore_backoff=True,
)
if (u"signatures" not in response
or server_name not in response[u"signatures"]):
raise KeyLookupError("Key response not signed by remote server")
if "tls_fingerprints" not in response:
raise KeyLookupError("Key response missing TLS fingerprints")
certificate_bytes = crypto.dump_certificate(
crypto.FILETYPE_ASN1, tls_certificate
)
sha256_fingerprint = hashlib.sha256(certificate_bytes).digest()
sha256_fingerprint_b64 = encode_base64(sha256_fingerprint)
response_sha256_fingerprints = set()
for fingerprint in response[u"tls_fingerprints"]:
if u"sha256" in fingerprint:
response_sha256_fingerprints.add(fingerprint[u"sha256"])
if sha256_fingerprint_b64 not in response_sha256_fingerprints:
raise KeyLookupError("TLS certificate not allowed by fingerprints")
response_keys = yield self.process_v2_response(
from_server=server_name,
requested_ids=[requested_key_id],

View file

@ -369,13 +369,13 @@ class FederationServer(FederationBase):
})
@defer.inlineCallbacks
def on_invite_request(self, origin, content):
def on_invite_request(self, origin, content, room_version):
pdu = event_from_pdu_json(content)
origin_host, _ = parse_server_name(origin)
yield self.check_server_matches_acl(origin_host, pdu.room_id)
ret_pdu = yield self.handler.on_invite_request(origin, pdu)
time_now = self._clock.time_msec()
defer.returnValue((200, {"event": ret_pdu.get_pdu_json(time_now)}))
defer.returnValue({"event": ret_pdu.get_pdu_json(time_now)})
@defer.inlineCallbacks
def on_send_join_request(self, origin, content):

View file

@ -21,7 +21,7 @@ from six.moves import urllib
from twisted.internet import defer
from synapse.api.constants import Membership
from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.api.urls import FEDERATION_V1_PREFIX
from synapse.util.logutils import log_function
logger = logging.getLogger(__name__)
@ -51,7 +51,7 @@ class TransportLayerClient(object):
logger.debug("get_room_state dest=%s, room=%s",
destination, room_id)
path = _create_path(PREFIX, "/state/%s/", room_id)
path = _create_v1_path("/state/%s/", room_id)
return self.client.get_json(
destination, path=path, args={"event_id": event_id},
)
@ -73,7 +73,7 @@ class TransportLayerClient(object):
logger.debug("get_room_state_ids dest=%s, room=%s",
destination, room_id)
path = _create_path(PREFIX, "/state_ids/%s/", room_id)
path = _create_v1_path("/state_ids/%s/", room_id)
return self.client.get_json(
destination, path=path, args={"event_id": event_id},
)
@ -95,7 +95,7 @@ class TransportLayerClient(object):
logger.debug("get_pdu dest=%s, event_id=%s",
destination, event_id)
path = _create_path(PREFIX, "/event/%s/", event_id)
path = _create_v1_path("/event/%s/", event_id)
return self.client.get_json(destination, path=path, timeout=timeout)
@log_function
@ -121,7 +121,7 @@ class TransportLayerClient(object):
# TODO: raise?
return
path = _create_path(PREFIX, "/backfill/%s/", room_id)
path = _create_v1_path("/backfill/%s/", room_id)
args = {
"v": event_tuples,
@ -167,7 +167,7 @@ class TransportLayerClient(object):
# generated by the json_data_callback.
json_data = transaction.get_dict()
path = _create_path(PREFIX, "/send/%s/", transaction.transaction_id)
path = _create_v1_path("/send/%s/", transaction.transaction_id)
response = yield self.client.put_json(
transaction.destination,
@ -184,7 +184,7 @@ class TransportLayerClient(object):
@log_function
def make_query(self, destination, query_type, args, retry_on_dns_fail,
ignore_backoff=False):
path = _create_path(PREFIX, "/query/%s", query_type)
path = _create_v1_path("/query/%s", query_type)
content = yield self.client.get_json(
destination=destination,
@ -231,7 +231,7 @@ class TransportLayerClient(object):
"make_membership_event called with membership='%s', must be one of %s" %
(membership, ",".join(valid_memberships))
)
path = _create_path(PREFIX, "/make_%s/%s/%s", membership, room_id, user_id)
path = _create_v1_path("/make_%s/%s/%s", membership, room_id, user_id)
ignore_backoff = False
retry_on_dns_fail = False
@ -258,7 +258,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
def send_join(self, destination, room_id, event_id, content):
path = _create_path(PREFIX, "/send_join/%s/%s", room_id, event_id)
path = _create_v1_path("/send_join/%s/%s", room_id, event_id)
response = yield self.client.put_json(
destination=destination,
@ -271,7 +271,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
def send_leave(self, destination, room_id, event_id, content):
path = _create_path(PREFIX, "/send_leave/%s/%s", room_id, event_id)
path = _create_v1_path("/send_leave/%s/%s", room_id, event_id)
response = yield self.client.put_json(
destination=destination,
@ -290,7 +290,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
def send_invite(self, destination, room_id, event_id, content):
path = _create_path(PREFIX, "/invite/%s/%s", room_id, event_id)
path = _create_v1_path("/invite/%s/%s", room_id, event_id)
response = yield self.client.put_json(
destination=destination,
@ -306,7 +306,7 @@ class TransportLayerClient(object):
def get_public_rooms(self, remote_server, limit, since_token,
search_filter=None, include_all_networks=False,
third_party_instance_id=None):
path = PREFIX + "/publicRooms"
path = _create_v1_path("/publicRooms")
args = {
"include_all_networks": "true" if include_all_networks else "false",
@ -332,7 +332,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
def exchange_third_party_invite(self, destination, room_id, event_dict):
path = _create_path(PREFIX, "/exchange_third_party_invite/%s", room_id,)
path = _create_v1_path("/exchange_third_party_invite/%s", room_id,)
response = yield self.client.put_json(
destination=destination,
@ -345,7 +345,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
def get_event_auth(self, destination, room_id, event_id):
path = _create_path(PREFIX, "/event_auth/%s/%s", room_id, event_id)
path = _create_v1_path("/event_auth/%s/%s", room_id, event_id)
content = yield self.client.get_json(
destination=destination,
@ -357,7 +357,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
def send_query_auth(self, destination, room_id, event_id, content):
path = _create_path(PREFIX, "/query_auth/%s/%s", room_id, event_id)
path = _create_v1_path("/query_auth/%s/%s", room_id, event_id)
content = yield self.client.post_json(
destination=destination,
@ -392,7 +392,7 @@ class TransportLayerClient(object):
Returns:
A dict containg the device keys.
"""
path = PREFIX + "/user/keys/query"
path = _create_v1_path("/user/keys/query")
content = yield self.client.post_json(
destination=destination,
@ -419,7 +419,7 @@ class TransportLayerClient(object):
Returns:
A dict containg the device keys.
"""
path = _create_path(PREFIX, "/user/devices/%s", user_id)
path = _create_v1_path("/user/devices/%s", user_id)
content = yield self.client.get_json(
destination=destination,
@ -455,7 +455,7 @@ class TransportLayerClient(object):
A dict containg the one-time keys.
"""
path = PREFIX + "/user/keys/claim"
path = _create_v1_path("/user/keys/claim")
content = yield self.client.post_json(
destination=destination,
@ -469,7 +469,7 @@ class TransportLayerClient(object):
@log_function
def get_missing_events(self, destination, room_id, earliest_events,
latest_events, limit, min_depth, timeout):
path = _create_path(PREFIX, "/get_missing_events/%s", room_id,)
path = _create_v1_path("/get_missing_events/%s", room_id,)
content = yield self.client.post_json(
destination=destination,
@ -489,7 +489,7 @@ class TransportLayerClient(object):
def get_group_profile(self, destination, group_id, requester_user_id):
"""Get a group profile
"""
path = _create_path(PREFIX, "/groups/%s/profile", group_id,)
path = _create_v1_path("/groups/%s/profile", group_id,)
return self.client.get_json(
destination=destination,
@ -508,7 +508,7 @@ class TransportLayerClient(object):
requester_user_id (str)
content (dict): The new profile of the group
"""
path = _create_path(PREFIX, "/groups/%s/profile", group_id,)
path = _create_v1_path("/groups/%s/profile", group_id,)
return self.client.post_json(
destination=destination,
@ -522,7 +522,7 @@ class TransportLayerClient(object):
def get_group_summary(self, destination, group_id, requester_user_id):
"""Get a group summary
"""
path = _create_path(PREFIX, "/groups/%s/summary", group_id,)
path = _create_v1_path("/groups/%s/summary", group_id,)
return self.client.get_json(
destination=destination,
@ -535,7 +535,7 @@ class TransportLayerClient(object):
def get_rooms_in_group(self, destination, group_id, requester_user_id):
"""Get all rooms in a group
"""
path = _create_path(PREFIX, "/groups/%s/rooms", group_id,)
path = _create_v1_path("/groups/%s/rooms", group_id,)
return self.client.get_json(
destination=destination,
@ -548,7 +548,7 @@ class TransportLayerClient(object):
content):
"""Add a room to a group
"""
path = _create_path(PREFIX, "/groups/%s/room/%s", group_id, room_id,)
path = _create_v1_path("/groups/%s/room/%s", group_id, room_id,)
return self.client.post_json(
destination=destination,
@ -562,8 +562,8 @@ class TransportLayerClient(object):
config_key, content):
"""Update room in group
"""
path = _create_path(
PREFIX, "/groups/%s/room/%s/config/%s",
path = _create_v1_path(
"/groups/%s/room/%s/config/%s",
group_id, room_id, config_key,
)
@ -578,7 +578,7 @@ class TransportLayerClient(object):
def remove_room_from_group(self, destination, group_id, requester_user_id, room_id):
"""Remove a room from a group
"""
path = _create_path(PREFIX, "/groups/%s/room/%s", group_id, room_id,)
path = _create_v1_path("/groups/%s/room/%s", group_id, room_id,)
return self.client.delete_json(
destination=destination,
@ -591,7 +591,7 @@ class TransportLayerClient(object):
def get_users_in_group(self, destination, group_id, requester_user_id):
"""Get users in a group
"""
path = _create_path(PREFIX, "/groups/%s/users", group_id,)
path = _create_v1_path("/groups/%s/users", group_id,)
return self.client.get_json(
destination=destination,
@ -604,7 +604,7 @@ class TransportLayerClient(object):
def get_invited_users_in_group(self, destination, group_id, requester_user_id):
"""Get users that have been invited to a group
"""
path = _create_path(PREFIX, "/groups/%s/invited_users", group_id,)
path = _create_v1_path("/groups/%s/invited_users", group_id,)
return self.client.get_json(
destination=destination,
@ -617,8 +617,8 @@ class TransportLayerClient(object):
def accept_group_invite(self, destination, group_id, user_id, content):
"""Accept a group invite
"""
path = _create_path(
PREFIX, "/groups/%s/users/%s/accept_invite",
path = _create_v1_path(
"/groups/%s/users/%s/accept_invite",
group_id, user_id,
)
@ -633,7 +633,7 @@ class TransportLayerClient(object):
def join_group(self, destination, group_id, user_id, content):
"""Attempts to join a group
"""
path = _create_path(PREFIX, "/groups/%s/users/%s/join", group_id, user_id)
path = _create_v1_path("/groups/%s/users/%s/join", group_id, user_id)
return self.client.post_json(
destination=destination,
@ -646,7 +646,7 @@ class TransportLayerClient(object):
def invite_to_group(self, destination, group_id, user_id, requester_user_id, content):
"""Invite a user to a group
"""
path = _create_path(PREFIX, "/groups/%s/users/%s/invite", group_id, user_id)
path = _create_v1_path("/groups/%s/users/%s/invite", group_id, user_id)
return self.client.post_json(
destination=destination,
@ -662,7 +662,7 @@ class TransportLayerClient(object):
invited.
"""
path = _create_path(PREFIX, "/groups/local/%s/users/%s/invite", group_id, user_id)
path = _create_v1_path("/groups/local/%s/users/%s/invite", group_id, user_id)
return self.client.post_json(
destination=destination,
@ -676,7 +676,7 @@ class TransportLayerClient(object):
user_id, content):
"""Remove a user fron a group
"""
path = _create_path(PREFIX, "/groups/%s/users/%s/remove", group_id, user_id)
path = _create_v1_path("/groups/%s/users/%s/remove", group_id, user_id)
return self.client.post_json(
destination=destination,
@ -693,7 +693,7 @@ class TransportLayerClient(object):
kicked from the group.
"""
path = _create_path(PREFIX, "/groups/local/%s/users/%s/remove", group_id, user_id)
path = _create_v1_path("/groups/local/%s/users/%s/remove", group_id, user_id)
return self.client.post_json(
destination=destination,
@ -708,7 +708,7 @@ class TransportLayerClient(object):
the attestations
"""
path = _create_path(PREFIX, "/groups/%s/renew_attestation/%s", group_id, user_id)
path = _create_v1_path("/groups/%s/renew_attestation/%s", group_id, user_id)
return self.client.post_json(
destination=destination,
@ -723,12 +723,12 @@ class TransportLayerClient(object):
"""Update a room entry in a group summary
"""
if category_id:
path = _create_path(
PREFIX, "/groups/%s/summary/categories/%s/rooms/%s",
path = _create_v1_path(
"/groups/%s/summary/categories/%s/rooms/%s",
group_id, category_id, room_id,
)
else:
path = _create_path(PREFIX, "/groups/%s/summary/rooms/%s", group_id, room_id,)
path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id,)
return self.client.post_json(
destination=destination,
@ -744,12 +744,12 @@ class TransportLayerClient(object):
"""Delete a room entry in a group summary
"""
if category_id:
path = _create_path(
PREFIX + "/groups/%s/summary/categories/%s/rooms/%s",
path = _create_v1_path(
"/groups/%s/summary/categories/%s/rooms/%s",
group_id, category_id, room_id,
)
else:
path = _create_path(PREFIX, "/groups/%s/summary/rooms/%s", group_id, room_id,)
path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id,)
return self.client.delete_json(
destination=destination,
@ -762,7 +762,7 @@ class TransportLayerClient(object):
def get_group_categories(self, destination, group_id, requester_user_id):
"""Get all categories in a group
"""
path = _create_path(PREFIX, "/groups/%s/categories", group_id,)
path = _create_v1_path("/groups/%s/categories", group_id,)
return self.client.get_json(
destination=destination,
@ -775,7 +775,7 @@ class TransportLayerClient(object):
def get_group_category(self, destination, group_id, requester_user_id, category_id):
"""Get category info in a group
"""
path = _create_path(PREFIX, "/groups/%s/categories/%s", group_id, category_id,)
path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id,)
return self.client.get_json(
destination=destination,
@ -789,7 +789,7 @@ class TransportLayerClient(object):
content):
"""Update a category in a group
"""
path = _create_path(PREFIX, "/groups/%s/categories/%s", group_id, category_id,)
path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id,)
return self.client.post_json(
destination=destination,
@ -804,7 +804,7 @@ class TransportLayerClient(object):
category_id):
"""Delete a category in a group
"""
path = _create_path(PREFIX, "/groups/%s/categories/%s", group_id, category_id,)
path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id,)
return self.client.delete_json(
destination=destination,
@ -817,7 +817,7 @@ class TransportLayerClient(object):
def get_group_roles(self, destination, group_id, requester_user_id):
"""Get all roles in a group
"""
path = _create_path(PREFIX, "/groups/%s/roles", group_id,)
path = _create_v1_path("/groups/%s/roles", group_id,)
return self.client.get_json(
destination=destination,
@ -830,7 +830,7 @@ class TransportLayerClient(object):
def get_group_role(self, destination, group_id, requester_user_id, role_id):
"""Get a roles info
"""
path = _create_path(PREFIX, "/groups/%s/roles/%s", group_id, role_id,)
path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id,)
return self.client.get_json(
destination=destination,
@ -844,7 +844,7 @@ class TransportLayerClient(object):
content):
"""Update a role in a group
"""
path = _create_path(PREFIX, "/groups/%s/roles/%s", group_id, role_id,)
path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id,)
return self.client.post_json(
destination=destination,
@ -858,7 +858,7 @@ class TransportLayerClient(object):
def delete_group_role(self, destination, group_id, requester_user_id, role_id):
"""Delete a role in a group
"""
path = _create_path(PREFIX, "/groups/%s/roles/%s", group_id, role_id,)
path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id,)
return self.client.delete_json(
destination=destination,
@ -873,12 +873,12 @@ class TransportLayerClient(object):
"""Update a users entry in a group
"""
if role_id:
path = _create_path(
PREFIX, "/groups/%s/summary/roles/%s/users/%s",
path = _create_v1_path(
"/groups/%s/summary/roles/%s/users/%s",
group_id, role_id, user_id,
)
else:
path = _create_path(PREFIX, "/groups/%s/summary/users/%s", group_id, user_id,)
path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id,)
return self.client.post_json(
destination=destination,
@ -893,7 +893,7 @@ class TransportLayerClient(object):
content):
"""Sets the join policy for a group
"""
path = _create_path(PREFIX, "/groups/%s/settings/m.join_policy", group_id,)
path = _create_v1_path("/groups/%s/settings/m.join_policy", group_id,)
return self.client.put_json(
destination=destination,
@ -909,12 +909,12 @@ class TransportLayerClient(object):
"""Delete a users entry in a group
"""
if role_id:
path = _create_path(
PREFIX, "/groups/%s/summary/roles/%s/users/%s",
path = _create_v1_path(
"/groups/%s/summary/roles/%s/users/%s",
group_id, role_id, user_id,
)
else:
path = _create_path(PREFIX, "/groups/%s/summary/users/%s", group_id, user_id,)
path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id,)
return self.client.delete_json(
destination=destination,
@ -927,7 +927,7 @@ class TransportLayerClient(object):
"""Get the groups a list of users are publicising
"""
path = PREFIX + "/get_groups_publicised"
path = _create_v1_path("/get_groups_publicised")
content = {"user_ids": user_ids}
@ -939,20 +939,22 @@ class TransportLayerClient(object):
)
def _create_path(prefix, path, *args):
"""Creates a path from the prefix, path template and args. Ensures that
all args are url encoded.
def _create_v1_path(path, *args):
"""Creates a path against V1 federation API from the path template and
args. Ensures that all args are url encoded.
Example:
_create_path(PREFIX, "/event/%s/", event_id)
_create_v1_path("/event/%s/", event_id)
Args:
prefix (str)
path (str): String template for the path
args: ([str]): Args to insert into path. Each arg will be url encoded
Returns:
str
"""
return prefix + path % tuple(urllib.parse.quote(arg, "") for arg in args)
return (
FEDERATION_V1_PREFIX
+ path % tuple(urllib.parse.quote(arg, "") for arg in args)
)

View file

@ -21,8 +21,9 @@ import re
from twisted.internet import defer
import synapse
from synapse.api.constants import RoomVersions
from synapse.api.errors import Codes, FederationDeniedError, SynapseError
from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.api.urls import FEDERATION_V1_PREFIX, FEDERATION_V2_PREFIX
from synapse.http.endpoint import parse_and_validate_server_name
from synapse.http.server import JsonResource
from synapse.http.servlet import (
@ -227,6 +228,8 @@ class BaseFederationServlet(object):
"""
REQUIRE_AUTH = True
PREFIX = FEDERATION_V1_PREFIX # Allows specifying the API version
def __init__(self, handler, authenticator, ratelimiter, server_name):
self.handler = handler
self.authenticator = authenticator
@ -286,7 +289,7 @@ class BaseFederationServlet(object):
return new_func
def register(self, server):
pattern = re.compile("^" + PREFIX + self.PATH + "$")
pattern = re.compile("^" + self.PREFIX + self.PATH + "$")
for method in ("GET", "PUT", "POST"):
code = getattr(self, "on_%s" % (method), None)
@ -488,14 +491,46 @@ class FederationSendJoinServlet(BaseFederationServlet):
defer.returnValue((200, content))
class FederationInviteServlet(BaseFederationServlet):
class FederationV1InviteServlet(BaseFederationServlet):
PATH = "/invite/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
@defer.inlineCallbacks
def on_PUT(self, origin, content, query, context, event_id):
# We don't get a room version, so we have to assume its EITHER v1 or
# v2. This is "fine" as the only difference between V1 and V2 is the
# state resolution algorithm, and we don't use that for processing
# invites
content = yield self.handler.on_invite_request(
origin, content, room_version=RoomVersions.V1,
)
# V1 federation API is defined to return a content of `[200, {...}]`
# due to a historical bug.
defer.returnValue((200, (200, content)))
class FederationV2InviteServlet(BaseFederationServlet):
PATH = "/invite/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
PREFIX = FEDERATION_V2_PREFIX
@defer.inlineCallbacks
def on_PUT(self, origin, content, query, context, event_id):
# TODO(paul): assert that context/event_id parsed from path actually
# match those given in content
content = yield self.handler.on_invite_request(origin, content)
room_version = content["room_version"]
event = content["event"]
invite_room_state = content["invite_room_state"]
# Synapse expects invite_room_state to be in unsigned, as it is in v1
# API
event.setdefault("unsigned", {})["invite_room_state"] = invite_room_state
content = yield self.handler.on_invite_request(
origin, event, room_version=room_version,
)
defer.returnValue((200, content))
@ -1263,7 +1298,8 @@ FEDERATION_SERVLET_CLASSES = (
FederationEventServlet,
FederationSendJoinServlet,
FederationSendLeaveServlet,
FederationInviteServlet,
FederationV1InviteServlet,
FederationV2InviteServlet,
FederationQueryAuthServlet,
FederationGetMissingEventsServlet,
FederationEventAuthServlet,

147
synapse/handlers/acme.py Normal file
View file

@ -0,0 +1,147 @@
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import attr
from zope.interface import implementer
from twisted.internet import defer
from twisted.internet.endpoints import serverFromString
from twisted.python.filepath import FilePath
from twisted.python.url import URL
from twisted.web import server, static
from twisted.web.resource import Resource
logger = logging.getLogger(__name__)
try:
from txacme.interfaces import ICertificateStore
@attr.s
@implementer(ICertificateStore)
class ErsatzStore(object):
"""
A store that only stores in memory.
"""
certs = attr.ib(default=attr.Factory(dict))
def store(self, server_name, pem_objects):
self.certs[server_name] = [o.as_bytes() for o in pem_objects]
return defer.succeed(None)
except ImportError:
# txacme is missing
pass
class AcmeHandler(object):
def __init__(self, hs):
self.hs = hs
self.reactor = hs.get_reactor()
@defer.inlineCallbacks
def start_listening(self):
# Configure logging for txacme, if you need to debug
# from eliot import add_destinations
# from eliot.twisted import TwistedDestination
#
# add_destinations(TwistedDestination())
from txacme.challenges import HTTP01Responder
from txacme.service import AcmeIssuingService
from txacme.endpoint import load_or_create_client_key
from txacme.client import Client
from josepy.jwa import RS256
self._store = ErsatzStore()
responder = HTTP01Responder()
self._issuer = AcmeIssuingService(
cert_store=self._store,
client_creator=(
lambda: Client.from_url(
reactor=self.reactor,
url=URL.from_text(self.hs.config.acme_url),
key=load_or_create_client_key(
FilePath(self.hs.config.config_dir_path)
),
alg=RS256,
)
),
clock=self.reactor,
responders=[responder],
)
well_known = Resource()
well_known.putChild(b'acme-challenge', responder.resource)
responder_resource = Resource()
responder_resource.putChild(b'.well-known', well_known)
responder_resource.putChild(b'check', static.Data(b'OK', b'text/plain'))
srv = server.Site(responder_resource)
listeners = []
for host in self.hs.config.acme_bind_addresses:
logger.info(
"Listening for ACME requests on %s:%s", host, self.hs.config.acme_port
)
endpoint = serverFromString(
self.reactor, "tcp:%s:interface=%s" % (self.hs.config.acme_port, host)
)
listeners.append(endpoint.listen(srv))
# Make sure we are registered to the ACME server. There's no public API
# for this, it is usually triggered by startService, but since we don't
# want it to control where we save the certificates, we have to reach in
# and trigger the registration machinery ourselves.
self._issuer._registered = False
yield self._issuer._ensure_registered()
# Return a Deferred that will fire when all the servers have started up.
yield defer.DeferredList(listeners, fireOnOneErrback=True, consumeErrors=True)
@defer.inlineCallbacks
def provision_certificate(self):
logger.warning("Reprovisioning %s", self.hs.hostname)
try:
yield self._issuer.issue_cert(self.hs.hostname)
except Exception:
logger.exception("Fail!")
raise
logger.warning("Reprovisioned %s, saving.", self.hs.hostname)
cert_chain = self._store.certs[self.hs.hostname]
try:
with open(self.hs.config.tls_private_key_file, "wb") as private_key_file:
for x in cert_chain:
if x.startswith(b"-----BEGIN RSA PRIVATE KEY-----"):
private_key_file.write(x)
with open(self.hs.config.tls_certificate_file, "wb") as certificate_file:
for x in cert_chain:
if x.startswith(b"-----BEGIN CERTIFICATE-----"):
certificate_file.write(x)
except Exception:
logger.exception("Failed saving!")
raise
defer.returnValue(True)

View file

@ -167,18 +167,21 @@ class IdentityHandler(BaseHandler):
"mxid": mxid,
"threepid": threepid,
}
headers = {}
# we abuse the federation http client to sign the request, but we have to send it
# using the normal http client since we don't want the SRV lookup and want normal
# 'browser-like' HTTPS.
self.federation_http_client.sign_request(
auth_headers = self.federation_http_client.build_auth_headers(
destination=None,
method='POST',
url_bytes='/_matrix/identity/api/v1/3pid/unbind'.encode('ascii'),
headers_dict=headers,
content=content,
destination_is=id_server,
)
headers = {
b"Authorization": auth_headers,
}
try:
yield self.http_client.post_json_get_json(
url,

View file

@ -269,6 +269,7 @@ class RoomCreationHandler(BaseHandler):
(EventTypes.RoomHistoryVisibility, ""),
(EventTypes.GuestAccess, ""),
(EventTypes.RoomAvatar, ""),
(EventTypes.Encryption, ""),
)
old_room_state_ids = yield self.store.get_filtered_current_state_ids(

View file

@ -15,6 +15,7 @@
import logging
from collections import namedtuple
from datetime import datetime, timedelta
from six import PY3, iteritems
from six.moves import range
@ -76,8 +77,14 @@ class RoomListHandler(BaseHandler):
# We explicitly don't bother caching searches or requests for
# appservice specific lists.
logger.info("Bypassing cache as search request.")
# XXX: Quick hack to stop room directory queries taking too long.
# Timeout request after 60s. Probably want a more fundamental
# solution at some point
timeout = datetime.now() + timedelta(seconds=60)
return self._get_public_room_list(
limit, since_token, search_filter, network_tuple=network_tuple,
limit, since_token, search_filter,
network_tuple=network_tuple, timeout=timeout,
)
key = (limit, since_token, network_tuple)
@ -90,7 +97,8 @@ class RoomListHandler(BaseHandler):
@defer.inlineCallbacks
def _get_public_room_list(self, limit=None, since_token=None,
search_filter=None,
network_tuple=EMPTY_THIRD_PARTY_ID,):
network_tuple=EMPTY_THIRD_PARTY_ID,
timeout=None,):
if since_token and since_token != "END":
since_token = RoomListNextBatch.from_token(since_token)
else:
@ -205,6 +213,9 @@ class RoomListHandler(BaseHandler):
chunk = []
for i in range(0, len(rooms_to_scan), step):
if timeout and datetime.now() > timeout:
raise Exception("Timed out searching room directory")
batch = rooms_to_scan[i:i + step]
logger.info("Processing %i rooms for result", len(batch))
yield concurrently_execute(

View file

@ -20,6 +20,7 @@ from six import iteritems
from twisted.internet import defer
import synapse.metrics
from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.roommember import ProfileInfo

View file

@ -12,30 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import logging
import random
import re
import time
from twisted.internet import defer
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.internet.error import ConnectError
from twisted.names import client, dns
from twisted.names.error import DNSNameError, DomainError
logger = logging.getLogger(__name__)
SERVER_CACHE = {}
# our record of an individual server which can be tried to reach a destination.
#
# "host" is the hostname acquired from the SRV record. Except when there's
# no SRV record, in which case it is the original hostname.
_Server = collections.namedtuple(
"_Server", "priority weight host port expires"
)
def parse_server_name(server_name):
"""Split a server name into host/port parts.
@ -100,264 +81,3 @@ def parse_and_validate_server_name(server_name):
))
return host, port
def matrix_federation_endpoint(reactor, destination, tls_client_options_factory=None,
timeout=None):
"""Construct an endpoint for the given matrix destination.
Args:
reactor: Twisted reactor.
destination (unicode): The name of the server to connect to.
tls_client_options_factory
(synapse.crypto.context_factory.ClientTLSOptionsFactory):
Factory which generates TLS options for client connections.
timeout (int): connection timeout in seconds
"""
domain, port = parse_server_name(destination)
endpoint_kw_args = {}
if timeout is not None:
endpoint_kw_args.update(timeout=timeout)
if tls_client_options_factory is None:
transport_endpoint = HostnameEndpoint
default_port = 8008
else:
# the SNI string should be the same as the Host header, minus the port.
# as per https://github.com/matrix-org/synapse/issues/2525#issuecomment-336896777,
# the Host header and SNI should therefore be the server_name of the remote
# server.
tls_options = tls_client_options_factory.get_options(domain)
def transport_endpoint(reactor, host, port, timeout):
return wrapClientTLS(
tls_options,
HostnameEndpoint(reactor, host, port, timeout=timeout),
)
default_port = 8448
if port is None:
return _WrappingEndpointFac(SRVClientEndpoint(
reactor, "matrix", domain, protocol="tcp",
default_port=default_port, endpoint=transport_endpoint,
endpoint_kw_args=endpoint_kw_args
), reactor)
else:
return _WrappingEndpointFac(transport_endpoint(
reactor, domain, port, **endpoint_kw_args
), reactor)
class _WrappingEndpointFac(object):
def __init__(self, endpoint_fac, reactor):
self.endpoint_fac = endpoint_fac
self.reactor = reactor
@defer.inlineCallbacks
def connect(self, protocolFactory):
conn = yield self.endpoint_fac.connect(protocolFactory)
conn = _WrappedConnection(conn, self.reactor)
defer.returnValue(conn)
class _WrappedConnection(object):
"""Wraps a connection and calls abort on it if it hasn't seen any action
for 2.5-3 minutes.
"""
__slots__ = ["conn", "last_request"]
def __init__(self, conn, reactor):
object.__setattr__(self, "conn", conn)
object.__setattr__(self, "last_request", time.time())
self._reactor = reactor
def __getattr__(self, name):
return getattr(self.conn, name)
def __setattr__(self, name, value):
setattr(self.conn, name, value)
def _time_things_out_maybe(self):
# We use a slightly shorter timeout here just in case the callLater is
# triggered early. Paranoia ftw.
# TODO: Cancel the previous callLater rather than comparing time.time()?
if time.time() - self.last_request >= 2.5 * 60:
self.abort()
# Abort the underlying TLS connection. The abort() method calls
# loseConnection() on the TLS connection which tries to
# shutdown the connection cleanly. We call abortConnection()
# since that will promptly close the TLS connection.
#
# In Twisted >18.4; the TLS connection will be None if it has closed
# which will make abortConnection() throw. Check that the TLS connection
# is not None before trying to close it.
if self.transport.getHandle() is not None:
self.transport.abortConnection()
def request(self, request):
self.last_request = time.time()
# Time this connection out if we haven't send a request in the last
# N minutes
# TODO: Cancel the previous callLater?
self._reactor.callLater(3 * 60, self._time_things_out_maybe)
d = self.conn.request(request)
def update_request_time(res):
self.last_request = time.time()
# TODO: Cancel the previous callLater?
self._reactor.callLater(3 * 60, self._time_things_out_maybe)
return res
d.addCallback(update_request_time)
return d
class SRVClientEndpoint(object):
"""An endpoint which looks up SRV records for a service.
Cycles through the list of servers starting with each call to connect
picking the next server.
Implements twisted.internet.interfaces.IStreamClientEndpoint.
"""
def __init__(self, reactor, service, domain, protocol="tcp",
default_port=None, endpoint=HostnameEndpoint,
endpoint_kw_args={}):
self.reactor = reactor
self.service_name = "_%s._%s.%s" % (service, protocol, domain)
if default_port is not None:
self.default_server = _Server(
host=domain,
port=default_port,
priority=0,
weight=0,
expires=0,
)
else:
self.default_server = None
self.endpoint = endpoint
self.endpoint_kw_args = endpoint_kw_args
self.servers = None
self.used_servers = None
@defer.inlineCallbacks
def fetch_servers(self):
self.used_servers = []
self.servers = yield resolve_service(self.service_name)
def pick_server(self):
if not self.servers:
if self.used_servers:
self.servers = self.used_servers
self.used_servers = []
self.servers.sort()
elif self.default_server:
return self.default_server
else:
raise ConnectError(
"No server available for %s" % self.service_name
)
# look for all servers with the same priority
min_priority = self.servers[0].priority
weight_indexes = list(
(index, server.weight + 1)
for index, server in enumerate(self.servers)
if server.priority == min_priority
)
total_weight = sum(weight for index, weight in weight_indexes)
target_weight = random.randint(0, total_weight)
for index, weight in weight_indexes:
target_weight -= weight
if target_weight <= 0:
server = self.servers[index]
# XXX: this looks totally dubious:
#
# (a) we never reuse a server until we have been through
# all of the servers at the same priority, so if the
# weights are A: 100, B:1, we always do ABABAB instead of
# AAAA...AAAB (approximately).
#
# (b) After using all the servers at the lowest priority,
# we move onto the next priority. We should only use the
# second priority if servers at the top priority are
# unreachable.
#
del self.servers[index]
self.used_servers.append(server)
return server
@defer.inlineCallbacks
def connect(self, protocolFactory):
if self.servers is None:
yield self.fetch_servers()
server = self.pick_server()
logger.info("Connecting to %s:%s", server.host, server.port)
endpoint = self.endpoint(
self.reactor, server.host, server.port, **self.endpoint_kw_args
)
connection = yield endpoint.connect(protocolFactory)
defer.returnValue(connection)
@defer.inlineCallbacks
def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=time):
cache_entry = cache.get(service_name, None)
if cache_entry:
if all(s.expires > int(clock.time()) for s in cache_entry):
servers = list(cache_entry)
defer.returnValue(servers)
servers = []
try:
try:
answers, _, _ = yield dns_client.lookupService(service_name)
except DNSNameError:
defer.returnValue([])
if (len(answers) == 1
and answers[0].type == dns.SRV
and answers[0].payload
and answers[0].payload.target == dns.Name(b'.')):
raise ConnectError("Service %s unavailable" % service_name)
for answer in answers:
if answer.type != dns.SRV or not answer.payload:
continue
payload = answer.payload
servers.append(_Server(
host=str(payload.target),
port=int(payload.port),
priority=int(payload.priority),
weight=int(payload.weight),
expires=int(clock.time()) + answer.ttl,
))
servers.sort()
cache[service_name] = list(servers)
except DomainError as e:
# We failed to resolve the name (other than a NameError)
# Try something in the cache, else rereaise
cache_entry = cache.get(service_name, None)
if cache_entry:
logger.warn(
"Failed to resolve %r, falling back to cache. %r",
service_name, e
)
servers = list(cache_entry)
else:
raise e
defer.returnValue(servers)

View file

@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -0,0 +1,124 @@
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from zope.interface import implementer
from twisted.internet import defer
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.web.client import URI, Agent, HTTPConnectionPool
from twisted.web.iweb import IAgent
from synapse.http.endpoint import parse_server_name
from synapse.http.federation.srv_resolver import SrvResolver, pick_server_from_list
from synapse.util.logcontext import make_deferred_yieldable
logger = logging.getLogger(__name__)
@implementer(IAgent)
class MatrixFederationAgent(object):
"""An Agent-like thing which provides a `request` method which will look up a matrix
server and send an HTTP request to it.
Doesn't implement any retries. (Those are done in MatrixFederationHttpClient.)
Args:
reactor (IReactor): twisted reactor to use for underlying requests
tls_client_options_factory (ClientTLSOptionsFactory|None):
factory to use for fetching client tls options, or none to disable TLS.
srv_resolver (SrvResolver|None):
SRVResolver impl to use for looking up SRV records. None to use a default
implementation.
"""
def __init__(
self, reactor, tls_client_options_factory, _srv_resolver=None,
):
self._reactor = reactor
self._tls_client_options_factory = tls_client_options_factory
if _srv_resolver is None:
_srv_resolver = SrvResolver()
self._srv_resolver = _srv_resolver
self._pool = HTTPConnectionPool(reactor)
self._pool.retryAutomatically = False
self._pool.maxPersistentPerHost = 5
self._pool.cachedConnectionTimeout = 2 * 60
@defer.inlineCallbacks
def request(self, method, uri, headers=None, bodyProducer=None):
"""
Args:
method (bytes): HTTP method: GET/POST/etc
uri (bytes): Absolute URI to be retrieved
headers (twisted.web.http_headers.Headers|None):
HTTP headers to send with the request, or None to
send no extra headers.
bodyProducer (twisted.web.iweb.IBodyProducer|None):
An object which can generate bytes to make up the
body of this request (for example, the properly encoded contents of
a file for a file upload). Or None if the request is to have
no body.
Returns:
Deferred[twisted.web.iweb.IResponse]:
fires when the header of the response has been received (regardless of the
response status code). Fails if there is any problem which prevents that
response from being received (including problems that prevent the request
from being sent).
"""
parsed_uri = URI.fromBytes(uri)
server_name_bytes = parsed_uri.netloc
host, port = parse_server_name(server_name_bytes.decode("ascii"))
# XXX disabling TLS is really only supported here for the benefit of the
# unit tests. We should make the UTs cope with TLS rather than having to make
# the code support the unit tests.
if self._tls_client_options_factory is None:
tls_options = None
else:
tls_options = self._tls_client_options_factory.get_options(host)
if port is not None:
target = (host, port)
else:
server_list = yield self._srv_resolver.resolve_service(server_name_bytes)
if not server_list:
target = (host, 8448)
logger.debug("No SRV record for %s, using %s", host, target)
else:
target = pick_server_from_list(server_list)
class EndpointFactory(object):
@staticmethod
def endpointForURI(_uri):
logger.info("Connecting to %s:%s", target[0], target[1])
ep = HostnameEndpoint(self._reactor, host=target[0], port=target[1])
if tls_options is not None:
ep = wrapClientTLS(tls_options, ep)
return ep
agent = Agent.usingEndpointFactory(self._reactor, EndpointFactory(), self._pool)
res = yield make_deferred_yieldable(
agent.request(method, uri, headers, bodyProducer)
)
defer.returnValue(res)

View file

@ -0,0 +1,169 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
import time
import attr
from twisted.internet import defer
from twisted.internet.error import ConnectError
from twisted.names import client, dns
from twisted.names.error import DNSNameError, DomainError
from synapse.util.logcontext import make_deferred_yieldable
logger = logging.getLogger(__name__)
SERVER_CACHE = {}
@attr.s
class Server(object):
"""
Our record of an individual server which can be tried to reach a destination.
Attributes:
host (bytes): target hostname
port (int):
priority (int):
weight (int):
expires (int): when the cache should expire this record - in *seconds* since
the epoch
"""
host = attr.ib()
port = attr.ib()
priority = attr.ib(default=0)
weight = attr.ib(default=0)
expires = attr.ib(default=0)
def pick_server_from_list(server_list):
"""Randomly choose a server from the server list
Args:
server_list (list[Server]): list of candidate servers
Returns:
Tuple[bytes, int]: (host, port) pair for the chosen server
"""
if not server_list:
raise RuntimeError("pick_server_from_list called with empty list")
# TODO: currently we only use the lowest-priority servers. We should maintain a
# cache of servers known to be "down" and filter them out
min_priority = min(s.priority for s in server_list)
eligible_servers = list(s for s in server_list if s.priority == min_priority)
total_weight = sum(s.weight for s in eligible_servers)
target_weight = random.randint(0, total_weight)
for s in eligible_servers:
target_weight -= s.weight
if target_weight <= 0:
return s.host, s.port
# this should be impossible.
raise RuntimeError(
"pick_server_from_list got to end of eligible server list.",
)
class SrvResolver(object):
"""Interface to the dns client to do SRV lookups, with result caching.
The default resolver in twisted.names doesn't do any caching (it has a CacheResolver,
but the cache never gets populated), so we add our own caching layer here.
Args:
dns_client (twisted.internet.interfaces.IResolver): twisted resolver impl
cache (dict): cache object
get_time (callable): clock implementation. Should return seconds since the epoch
"""
def __init__(self, dns_client=client, cache=SERVER_CACHE, get_time=time.time):
self._dns_client = dns_client
self._cache = cache
self._get_time = get_time
@defer.inlineCallbacks
def resolve_service(self, service_name):
"""Look up a SRV record
Args:
service_name (bytes): record to look up
Returns:
Deferred[list[Server]]:
a list of the SRV records, or an empty list if none found
"""
now = int(self._get_time())
if not isinstance(service_name, bytes):
raise TypeError("%r is not a byte string" % (service_name,))
cache_entry = self._cache.get(service_name, None)
if cache_entry:
if all(s.expires > now for s in cache_entry):
servers = list(cache_entry)
defer.returnValue(servers)
try:
answers, _, _ = yield make_deferred_yieldable(
self._dns_client.lookupService(service_name),
)
except DNSNameError:
# TODO: cache this. We can get the SOA out of the exception, and use
# the negative-TTL value.
defer.returnValue([])
except DomainError as e:
# We failed to resolve the name (other than a NameError)
# Try something in the cache, else rereaise
cache_entry = self._cache.get(service_name, None)
if cache_entry:
logger.warn(
"Failed to resolve %r, falling back to cache. %r",
service_name, e
)
defer.returnValue(list(cache_entry))
else:
raise e
if (len(answers) == 1
and answers[0].type == dns.SRV
and answers[0].payload
and answers[0].payload.target == dns.Name(b'.')):
raise ConnectError("Service %s unavailable" % service_name)
servers = []
for answer in answers:
if answer.type != dns.SRV or not answer.payload:
continue
payload = answer.payload
servers.append(Server(
host=payload.target.name,
port=payload.port,
priority=payload.priority,
weight=payload.weight,
expires=now + answer.ttl,
))
self._cache[service_name] = list(servers)
defer.returnValue(servers)

View file

@ -32,7 +32,7 @@ from twisted.internet import defer, protocol
from twisted.internet.error import DNSLookupError
from twisted.internet.task import _EPSILON, Cooperator
from twisted.web._newclient import ResponseDone
from twisted.web.client import Agent, FileBodyProducer, HTTPConnectionPool
from twisted.web.client import FileBodyProducer
from twisted.web.http_headers import Headers
import synapse.metrics
@ -44,7 +44,7 @@ from synapse.api.errors import (
RequestSendFailed,
SynapseError,
)
from synapse.http.endpoint import matrix_federation_endpoint
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.util.async_helpers import timeout_deferred
from synapse.util.logcontext import make_deferred_yieldable
from synapse.util.metrics import Measure
@ -66,20 +66,6 @@ else:
MAXINT = sys.maxint
class MatrixFederationEndpointFactory(object):
def __init__(self, hs):
self.reactor = hs.get_reactor()
self.tls_client_options_factory = hs.tls_client_options_factory
def endpointForURI(self, uri):
destination = uri.netloc.decode('ascii')
return matrix_federation_endpoint(
self.reactor, destination, timeout=10,
tls_client_options_factory=self.tls_client_options_factory
)
_next_id = 1
@ -187,12 +173,10 @@ class MatrixFederationHttpClient(object):
self.signing_key = hs.config.signing_key[0]
self.server_name = hs.hostname
reactor = hs.get_reactor()
pool = HTTPConnectionPool(reactor)
pool.retryAutomatically = False
pool.maxPersistentPerHost = 5
pool.cachedConnectionTimeout = 2 * 60
self.agent = Agent.usingEndpointFactory(
reactor, MatrixFederationEndpointFactory(hs), pool=pool
self.agent = MatrixFederationAgent(
hs.get_reactor(),
hs.tls_client_options_factory,
)
self.clock = hs.get_clock()
self._store = hs.get_datastore()
@ -229,19 +213,18 @@ class MatrixFederationHttpClient(object):
backoff_on_404 (bool): Back off if we get a 404
Returns:
Deferred: resolves with the http response object on success.
Deferred[twisted.web.client.Response]: resolves with the HTTP
response object on success.
Fails with ``HttpResponseException``: if we get an HTTP response
code >= 300 (except 429).
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
Fails with ``FederationDeniedError`` if this destination
is not on our federation whitelist
Fails with ``RequestSendFailed`` if there were problems connecting to
the remote, due to e.g. DNS failures, connection timeouts etc.
Raises:
HttpResponseException: If we get an HTTP response code >= 300
(except 429).
NotRetryingDestination: If we are not yet ready to retry this
server.
FederationDeniedError: If this destination is not on our
federation whitelist
RequestSendFailed: If there were problems connecting to the
remote, due to e.g. DNS failures, connection timeouts etc.
"""
if timeout:
_sec_timeout = timeout / 1000
@ -299,9 +282,9 @@ class MatrixFederationHttpClient(object):
json = request.get_json()
if json:
headers_dict[b"Content-Type"] = [b"application/json"]
self.sign_request(
auth_headers = self.build_auth_headers(
destination_bytes, method_bytes, url_to_sign_bytes,
headers_dict, json,
json,
)
data = encode_canonical_json(json)
producer = FileBodyProducer(
@ -310,40 +293,40 @@ class MatrixFederationHttpClient(object):
)
else:
producer = None
self.sign_request(
auth_headers = self.build_auth_headers(
destination_bytes, method_bytes, url_to_sign_bytes,
headers_dict,
)
headers_dict[b"Authorization"] = auth_headers
logger.info(
"{%s} [%s] Sending request: %s %s",
"{%s} [%s] Sending request: %s %s; timeout %fs",
request.txn_id, request.destination, request.method,
url_str,
)
# we don't want all the fancy cookie and redirect handling that
# treq.request gives: just use the raw Agent.
request_deferred = self.agent.request(
method_bytes,
url_bytes,
headers=Headers(headers_dict),
bodyProducer=producer,
)
request_deferred = timeout_deferred(
request_deferred,
timeout=_sec_timeout,
reactor=self.hs.get_reactor(),
url_str, _sec_timeout,
)
try:
with Measure(self.clock, "outbound_request"):
response = yield make_deferred_yieldable(
request_deferred,
# we don't want all the fancy cookie and redirect handling
# that treq.request gives: just use the raw Agent.
request_deferred = self.agent.request(
method_bytes,
url_bytes,
headers=Headers(headers_dict),
bodyProducer=producer,
)
request_deferred = timeout_deferred(
request_deferred,
timeout=_sec_timeout,
reactor=self.hs.get_reactor(),
)
response = yield request_deferred
except DNSLookupError as e:
raise_from(RequestSendFailed(e, can_retry=retry_on_dns_fail), e)
except Exception as e:
logger.info("Failed to send request: %s", e)
raise_from(RequestSendFailed(e, can_retry=True), e)
logger.info(
@ -441,24 +424,23 @@ class MatrixFederationHttpClient(object):
defer.returnValue(response)
def sign_request(self, destination, method, url_bytes, headers_dict,
content=None, destination_is=None):
def build_auth_headers(
self, destination, method, url_bytes, content=None, destination_is=None,
):
"""
Signs a request by adding an Authorization header to headers_dict
Builds the Authorization headers for a federation request
Args:
destination (bytes|None): The desination home server of the request.
May be None if the destination is an identity server, in which case
destination_is must be non-None.
method (bytes): The HTTP method of the request
url_bytes (bytes): The URI path of the request
headers_dict (dict[bytes, list[bytes]]): Dictionary of request headers to
append to
content (object): The body of the request
destination_is (bytes): As 'destination', but if the destination is an
identity server
Returns:
None
list[bytes]: a list of headers to be added as "Authorization:" headers
"""
request = {
"method": method,
@ -485,8 +467,7 @@ class MatrixFederationHttpClient(object):
self.server_name, key, sig,
)).encode('ascii')
)
headers_dict[b"Authorization"] = auth_headers
return auth_headers
@defer.inlineCallbacks
def put_json(self, destination, path, args={}, data={},
@ -516,17 +497,18 @@ class MatrixFederationHttpClient(object):
requests)
Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result
will be the decoded JSON body.
Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body.
Fails with ``HttpResponseException`` if we get an HTTP response
code >= 300.
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
Fails with ``FederationDeniedError`` if this destination
is not on our federation whitelist
Raises:
HttpResponseException: If we get an HTTP response code >= 300
(except 429).
NotRetryingDestination: If we are not yet ready to retry this
server.
FederationDeniedError: If this destination is not on our
federation whitelist
RequestSendFailed: If there were problems connecting to the
remote, due to e.g. DNS failures, connection timeouts etc.
"""
request = MatrixFederationRequest(
@ -570,17 +552,18 @@ class MatrixFederationHttpClient(object):
try the request anyway.
args (dict): query params
Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result
will be the decoded JSON body.
Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body.
Fails with ``HttpResponseException`` if we get an HTTP response
code >= 300.
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
Fails with ``FederationDeniedError`` if this destination
is not on our federation whitelist
Raises:
HttpResponseException: If we get an HTTP response code >= 300
(except 429).
NotRetryingDestination: If we are not yet ready to retry this
server.
FederationDeniedError: If this destination is not on our
federation whitelist
RequestSendFailed: If there were problems connecting to the
remote, due to e.g. DNS failures, connection timeouts etc.
"""
request = MatrixFederationRequest(
@ -625,17 +608,18 @@ class MatrixFederationHttpClient(object):
ignore_backoff (bool): true to ignore the historical backoff data
and try the request anyway.
Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result
will be the decoded JSON body.
Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body.
Fails with ``HttpResponseException`` if we get an HTTP response
code >= 300.
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
Fails with ``FederationDeniedError`` if this destination
is not on our federation whitelist
Raises:
HttpResponseException: If we get an HTTP response code >= 300
(except 429).
NotRetryingDestination: If we are not yet ready to retry this
server.
FederationDeniedError: If this destination is not on our
federation whitelist
RequestSendFailed: If there were problems connecting to the
remote, due to e.g. DNS failures, connection timeouts etc.
"""
logger.debug("get_json args: %s", args)
@ -676,17 +660,18 @@ class MatrixFederationHttpClient(object):
ignore_backoff (bool): true to ignore the historical backoff data and
try the request anyway.
Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result
will be the decoded JSON body.
Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body.
Fails with ``HttpResponseException`` if we get an HTTP response
code >= 300.
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
Fails with ``FederationDeniedError`` if this destination
is not on our federation whitelist
Raises:
HttpResponseException: If we get an HTTP response code >= 300
(except 429).
NotRetryingDestination: If we are not yet ready to retry this
server.
FederationDeniedError: If this destination is not on our
federation whitelist
RequestSendFailed: If there were problems connecting to the
remote, due to e.g. DNS failures, connection timeouts etc.
"""
request = MatrixFederationRequest(
method="DELETE",
@ -719,18 +704,20 @@ class MatrixFederationHttpClient(object):
args (dict): Optional dictionary used to create the query string.
ignore_backoff (bool): true to ignore the historical backoff data
and try the request anyway.
Returns:
Deferred: resolves with an (int,dict) tuple of the file length and
a dict of the response headers.
Deferred[tuple[int, dict]]: Resolves with an (int,dict) tuple of
the file length and a dict of the response headers.
Fails with ``HttpResponseException`` if we get an HTTP response code
>= 300
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
Fails with ``FederationDeniedError`` if this destination
is not on our federation whitelist
Raises:
HttpResponseException: If we get an HTTP response code >= 300
(except 429).
NotRetryingDestination: If we are not yet ready to retry this
server.
FederationDeniedError: If this destination is not on our
federation whitelist
RequestSendFailed: If there were problems connecting to the
remote, due to e.g. DNS failures, connection timeouts etc.
"""
request = MatrixFederationRequest(
method="GET",

View file

@ -40,7 +40,11 @@ REQUIREMENTS = [
"signedjson>=1.0.0",
"pynacl>=1.2.1",
"service_identity>=16.0.0",
"Twisted>=17.1.0",
# our logcontext handling relies on the ability to cancel inlineCallbacks
# (https://twistedmatrix.com/trac/ticket/4632) which landed in Twisted 18.7.
"Twisted>=18.7.0",
"treq>=15.1",
# Twisted has required pyopenssl 16.0 since about Twisted 16.6.
"pyopenssl>=16.0.0",
@ -52,15 +56,18 @@ REQUIREMENTS = [
"pillow>=3.1.2",
"sortedcontainers>=1.4.4",
"psutil>=2.0.0",
"pymacaroons-pynacl>=0.9.3",
"msgpack-python>=0.4.2",
"pymacaroons>=0.13.0",
"msgpack>=0.5.0",
"phonenumbers>=8.2.0",
"six>=1.10",
# prometheus_client 0.4.0 changed the format of counter metrics
# (cf https://github.com/matrix-org/synapse/issues/4001)
"prometheus_client>=0.0.18,<0.4.0",
# we use attr.s(slots), which arrived in 16.0.0
"attrs>=16.0.0",
# Twisted 18.7.0 requires attrs>=17.4.0
"attrs>=17.4.0",
"netaddr>=0.7.18",
]
@ -72,6 +79,10 @@ CONDITIONAL_REQUIREMENTS = {
# ConsentResource uses select_autoescape, which arrived in jinja 2.9
"resources.consent": ["Jinja2>=2.9"],
# ACME support is required to provision TLS certificates from authorities
# that use the protocol, such as Let's Encrypt.
"acme": ["txacme>=0.9.2"],
"saml2": ["pysaml2>=4.5.0"],
"url_preview": ["lxml>=3.5.0"],
"test": ["mock>=2.0"],

View file

@ -309,22 +309,16 @@ class RegisterRestServlet(RestServlet):
assigned_user_id=registered_user_id,
)
# Only give msisdn flows if the x_show_msisdn flag is given:
# this is a hack to work around the fact that clients were shipped
# that use fallback registration if they see any flows that they don't
# recognise, which means we break registration for these clients if we
# advertise msisdn flows. Once usage of Riot iOS <=0.3.9 and Riot
# Android <=0.6.9 have fallen below an acceptable threshold, this
# parameter should go away and we should always advertise msisdn flows.
show_msisdn = False
if 'x_show_msisdn' in body and body['x_show_msisdn']:
show_msisdn = True
# FIXME: need a better error than "no auth flow found" for scenarios
# where we required 3PID for registration but the user didn't give one
require_email = 'email' in self.hs.config.registrations_require_3pid
require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid
show_msisdn = True
if self.hs.config.disable_msisdn_registration:
show_msisdn = False
require_msisdn = False
flows = []
if self.hs.config.enable_registration_captcha:
# only support 3PIDless registration if no 3PIDs are required

View file

@ -46,6 +46,7 @@ from synapse.federation.transport.client import TransportLayerClient
from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer
from synapse.groups.groups_server import GroupsServerHandler
from synapse.handlers import Handlers
from synapse.handlers.acme import AcmeHandler
from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.handlers.auth import AuthHandler, MacaroonGenerator
from synapse.handlers.deactivate_account import DeactivateAccountHandler
@ -129,6 +130,7 @@ class HomeServer(object):
'sync_handler',
'typing_handler',
'room_list_handler',
'acme_handler',
'auth_handler',
'device_handler',
'e2e_keys_handler',
@ -310,6 +312,9 @@ class HomeServer(object):
def build_e2e_room_keys_handler(self):
return E2eRoomKeysHandler(self)
def build_acme_handler(self):
return AcmeHandler(self)
def build_application_service_api(self):
return ApplicationServiceApi(self)

View file

@ -192,6 +192,41 @@ class SQLBaseStore(object):
self.database_engine = hs.database_engine
# A set of tables that are not safe to use native upserts in.
self._unsafe_to_upsert_tables = {"user_ips"}
if self.database_engine.can_native_upsert:
# Check ASAP (and then later, every 1s) to see if we have finished
# background updates of tables that aren't safe to update.
self._clock.call_later(0.0, self._check_safe_to_upsert)
@defer.inlineCallbacks
def _check_safe_to_upsert(self):
"""
Is it safe to use native UPSERT?
If there are background updates, we will need to wait, as they may be
the addition of indexes that set the UNIQUE constraint that we require.
If the background updates have not completed, wait a second and check again.
"""
updates = yield self._simple_select_list(
"background_updates",
keyvalues=None,
retcols=["update_name"],
desc="check_background_updates",
)
updates = [x["update_name"] for x in updates]
# The User IPs table in schema #53 was missing a unique index, which we
# run as a background update.
if "user_ips_device_unique_index" not in updates:
self._unsafe_to_upsert_tables.discard("user_id")
# If there's any tables left to check, reschedule to run.
if self._unsafe_to_upsert_tables:
self._clock.call_later(1.0, self._check_safe_to_upsert)
def start_profiling(self):
self._previous_loop_ts = self._clock.time_msec()
@ -494,8 +529,15 @@ class SQLBaseStore(object):
txn.executemany(sql, vals)
@defer.inlineCallbacks
def _simple_upsert(self, table, keyvalues, values,
insertion_values={}, desc="_simple_upsert", lock=True):
def _simple_upsert(
self,
table,
keyvalues,
values,
insertion_values={},
desc="_simple_upsert",
lock=True
):
"""
`lock` should generally be set to True (the default), but can be set
@ -516,16 +558,21 @@ class SQLBaseStore(object):
inserting
lock (bool): True to lock the table when doing the upsert.
Returns:
Deferred(bool): True if a new entry was created, False if an
existing one was updated.
Deferred(None or bool): Native upserts always return None. Emulated
upserts return True if a new entry was created, False if an existing
one was updated.
"""
attempts = 0
while True:
try:
result = yield self.runInteraction(
desc,
self._simple_upsert_txn, table, keyvalues, values, insertion_values,
lock=lock
self._simple_upsert_txn,
table,
keyvalues,
values,
insertion_values,
lock=lock,
)
defer.returnValue(result)
except self.database_engine.module.IntegrityError as e:
@ -537,21 +584,76 @@ class SQLBaseStore(object):
# presumably we raced with another transaction: let's retry.
logger.warn(
"IntegrityError when upserting into %s; retrying: %s",
table, e
"%s when upserting into %s; retrying: %s", e.__name__, table, e
)
def _simple_upsert_txn(self, txn, table, keyvalues, values, insertion_values={},
lock=True):
def _simple_upsert_txn(
self,
txn,
table,
keyvalues,
values,
insertion_values={},
lock=True,
):
"""
Pick the UPSERT method which works best on the platform. Either the
native one (Pg9.5+, recent SQLites), or fall back to an emulated method.
Args:
txn: The transaction to use.
table (str): The table to upsert into
keyvalues (dict): The unique key tables and their new values
values (dict): The nonunique columns and their new values
insertion_values (dict): additional key/values to use only when
inserting
lock (bool): True to lock the table when doing the upsert.
Returns:
Deferred(None or bool): Native upserts always return None. Emulated
upserts return True if a new entry was created, False if an existing
one was updated.
"""
if (
self.database_engine.can_native_upsert
and table not in self._unsafe_to_upsert_tables
):
return self._simple_upsert_txn_native_upsert(
txn,
table,
keyvalues,
values,
insertion_values=insertion_values,
)
else:
return self._simple_upsert_txn_emulated(
txn,
table,
keyvalues,
values,
insertion_values=insertion_values,
lock=lock,
)
def _simple_upsert_txn_emulated(
self, txn, table, keyvalues, values, insertion_values={}, lock=True
):
# We need to lock the table :(, unless we're *really* careful
if lock:
self.database_engine.lock_table(txn, table)
def _getwhere(key):
# If the value we're passing in is None (aka NULL), we need to use
# IS, not =, as NULL = NULL equals NULL (False).
if keyvalues[key] is None:
return "%s IS ?" % (key,)
else:
return "%s = ?" % (key,)
# First try to update.
sql = "UPDATE %s SET %s WHERE %s" % (
table,
", ".join("%s = ?" % (k,) for k in values),
" AND ".join("%s = ?" % (k,) for k in keyvalues)
" AND ".join(_getwhere(k) for k in keyvalues)
)
sqlargs = list(values.values()) + list(keyvalues.values())
@ -569,12 +671,44 @@ class SQLBaseStore(object):
sql = "INSERT INTO %s (%s) VALUES (%s)" % (
table,
", ".join(k for k in allvalues),
", ".join("?" for _ in allvalues)
", ".join("?" for _ in allvalues),
)
txn.execute(sql, list(allvalues.values()))
# successfully inserted
return True
def _simple_upsert_txn_native_upsert(
self, txn, table, keyvalues, values, insertion_values={}
):
"""
Use the native UPSERT functionality in recent PostgreSQL versions.
Args:
table (str): The table to upsert into
keyvalues (dict): The unique key tables and their new values
values (dict): The nonunique columns and their new values
insertion_values (dict): additional key/values to use only when
inserting
Returns:
None
"""
allvalues = {}
allvalues.update(keyvalues)
allvalues.update(values)
allvalues.update(insertion_values)
sql = (
"INSERT INTO %s (%s) VALUES (%s) "
"ON CONFLICT (%s) DO UPDATE SET %s"
) % (
table,
", ".join(k for k in allvalues),
", ".join("?" for _ in allvalues),
", ".join(k for k in keyvalues),
", ".join(k + "=EXCLUDED." + k for k in values),
)
txn.execute(sql, list(allvalues.values()))
def _simple_select_one(self, table, keyvalues, retcols,
allow_none=False, desc="_simple_select_one"):
"""Executes a SELECT query on the named table, which is expected to

View file

@ -65,7 +65,27 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
columns=["last_seen"],
)
# (user_id, access_token, ip) -> (user_agent, device_id, last_seen)
self.register_background_update_handler(
"user_ips_remove_dupes",
self._remove_user_ip_dupes,
)
# Register a unique index
self.register_background_index_update(
"user_ips_device_unique_index",
index_name="user_ips_user_token_ip_unique_index",
table="user_ips",
columns=["user_id", "access_token", "ip"],
unique=True,
)
# Drop the old non-unique index
self.register_background_update_handler(
"user_ips_drop_nonunique_index",
self._remove_user_ip_nonunique,
)
# (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
self._batch_row_update = {}
self._client_ip_looper = self._clock.looping_call(
@ -75,6 +95,129 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
"before", "shutdown", self._update_client_ips_batch
)
@defer.inlineCallbacks
def _remove_user_ip_nonunique(self, progress, batch_size):
def f(conn):
txn = conn.cursor()
txn.execute(
"DROP INDEX IF EXISTS user_ips_user_ip"
)
txn.close()
yield self.runWithConnection(f)
yield self._end_background_update("user_ips_drop_nonunique_index")
defer.returnValue(1)
@defer.inlineCallbacks
def _remove_user_ip_dupes(self, progress, batch_size):
# This works function works by scanning the user_ips table in batches
# based on `last_seen`. For each row in a batch it searches the rest of
# the table to see if there are any duplicates, if there are then they
# are removed and replaced with a suitable row.
# Fetch the start of the batch
begin_last_seen = progress.get("last_seen", 0)
def get_last_seen(txn):
txn.execute(
"""
SELECT last_seen FROM user_ips
WHERE last_seen > ?
ORDER BY last_seen
LIMIT 1
OFFSET ?
""",
(begin_last_seen, batch_size)
)
row = txn.fetchone()
if row:
return row[0]
else:
return None
# Get a last seen that has roughly `batch_size` since `begin_last_seen`
end_last_seen = yield self.runInteraction(
"user_ips_dups_get_last_seen", get_last_seen
)
# If it returns None, then we're processing the last batch
last = end_last_seen is None
logger.info(
"Scanning for duplicate 'user_ips' rows in range: %s <= last_seen < %s",
begin_last_seen, end_last_seen,
)
def remove(txn):
# This works by looking at all entries in the given time span, and
# then for each (user_id, access_token, ip) tuple in that range
# checking for any duplicates in the rest of the table (via a join).
# It then only returns entries which have duplicates, and the max
# last_seen across all duplicates, which can the be used to delete
# all other duplicates.
# It is efficient due to the existence of (user_id, access_token,
# ip) and (last_seen) indices.
# Define the search space, which requires handling the last batch in
# a different way
if last:
clause = "? <= last_seen"
args = (begin_last_seen,)
else:
clause = "? <= last_seen AND last_seen < ?"
args = (begin_last_seen, end_last_seen)
txn.execute(
"""
SELECT user_id, access_token, ip,
MAX(device_id), MAX(user_agent), MAX(last_seen)
FROM (
SELECT user_id, access_token, ip
FROM user_ips
WHERE {}
) c
INNER JOIN user_ips USING (user_id, access_token, ip)
GROUP BY user_id, access_token, ip
HAVING count(*) > 1
""".format(clause),
args
)
res = txn.fetchall()
# We've got some duplicates
for i in res:
user_id, access_token, ip, device_id, user_agent, last_seen = i
# Drop all the duplicates
txn.execute(
"""
DELETE FROM user_ips
WHERE user_id = ? AND access_token = ? AND ip = ?
""",
(user_id, access_token, ip)
)
# Add in one to be the last_seen
txn.execute(
"""
INSERT INTO user_ips
(user_id, access_token, ip, device_id, user_agent, last_seen)
VALUES (?, ?, ?, ?, ?, ?)
""",
(user_id, access_token, ip, device_id, user_agent, last_seen)
)
self._background_update_progress_txn(
txn, "user_ips_remove_dupes", {"last_seen": end_last_seen}
)
yield self.runInteraction("user_ips_dups_remove", remove)
if last:
yield self._end_background_update("user_ips_remove_dupes")
defer.returnValue(batch_size)
@defer.inlineCallbacks
def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id,
now=None):
@ -114,7 +257,10 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
)
def _update_client_ips_batch_txn(self, txn, to_update):
self.database_engine.lock_table(txn, "user_ips")
if "user_ips" in self._unsafe_to_upsert_tables or (
not self.database_engine.can_native_upsert
):
self.database_engine.lock_table(txn, "user_ips")
for entry in iteritems(to_update):
(user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
@ -127,10 +273,10 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
"user_id": user_id,
"access_token": access_token,
"ip": ip,
"user_agent": user_agent,
"device_id": device_id,
},
values={
"user_agent": user_agent,
"device_id": device_id,
"last_seen": last_seen,
},
lock=False,
@ -227,7 +373,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
results = {}
for key in self._batch_row_update:
uid, access_token, ip = key
uid, access_token, ip, = key
if uid == user_id:
user_agent, _, last_seen = self._batch_row_update[key]
results[(access_token, ip)] = (user_agent, last_seen)

View file

@ -18,7 +18,7 @@ import platform
from ._base import IncorrectDatabaseSetup
from .postgres import PostgresEngine
from .sqlite3 import Sqlite3Engine
from .sqlite import Sqlite3Engine
SUPPORTED_MODULE = {
"sqlite3": Sqlite3Engine,

View file

@ -38,6 +38,13 @@ class PostgresEngine(object):
return sql.replace("?", "%s")
def on_new_connection(self, db_conn):
# Get the version of PostgreSQL that we're using. As per the psycopg2
# docs: The number is formed by converting the major, minor, and
# revision numbers into two-decimal-digit numbers and appending them
# together. For example, version 8.1.5 will be returned as 80105
self._version = db_conn.server_version
db_conn.set_isolation_level(
self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
)
@ -54,6 +61,13 @@ class PostgresEngine(object):
cursor.close()
@property
def can_native_upsert(self):
"""
Can we use native UPSERTs? This requires PostgreSQL 9.5+.
"""
return self._version >= 90500
def is_deadlock(self, error):
if isinstance(error, self.module.DatabaseError):
# https://www.postgresql.org/docs/current/static/errcodes-appendix.html

View file

@ -15,6 +15,7 @@
import struct
import threading
from sqlite3 import sqlite_version_info
from synapse.storage.prepare_database import prepare_database
@ -30,6 +31,14 @@ class Sqlite3Engine(object):
self._current_state_group_id = None
self._current_state_group_id_lock = threading.Lock()
@property
def can_native_upsert(self):
"""
Do we support native UPSERTs? This requires SQLite3 3.24+, plus some
more work we haven't done yet to tell what was inserted vs updated.
"""
return sqlite_version_info >= (3, 24, 0)
def check_database(self, txn):
pass

View file

@ -739,7 +739,18 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
}
events_map = {ev.event_id: ev for ev, _ in events_context}
room_version = yield self.get_room_version(room_id)
# We need to get the room version, which is in the create event.
# Normally that'd be in the database, but its also possible that we're
# currently trying to persist it.
room_version = None
for ev, _ in events_context:
if ev.type == EventTypes.Create and ev.state_key == "":
room_version = ev.content.get("room_version", "1")
break
if not room_version:
room_version = yield self.get_room_version(room_id)
logger.debug("calling resolve_state_groups from preserve_events")
res = yield self._state_resolution_handler.resolve_state_groups(

View file

@ -215,7 +215,7 @@ class PusherStore(PusherWorkerStore):
with self._pushers_id_gen.get_next() as stream_id:
# no need to lock because `pushers` has a unique key on
# (app_id, pushkey, user_name) so _simple_upsert will retry
newly_inserted = yield self._simple_upsert(
yield self._simple_upsert(
table="pushers",
keyvalues={
"app_id": app_id,
@ -238,7 +238,12 @@ class PusherStore(PusherWorkerStore):
lock=False,
)
if newly_inserted:
user_has_pusher = self.get_if_user_has_pusher.cache.get(
(user_id,), None, update_metrics=False
)
if user_has_pusher is not True:
# invalidate, since we the user might not have had a pusher before
yield self.runInteraction(
"add_pusher",
self._invalidate_cache_and_stream,

View file

@ -0,0 +1,26 @@
/* Copyright 2018 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- delete duplicates
INSERT INTO background_updates (update_name, progress_json) VALUES
('user_ips_remove_dupes', '{}');
-- add a new unique index to user_ips table
INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
('user_ips_device_unique_index', '{}', 'user_ips_remove_dupes');
-- drop the old original index
INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
('user_ips_drop_nonunique_index', '{}', 'user_ips_device_unique_index');

View file

@ -168,14 +168,14 @@ class UserDirectoryStore(SQLBaseStore):
if isinstance(self.database_engine, PostgresEngine):
# We weight the localpart most highly, then display name and finally
# server name
if new_entry:
if self.database_engine.can_native_upsert:
sql = """
INSERT INTO user_directory_search(user_id, vector)
VALUES (?,
setweight(to_tsvector('english', ?), 'A')
|| setweight(to_tsvector('english', ?), 'D')
|| setweight(to_tsvector('english', COALESCE(?, '')), 'B')
)
) ON CONFLICT (user_id) DO UPDATE SET vector=EXCLUDED.vector
"""
txn.execute(
sql,
@ -185,20 +185,45 @@ class UserDirectoryStore(SQLBaseStore):
)
)
else:
sql = """
UPDATE user_directory_search
SET vector = setweight(to_tsvector('english', ?), 'A')
|| setweight(to_tsvector('english', ?), 'D')
|| setweight(to_tsvector('english', COALESCE(?, '')), 'B')
WHERE user_id = ?
"""
txn.execute(
sql,
(
get_localpart_from_id(user_id), get_domain_from_id(user_id),
display_name, user_id,
# TODO: Remove this code after we've bumped the minimum version
# of postgres to always support upserts, so we can get rid of
# `new_entry` usage
if new_entry is True:
sql = """
INSERT INTO user_directory_search(user_id, vector)
VALUES (?,
setweight(to_tsvector('english', ?), 'A')
|| setweight(to_tsvector('english', ?), 'D')
|| setweight(to_tsvector('english', COALESCE(?, '')), 'B')
)
"""
txn.execute(
sql,
(
user_id, get_localpart_from_id(user_id),
get_domain_from_id(user_id), display_name,
)
)
elif new_entry is False:
sql = """
UPDATE user_directory_search
SET vector = setweight(to_tsvector('english', ?), 'A')
|| setweight(to_tsvector('english', ?), 'D')
|| setweight(to_tsvector('english', COALESCE(?, '')), 'B')
WHERE user_id = ?
"""
txn.execute(
sql,
(
get_localpart_from_id(user_id),
get_domain_from_id(user_id),
display_name, user_id,
)
)
else:
raise RuntimeError(
"upsert returned None when 'can_native_upsert' is False"
)
)
elif isinstance(self.database_engine, Sqlite3Engine):
value = "%s %s" % (user_id, display_name,) if display_name else user_id
self._simple_upsert_txn(

View file

@ -387,12 +387,14 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
deferred that wraps and times out the given deferred, correctly handling
the case where the given deferred's canceller throws.
(See https://twistedmatrix.com/trac/ticket/9534)
NOTE: Unlike `Deferred.addTimeout`, this function returns a new deferred
Args:
deferred (Deferred)
timeout (float): Timeout in seconds
reactor (twisted.internet.reactor): The twisted reactor to use
reactor (twisted.interfaces.IReactorTime): The twisted reactor to use
on_timeout_cancel (callable): A callable which is called immediately
after the deferred times out, and not if this deferred is
otherwise cancelled before the timeout.

View file

@ -51,7 +51,6 @@ class ConfigGenerationTestCase(unittest.TestCase):
"lemurs.win.log.config",
"lemurs.win.signing.key",
"lemurs.win.tls.crt",
"lemurs.win.tls.dh",
"lemurs.win.tls.key",
]
),

View file

@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -0,0 +1,240 @@
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from mock import Mock
import treq
from twisted.internet import defer
from twisted.internet.protocol import Factory
from twisted.protocols.tls import TLSMemoryBIOFactory
from twisted.test.ssl_helpers import ServerTLSContext
from twisted.web.http import HTTPChannel
from synapse.crypto.context_factory import ClientTLSOptionsFactory
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.util.logcontext import LoggingContext
from tests.server import FakeTransport, ThreadedMemoryReactorClock
from tests.unittest import TestCase
logger = logging.getLogger(__name__)
class MatrixFederationAgentTests(TestCase):
def setUp(self):
self.reactor = ThreadedMemoryReactorClock()
self.mock_resolver = Mock()
self.agent = MatrixFederationAgent(
reactor=self.reactor,
tls_client_options_factory=ClientTLSOptionsFactory(None),
_srv_resolver=self.mock_resolver,
)
def _make_connection(self, client_factory, expected_sni):
"""Builds a test server, and completes the outgoing client connection
Returns:
HTTPChannel: the test server
"""
# build the test server
server_tls_protocol = _build_test_server()
# now, tell the client protocol factory to build the client protocol (it will be a
# _WrappingProtocol, around a TLSMemoryBIOProtocol, around an
# HTTP11ClientProtocol) and wire the output of said protocol up to the server via
# a FakeTransport.
#
# Normally this would be done by the TCP socket code in Twisted, but we are
# stubbing that out here.
client_protocol = client_factory.buildProtocol(None)
client_protocol.makeConnection(FakeTransport(server_tls_protocol, self.reactor))
# tell the server tls protocol to send its stuff back to the client, too
server_tls_protocol.makeConnection(FakeTransport(client_protocol, self.reactor))
# give the reactor a pump to get the TLS juices flowing.
self.reactor.pump((0.1,))
# check the SNI
server_name = server_tls_protocol._tlsConnection.get_servername()
self.assertEqual(
server_name,
expected_sni,
"Expected SNI %s but got %s" % (expected_sni, server_name),
)
# fish the test server back out of the server-side TLS protocol.
return server_tls_protocol.wrappedProtocol
@defer.inlineCallbacks
def _make_get_request(self, uri):
"""
Sends a simple GET request via the agent, and checks its logcontext management
"""
with LoggingContext("one") as context:
fetch_d = self.agent.request(b'GET', uri)
# Nothing happened yet
self.assertNoResult(fetch_d)
# should have reset logcontext to the sentinel
_check_logcontext(LoggingContext.sentinel)
try:
fetch_res = yield fetch_d
defer.returnValue(fetch_res)
finally:
_check_logcontext(context)
def test_get(self):
"""
happy-path test of a GET request
"""
self.reactor.lookups["testserv"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix://testserv:8448/foo/bar")
# Nothing happened yet
self.assertNoResult(test_d)
# Make sure treq is trying to connect
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
self.assertEqual(host, '1.2.3.4')
self.assertEqual(port, 8448)
# make a test server, and wire up the client
http_server = self._make_connection(
client_factory,
expected_sni=b"testserv",
)
self.assertEqual(len(http_server.requests), 1)
request = http_server.requests[0]
self.assertEqual(request.method, b'GET')
self.assertEqual(request.path, b'/foo/bar')
self.assertEqual(
request.requestHeaders.getRawHeaders(b'host'),
[b'testserv:8448']
)
content = request.content.read()
self.assertEqual(content, b'')
# Deferred is still without a result
self.assertNoResult(test_d)
# send the headers
request.responseHeaders.setRawHeaders(b'Content-Type', [b'application/json'])
request.write('')
self.reactor.pump((0.1,))
response = self.successResultOf(test_d)
# that should give us a Response object
self.assertEqual(response.code, 200)
# Send the body
request.write('{ "a": 1 }'.encode('ascii'))
request.finish()
self.reactor.pump((0.1,))
# check it can be read
json = self.successResultOf(treq.json_content(response))
self.assertEqual(json, {"a": 1})
def test_get_ip_address(self):
"""
Test the behaviour when the server name contains an explicit IP (with no port)
"""
# the SRV lookup will return an empty list (XXX: why do we even do an SRV lookup?)
self.mock_resolver.resolve_service.side_effect = lambda _: []
# then there will be a getaddrinfo on the IP
self.reactor.lookups["1.2.3.4"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix://1.2.3.4/foo/bar")
# Nothing happened yet
self.assertNoResult(test_d)
self.mock_resolver.resolve_service.assert_called_once()
# Make sure treq is trying to connect
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
self.assertEqual(host, '1.2.3.4')
self.assertEqual(port, 8448)
# make a test server, and wire up the client
http_server = self._make_connection(
client_factory,
expected_sni=None,
)
self.assertEqual(len(http_server.requests), 1)
request = http_server.requests[0]
self.assertEqual(request.method, b'GET')
self.assertEqual(request.path, b'/foo/bar')
# XXX currently broken
# self.assertEqual(
# request.requestHeaders.getRawHeaders(b'host'),
# [b'1.2.3.4:8448']
# )
# finish the request
request.finish()
self.reactor.pump((0.1,))
self.successResultOf(test_d)
def _check_logcontext(context):
current = LoggingContext.current_context()
if current is not context:
raise AssertionError(
"Expected logcontext %s but was %s" % (context, current),
)
def _build_test_server():
"""Construct a test server
This builds an HTTP channel, wrapped with a TLSMemoryBIOProtocol
Returns:
TLSMemoryBIOProtocol
"""
server_factory = Factory.forProtocol(HTTPChannel)
# Request.finish expects the factory to have a 'log' method.
server_factory.log = _log_request
server_tls_factory = TLSMemoryBIOFactory(
ServerTLSContext(), isClient=False, wrappedFactory=server_factory,
)
return server_tls_factory.buildProtocol(None)
def _log_request(request):
"""Implements Factory.log, which is expected by Request.finish"""
logger.info("Completed request %s", request)

View file

@ -0,0 +1,207 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from mock import Mock
from twisted.internet import defer
from twisted.internet.defer import Deferred
from twisted.internet.error import ConnectError
from twisted.names import dns, error
from synapse.http.federation.srv_resolver import SrvResolver
from synapse.util.logcontext import LoggingContext
from tests import unittest
from tests.utils import MockClock
class SrvResolverTestCase(unittest.TestCase):
def test_resolve(self):
dns_client_mock = Mock()
service_name = b"test_service.example.com"
host_name = b"example.com"
answer_srv = dns.RRHeader(
type=dns.SRV, payload=dns.Record_SRV(target=host_name)
)
result_deferred = Deferred()
dns_client_mock.lookupService.return_value = result_deferred
cache = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
@defer.inlineCallbacks
def do_lookup():
with LoggingContext("one") as ctx:
resolve_d = resolver.resolve_service(service_name)
self.assertNoResult(resolve_d)
# should have reset to the sentinel context
self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel)
result = yield resolve_d
# should have restored our context
self.assertIs(LoggingContext.current_context(), ctx)
defer.returnValue(result)
test_d = do_lookup()
self.assertNoResult(test_d)
dns_client_mock.lookupService.assert_called_once_with(service_name)
result_deferred.callback(
([answer_srv], None, None)
)
servers = self.successResultOf(test_d)
self.assertEquals(len(servers), 1)
self.assertEquals(servers, cache[service_name])
self.assertEquals(servers[0].host, host_name)
@defer.inlineCallbacks
def test_from_cache_expired_and_dns_fail(self):
dns_client_mock = Mock()
dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
service_name = b"test_service.example.com"
entry = Mock(spec_set=["expires"])
entry.expires = 0
cache = {service_name: [entry]}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
servers = yield resolver.resolve_service(service_name)
dns_client_mock.lookupService.assert_called_once_with(service_name)
self.assertEquals(len(servers), 1)
self.assertEquals(servers, cache[service_name])
@defer.inlineCallbacks
def test_from_cache(self):
clock = MockClock()
dns_client_mock = Mock(spec_set=['lookupService'])
dns_client_mock.lookupService = Mock(spec_set=[])
service_name = b"test_service.example.com"
entry = Mock(spec_set=["expires"])
entry.expires = 999999999
cache = {service_name: [entry]}
resolver = SrvResolver(
dns_client=dns_client_mock, cache=cache, get_time=clock.time,
)
servers = yield resolver.resolve_service(service_name)
self.assertFalse(dns_client_mock.lookupService.called)
self.assertEquals(len(servers), 1)
self.assertEquals(servers, cache[service_name])
@defer.inlineCallbacks
def test_empty_cache(self):
dns_client_mock = Mock()
dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
service_name = b"test_service.example.com"
cache = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
with self.assertRaises(error.DNSServerError):
yield resolver.resolve_service(service_name)
@defer.inlineCallbacks
def test_name_error(self):
dns_client_mock = Mock()
dns_client_mock.lookupService.return_value = defer.fail(error.DNSNameError())
service_name = b"test_service.example.com"
cache = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
servers = yield resolver.resolve_service(service_name)
self.assertEquals(len(servers), 0)
self.assertEquals(len(cache), 0)
def test_disabled_service(self):
"""
test the behaviour when there is a single record which is ".".
"""
service_name = b"test_service.example.com"
lookup_deferred = Deferred()
dns_client_mock = Mock()
dns_client_mock.lookupService.return_value = lookup_deferred
cache = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
resolve_d = resolver.resolve_service(service_name)
self.assertNoResult(resolve_d)
# returning a single "." should make the lookup fail with a ConenctError
lookup_deferred.callback((
[dns.RRHeader(type=dns.SRV, payload=dns.Record_SRV(target=b"."))],
None,
None,
))
self.failureResultOf(resolve_d, ConnectError)
def test_non_srv_answer(self):
"""
test the behaviour when the dns server gives us a spurious non-SRV response
"""
service_name = b"test_service.example.com"
lookup_deferred = Deferred()
dns_client_mock = Mock()
dns_client_mock.lookupService.return_value = lookup_deferred
cache = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
resolve_d = resolver.resolve_service(service_name)
self.assertNoResult(resolve_d)
lookup_deferred.callback((
[
dns.RRHeader(type=dns.A, payload=dns.Record_A()),
dns.RRHeader(type=dns.SRV, payload=dns.Record_SRV(target=b"host")),
],
None,
None,
))
servers = self.successResultOf(resolve_d)
self.assertEquals(len(servers), 1)
self.assertEquals(servers, cache[service_name])
self.assertEquals(servers[0].host, b"host")

View file

@ -15,8 +15,10 @@
from mock import Mock
from twisted.internet import defer
from twisted.internet.defer import TimeoutError
from twisted.internet.error import ConnectingCancelledError, DNSLookupError
from twisted.test.proto_helpers import StringTransport
from twisted.web.client import ResponseNeverReceived
from twisted.web.http import HTTPChannel
@ -25,11 +27,20 @@ from synapse.http.matrixfederationclient import (
MatrixFederationHttpClient,
MatrixFederationRequest,
)
from synapse.util.logcontext import LoggingContext
from tests.server import FakeTransport
from tests.unittest import HomeserverTestCase
def check_logcontext(context):
current = LoggingContext.current_context()
if current is not context:
raise AssertionError(
"Expected logcontext %s but was %s" % (context, current),
)
class FederationClientTests(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
@ -42,9 +53,73 @@ class FederationClientTests(HomeserverTestCase):
self.cl = MatrixFederationHttpClient(self.hs)
self.reactor.lookups["testserv"] = "1.2.3.4"
def test_client_get(self):
"""
happy-path test of a GET request
"""
@defer.inlineCallbacks
def do_request():
with LoggingContext("one") as context:
fetch_d = self.cl.get_json("testserv:8008", "foo/bar")
# Nothing happened yet
self.assertNoResult(fetch_d)
# should have reset logcontext to the sentinel
check_logcontext(LoggingContext.sentinel)
try:
fetch_res = yield fetch_d
defer.returnValue(fetch_res)
finally:
check_logcontext(context)
test_d = do_request()
self.pump()
# Nothing happened yet
self.assertNoResult(test_d)
# Make sure treq is trying to connect
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1)
(host, port, factory, _timeout, _bindAddress) = clients[0]
self.assertEqual(host, '1.2.3.4')
self.assertEqual(port, 8008)
# complete the connection and wire it up to a fake transport
protocol = factory.buildProtocol(None)
transport = StringTransport()
protocol.makeConnection(transport)
# that should have made it send the request to the transport
self.assertRegex(transport.value(), b"^GET /foo/bar")
# Deferred is still without a result
self.assertNoResult(test_d)
# Send it the HTTP response
res_json = '{ "a": 1 }'.encode('ascii')
protocol.dataReceived(
b"HTTP/1.1 200 OK\r\n"
b"Server: Fake\r\n"
b"Content-Type: application/json\r\n"
b"Content-Length: %i\r\n"
b"\r\n"
b"%s" % (len(res_json), res_json)
)
self.pump()
res = self.successResultOf(test_d)
# check the response is as expected
self.assertEqual(res, {"a": 1})
def test_dns_error(self):
"""
If the DNS raising returns an error, it will bubble up.
If the DNS lookup returns an error, it will bubble up.
"""
d = self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000)
self.pump()
@ -53,6 +128,28 @@ class FederationClientTests(HomeserverTestCase):
self.assertIsInstance(f.value, RequestSendFailed)
self.assertIsInstance(f.value.inner_exception, DNSLookupError)
def test_client_connection_refused(self):
d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
self.pump()
# Nothing happened yet
self.assertNoResult(d)
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1)
(host, port, factory, _timeout, _bindAddress) = clients[0]
self.assertEqual(host, '1.2.3.4')
self.assertEqual(port, 8008)
e = Exception("go away")
factory.clientConnectionFailed(None, e)
self.pump(0.5)
f = self.failureResultOf(d)
self.assertIsInstance(f.value, RequestSendFailed)
self.assertIs(f.value.inner_exception, e)
def test_client_never_connect(self):
"""
If the HTTP request is not connected and is timed out, it'll give a
@ -63,7 +160,7 @@ class FederationClientTests(HomeserverTestCase):
self.pump()
# Nothing happened yet
self.assertFalse(d.called)
self.assertNoResult(d)
# Make sure treq is trying to connect
clients = self.reactor.tcpClients
@ -72,7 +169,7 @@ class FederationClientTests(HomeserverTestCase):
self.assertEqual(clients[0][1], 8008)
# Deferred is still without a result
self.assertFalse(d.called)
self.assertNoResult(d)
# Push by enough to time it out
self.reactor.advance(10.5)
@ -94,7 +191,7 @@ class FederationClientTests(HomeserverTestCase):
self.pump()
# Nothing happened yet
self.assertFalse(d.called)
self.assertNoResult(d)
# Make sure treq is trying to connect
clients = self.reactor.tcpClients
@ -107,7 +204,7 @@ class FederationClientTests(HomeserverTestCase):
client.makeConnection(conn)
# Deferred is still without a result
self.assertFalse(d.called)
self.assertNoResult(d)
# Push by enough to time it out
self.reactor.advance(10.5)
@ -135,7 +232,7 @@ class FederationClientTests(HomeserverTestCase):
client.makeConnection(conn)
# Deferred does not have a result
self.assertFalse(d.called)
self.assertNoResult(d)
# Send it the HTTP response
client.dataReceived(b"HTTP/1.1 200 OK\r\nServer: Fake\r\n\r\n")
@ -159,7 +256,7 @@ class FederationClientTests(HomeserverTestCase):
client.makeConnection(conn)
# Deferred does not have a result
self.assertFalse(d.called)
self.assertNoResult(d)
# Send it the HTTP response
client.dataReceived(
@ -195,3 +292,42 @@ class FederationClientTests(HomeserverTestCase):
request = server.requests[0]
content = request.content.read()
self.assertEqual(content, b'{"a":"b"}')
def test_closes_connection(self):
"""Check that the client closes unused HTTP connections"""
d = self.cl.get_json("testserv:8008", "foo/bar")
self.pump()
# there should have been a call to connectTCP
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1)
(_host, _port, factory, _timeout, _bindAddress) = clients[0]
# complete the connection and wire it up to a fake transport
client = factory.buildProtocol(None)
conn = StringTransport()
client.makeConnection(conn)
# that should have made it send the request to the connection
self.assertRegex(conn.value(), b"^GET /foo/bar")
# Send the HTTP response
client.dataReceived(
b"HTTP/1.1 200 OK\r\n"
b"Content-Type: application/json\r\n"
b"Content-Length: 2\r\n"
b"\r\n"
b"{}"
)
# We should get a successful response
r = self.successResultOf(d)
self.assertEqual(r, {})
self.assertFalse(conn.disconnecting)
# wait for a while
self.pump(120)
self.assertTrue(conn.disconnecting)

View file

@ -1,4 +1,5 @@
import json
import logging
from io import BytesIO
from six import text_type
@ -22,6 +23,8 @@ from synapse.util import Clock
from tests.utils import setup_test_homeserver as _sth
logger = logging.getLogger(__name__)
class TimedOutException(Exception):
"""
@ -339,7 +342,7 @@ def get_clock():
return (clock, hs_clock)
@attr.s
@attr.s(cmp=False)
class FakeTransport(object):
"""
A twisted.internet.interfaces.ITransport implementation which sends all its data
@ -414,6 +417,11 @@ class FakeTransport(object):
self.buffer = self.buffer + byt
def _write():
if not self.buffer:
# nothing to do. Don't write empty buffers: it upsets the
# TLSMemoryBIOProtocol
return
if getattr(self.other, "transport") is not None:
self.other.dataReceived(self.buffer)
self.buffer = b""
@ -421,7 +429,10 @@ class FakeTransport(object):
self._reactor.callLater(0.0, _write)
_write()
# always actually do the write asynchronously. Some protocols (notably the
# TLSMemoryBIOProtocol) get very confused if a read comes back while they are
# still doing a write. Doing a callLater here breaks the cycle.
self._reactor.callLater(0.0, _write)
def writeSequence(self, seq):
for x in seq:

View file

@ -49,6 +49,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.db_pool.runWithConnection = runWithConnection
config = Mock()
config._enable_native_upserts = False
config.event_cache_size = 1
config.database_config = {"name": "sqlite3"}
hs = TestHomeServer(

View file

@ -62,6 +62,77 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
r,
)
def test_insert_new_client_ip_none_device_id(self):
"""
An insert with a device ID of NULL will not create a new entry, but
update an existing entry in the user_ips table.
"""
self.reactor.advance(12345678)
user_id = "@user:id"
# Add & trigger the storage loop
self.get_success(
self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", None
)
)
self.reactor.advance(200)
self.pump(0)
result = self.get_success(
self.store._simple_select_list(
table="user_ips",
keyvalues={"user_id": user_id},
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
desc="get_user_ip_and_agents",
)
)
self.assertEqual(
result,
[
{
'access_token': 'access_token',
'ip': 'ip',
'user_agent': 'user_agent',
'device_id': None,
'last_seen': 12345678000,
}
],
)
# Add another & trigger the storage loop
self.get_success(
self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", None
)
)
self.reactor.advance(10)
self.pump(0)
result = self.get_success(
self.store._simple_select_list(
table="user_ips",
keyvalues={"user_id": user_id},
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
desc="get_user_ip_and_agents",
)
)
# Only one result, has been upserted.
self.assertEqual(
result,
[
{
'access_token': 'access_token',
'ip': 'ip',
'user_agent': 'user_agent',
'device_id': None,
'last_seen': 12345878000,
}
],
)
def test_disabled_monthly_active_user(self):
self.hs.config.limit_usage_by_mau = False
self.hs.config.max_mau_value = 50

View file

@ -1,129 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from mock import Mock
from twisted.internet import defer
from twisted.names import dns, error
from synapse.http.endpoint import resolve_service
from tests.utils import MockClock
from . import unittest
@unittest.DEBUG
class DnsTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_resolve(self):
dns_client_mock = Mock()
service_name = "test_service.example.com"
host_name = "example.com"
answer_srv = dns.RRHeader(
type=dns.SRV, payload=dns.Record_SRV(target=host_name)
)
dns_client_mock.lookupService.return_value = defer.succeed(
([answer_srv], None, None)
)
cache = {}
servers = yield resolve_service(
service_name, dns_client=dns_client_mock, cache=cache
)
dns_client_mock.lookupService.assert_called_once_with(service_name)
self.assertEquals(len(servers), 1)
self.assertEquals(servers, cache[service_name])
self.assertEquals(servers[0].host, host_name)
@defer.inlineCallbacks
def test_from_cache_expired_and_dns_fail(self):
dns_client_mock = Mock()
dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
service_name = "test_service.example.com"
entry = Mock(spec_set=["expires"])
entry.expires = 0
cache = {service_name: [entry]}
servers = yield resolve_service(
service_name, dns_client=dns_client_mock, cache=cache
)
dns_client_mock.lookupService.assert_called_once_with(service_name)
self.assertEquals(len(servers), 1)
self.assertEquals(servers, cache[service_name])
@defer.inlineCallbacks
def test_from_cache(self):
clock = MockClock()
dns_client_mock = Mock(spec_set=['lookupService'])
dns_client_mock.lookupService = Mock(spec_set=[])
service_name = "test_service.example.com"
entry = Mock(spec_set=["expires"])
entry.expires = 999999999
cache = {service_name: [entry]}
servers = yield resolve_service(
service_name, dns_client=dns_client_mock, cache=cache, clock=clock
)
self.assertFalse(dns_client_mock.lookupService.called)
self.assertEquals(len(servers), 1)
self.assertEquals(servers, cache[service_name])
@defer.inlineCallbacks
def test_empty_cache(self):
dns_client_mock = Mock()
dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
service_name = "test_service.example.com"
cache = {}
with self.assertRaises(error.DNSServerError):
yield resolve_service(service_name, dns_client=dns_client_mock, cache=cache)
@defer.inlineCallbacks
def test_name_error(self):
dns_client_mock = Mock()
dns_client_mock.lookupService.return_value = defer.fail(error.DNSNameError())
service_name = "test_service.example.com"
cache = {}
servers = yield resolve_service(
service_name, dns_client=dns_client_mock, cache=cache
)
self.assertEquals(len(servers), 0)
self.assertEquals(len(cache), 0)

View file

@ -19,7 +19,7 @@ from six import StringIO
from twisted.internet.defer import Deferred
from twisted.python.failure import Failure
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
from twisted.test.proto_helpers import AccumulatingProtocol
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
@ -30,12 +30,18 @@ from synapse.util import Clock
from synapse.util.logcontext import make_deferred_yieldable
from tests import unittest
from tests.server import FakeTransport, make_request, render, setup_test_homeserver
from tests.server import (
FakeTransport,
ThreadedMemoryReactorClock,
make_request,
render,
setup_test_homeserver,
)
class JsonResourceTests(unittest.TestCase):
def setUp(self):
self.reactor = MemoryReactorClock()
self.reactor = ThreadedMemoryReactorClock()
self.hs_clock = Clock(self.reactor)
self.homeserver = setup_test_homeserver(
self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.reactor

View file

@ -96,7 +96,7 @@ class TestCase(unittest.TestCase):
method = getattr(self, methodName)
level = getattr(method, "loglevel", getattr(self, "loglevel", logging.ERROR))
level = getattr(method, "loglevel", getattr(self, "loglevel", logging.WARNING))
@around(self)
def setUp(orig):
@ -333,7 +333,15 @@ class HomeserverTestCase(TestCase):
"""
kwargs = dict(kwargs)
kwargs.update(self._hs_args)
return setup_test_homeserver(self.addCleanup, *args, **kwargs)
hs = setup_test_homeserver(self.addCleanup, *args, **kwargs)
stor = hs.get_datastore()
# Run the database background updates.
if hasattr(stor, "do_next_background_update"):
while not self.get_success(stor.has_completed_background_updates()):
self.get_success(stor.do_next_background_update(1))
return hs
def pump(self, by=0.0):
"""

View file

@ -0,0 +1,104 @@
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from twisted.internet.defer import CancelledError, Deferred
from twisted.internet.task import Clock
from synapse.util import logcontext
from synapse.util.async_helpers import timeout_deferred
from synapse.util.logcontext import LoggingContext
from tests.unittest import TestCase
class TimeoutDeferredTest(TestCase):
def setUp(self):
self.clock = Clock()
def test_times_out(self):
"""Basic test case that checks that the original deferred is cancelled and that
the timing-out deferred is errbacked
"""
cancelled = [False]
def canceller(_d):
cancelled[0] = True
non_completing_d = Deferred(canceller)
timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock)
self.assertNoResult(timing_out_d)
self.assertFalse(cancelled[0], "deferred was cancelled prematurely")
self.clock.pump((1.0, ))
self.assertTrue(cancelled[0], "deferred was not cancelled by timeout")
self.failureResultOf(timing_out_d, defer.TimeoutError, )
def test_times_out_when_canceller_throws(self):
"""Test that we have successfully worked around
https://twistedmatrix.com/trac/ticket/9534"""
def canceller(_d):
raise Exception("can't cancel this deferred")
non_completing_d = Deferred(canceller)
timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock)
self.assertNoResult(timing_out_d)
self.clock.pump((1.0, ))
self.failureResultOf(timing_out_d, defer.TimeoutError, )
def test_logcontext_is_preserved_on_cancellation(self):
blocking_was_cancelled = [False]
@defer.inlineCallbacks
def blocking():
non_completing_d = Deferred()
with logcontext.PreserveLoggingContext():
try:
yield non_completing_d
except CancelledError:
blocking_was_cancelled[0] = True
raise
with logcontext.LoggingContext("one") as context_one:
# the errbacks should be run in the test logcontext
def errback(res, deferred_name):
self.assertIs(
LoggingContext.current_context(), context_one,
"errback %s run in unexpected logcontext %s" % (
deferred_name, LoggingContext.current_context(),
)
)
return res
original_deferred = blocking()
original_deferred.addErrback(errback, "orig")
timing_out_d = timeout_deferred(original_deferred, 1.0, self.clock)
self.assertNoResult(timing_out_d)
self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel)
timing_out_d.addErrback(errback, "timingout")
self.clock.pump((1.0, ))
self.assertTrue(
blocking_was_cancelled[0],
"non-completing deferred was not cancelled",
)
self.failureResultOf(timing_out_d, defer.TimeoutError, )
self.assertIs(LoggingContext.current_context(), context_one)

View file

@ -149,4 +149,5 @@ deps =
codecov
commands =
coverage combine
coverage xml
codecov -X gcov