Merge branch 'anoa/room_dir_quick_fix' into matrix-org-hotfixes
This commit is contained in:
commit
acaca1b4e9
15
.codecov.yml
Normal file
15
.codecov.yml
Normal file
|
@ -0,0 +1,15 @@
|
|||
comment:
|
||||
layout: "diff"
|
||||
|
||||
coverage:
|
||||
status:
|
||||
project:
|
||||
default:
|
||||
target: 0 # Target % coverage, can be auto. Turned off for now
|
||||
threshold: null
|
||||
base: auto
|
||||
patch:
|
||||
default:
|
||||
target: 0
|
||||
threshold: null
|
||||
base: auto
|
|
@ -1,11 +1,7 @@
|
|||
[run]
|
||||
branch = True
|
||||
parallel = True
|
||||
source = synapse
|
||||
|
||||
[paths]
|
||||
source=
|
||||
coverage
|
||||
include = synapse/*
|
||||
|
||||
[report]
|
||||
precision = 2
|
||||
|
|
6
.gitignore
vendored
6
.gitignore
vendored
|
@ -25,9 +25,9 @@ homeserver*.pid
|
|||
*.tls.dh
|
||||
*.tls.key
|
||||
|
||||
.coverage
|
||||
.coverage.*
|
||||
!.coverage.rc
|
||||
.coverage*
|
||||
coverage.*
|
||||
!.coveragerc
|
||||
htmlcov
|
||||
|
||||
demo/*/*.db
|
||||
|
|
10
.travis.yml
10
.travis.yml
|
@ -12,6 +12,9 @@ cache:
|
|||
#
|
||||
- $HOME/.cache/pip/wheels
|
||||
|
||||
addons:
|
||||
postgresql: "9.4"
|
||||
|
||||
# don't clone the whole repo history, one commit will do
|
||||
git:
|
||||
depth: 1
|
||||
|
@ -68,6 +71,13 @@ matrix:
|
|||
|
||||
install:
|
||||
- pip install tox
|
||||
|
||||
# if we don't have python3.6 in this environment, travis unhelpfully gives us
|
||||
# a `python3.6` on our path which does nothing but spit out a warning. Tox
|
||||
# tries to run it (even if we're not running a py36 env), so the build logs
|
||||
# then have warnings which look like errors. To reduce the noise, remove the
|
||||
# non-functional python3.6.
|
||||
- ( ! command -v python3.6 || python3.6 --version ) &>/dev/null || rm -f $(command -v python3.6)
|
||||
|
||||
script:
|
||||
- tox -e $TOX_ENV
|
||||
|
|
|
@ -37,6 +37,7 @@ prune docker
|
|||
prune .circleci
|
||||
prune .coveragerc
|
||||
prune debian
|
||||
prune .codecov.yml
|
||||
|
||||
exclude jenkins*
|
||||
recursive-exclude jenkins *.sh
|
||||
|
|
|
@ -184,7 +184,7 @@ Configuring Synapse
|
|||
Before you can start Synapse, you will need to generate a configuration
|
||||
file. To do this, run (in your virtualenv, as before)::
|
||||
|
||||
cd ~/.synapse
|
||||
cd ~/synapse
|
||||
python -m synapse.app.homeserver \
|
||||
--server-name my.domain.name \
|
||||
--config-path homeserver.yaml \
|
||||
|
@ -220,7 +220,7 @@ is configured to use TLS with a self-signed certificate. If you would like
|
|||
to do initial test with a client without having to setup a reverse proxy,
|
||||
you can temporarly use another certificate. (Note that a self-signed
|
||||
certificate is fine for `Federation`_). You can do so by changing
|
||||
``tls_certificate_path``, ``tls_private_key_path`` and ``tls_dh_params_path``
|
||||
``tls_certificate_path`` and ``tls_private_key_path``
|
||||
in ``homeserver.yaml``; alternatively, you can use a reverse-proxy, but be sure
|
||||
to read `Using a reverse proxy with Synapse`_ when doing so.
|
||||
|
||||
|
@ -796,8 +796,7 @@ A manual password reset can be done via direct database access as follows.
|
|||
|
||||
First calculate the hash of the new password::
|
||||
|
||||
$ source ~/.synapse/bin/activate
|
||||
$ ./scripts/hash_password
|
||||
$ ~/synapse/env/bin/hash_password
|
||||
Password:
|
||||
Confirm password:
|
||||
$2a$12$xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
||||
|
|
1
changelog.d/4229.feature
Normal file
1
changelog.d/4229.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Synapse's cipher string has been updated to require ECDH key exchange. Configuring and generating dh_params is no longer required, and they will be ignored.
|
1
changelog.d/4306.misc
Normal file
1
changelog.d/4306.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Synapse will now take advantage of native UPSERT functionality in PostgreSQL 9.5+ and SQLite 3.24+.
|
1
changelog.d/4342.misc
Normal file
1
changelog.d/4342.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Update README to use the new virtualenv everywhere
|
1
changelog.d/4368.misc
Normal file
1
changelog.d/4368.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Add better logging for unexpected errors while sending transactions
|
1
changelog.d/4369.bugfix
Normal file
1
changelog.d/4369.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Prevent users with access tokens predating the introduction of device IDs from creating spurious entries in the user_ips table.
|
1
changelog.d/4370.misc
Normal file
1
changelog.d/4370.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Apply a unique index to the user_ips table, preventing duplicates.
|
1
changelog.d/4377.misc
Normal file
1
changelog.d/4377.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Silence travis-ci build warnings by removing non-functional python3.6
|
1
changelog.d/4384.feature
Normal file
1
changelog.d/4384.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Synapse can now automatically provision TLS certificates via ACME (the protocol used by CAs like Let's Encrypt).
|
1
changelog.d/4387.misc
Normal file
1
changelog.d/4387.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Fix a comment in the generated config file
|
1
changelog.d/4390.misc
Normal file
1
changelog.d/4390.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Add ground work for implementing future federation API versions
|
1
changelog.d/4392.bugfix
Normal file
1
changelog.d/4392.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Fix typo in ALL_USER_TYPES definition to ensure type is a tuple
|
1
changelog.d/4397.bugfix
Normal file
1
changelog.d/4397.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Fix high CPU usage due to remote devicelist updates
|
1
changelog.d/4399.misc
Normal file
1
changelog.d/4399.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Update dependencies on msgpack and pymacaroons to use the up-to-date packages.
|
1
changelog.d/4400.misc
Normal file
1
changelog.d/4400.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Tweak codecov settings to make them less loud.
|
1
changelog.d/4402.misc
Normal file
1
changelog.d/4402.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Implement server support for MSC1794 - Federation v2 Invite API
|
1
changelog.d/4404.bugfix
Normal file
1
changelog.d/4404.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Fix potential bug where creating or joining a room could fail
|
1
changelog.d/4407.bugfix
Normal file
1
changelog.d/4407.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Fix incorrect logcontexts after a Deferred was cancelled
|
1
changelog.d/4408.misc
Normal file
1
changelog.d/4408.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Refactor 'sign_request' as 'build_auth_headers'
|
1
changelog.d/4409.misc
Normal file
1
changelog.d/4409.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Remove redundant federation connection wrapping code
|
1
changelog.d/4411.bugfix
Normal file
1
changelog.d/4411.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Ensure encrypted room state is persisted across room upgrades.
|
1
changelog.d/4423.feature
Normal file
1
changelog.d/4423.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Config option to disable requesting MSISDN on registration.
|
1
changelog.d/4426.misc
Normal file
1
changelog.d/4426.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Remove redundant SynapseKeyClientProtocol magic
|
1
changelog.d/4427.misc
Normal file
1
changelog.d/4427.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Refactor and cleanup for SRV record lookup
|
1
changelog.d/4428.misc
Normal file
1
changelog.d/4428.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Move SRV logic into the Agent layer
|
1
changelog.d/4432.misc
Normal file
1
changelog.d/4432.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Apply a unique index to the user_ips table, preventing duplicates.
|
1
changelog.d/4433.misc
Normal file
1
changelog.d/4433.misc
Normal file
|
@ -0,0 +1 @@
|
|||
debian package: symlink to explicit python version
|
1
changelog.d/4434.misc
Normal file
1
changelog.d/4434.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Apply a unique index to the user_ips table, preventing duplicates.
|
1
changelog.d/4445.feature
Normal file
1
changelog.d/4445.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Add a metric for tracking event stream position of the user directory.
|
1
changelog.d/4452.bugfix
Normal file
1
changelog.d/4452.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Don't send IP addresses as SNI
|
1
changelog.d/4461.bugfix
Normal file
1
changelog.d/4461.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Add a timeout to filtered room directory queries.
|
15
debian/build_virtualenv
vendored
15
debian/build_virtualenv
vendored
|
@ -6,7 +6,16 @@
|
|||
set -e
|
||||
|
||||
export DH_VIRTUALENV_INSTALL_ROOT=/opt/venvs
|
||||
SNAKE=/usr/bin/python3
|
||||
|
||||
# make sure that the virtualenv links to the specific version of python, by
|
||||
# dereferencing the python3 symlink.
|
||||
#
|
||||
# Otherwise, if somebody tries to install (say) the stretch package on buster,
|
||||
# they will get a confusing error about "No module named 'synapse'", because
|
||||
# python won't look in the right directory. At least this way, the error will
|
||||
# be a *bit* more obvious.
|
||||
#
|
||||
SNAKE=`readlink -e /usr/bin/python3`
|
||||
|
||||
# try to set the CFLAGS so any compiled C extensions are compiled with the most
|
||||
# generic as possible x64 instructions, so that compiling it on a new Intel chip
|
||||
|
@ -46,3 +55,7 @@ cp -r tests "$tmpdir"
|
|||
PYTHONPATH="$tmpdir" \
|
||||
debian/matrix-synapse-py3/opt/venvs/matrix-synapse/bin/python \
|
||||
-B -m twisted.trial --reporter=text -j2 tests
|
||||
|
||||
# add a dependency on the right version of python to substvars.
|
||||
PYPKG=`basename $SNAKE`
|
||||
echo "synapse:pydepends=$PYPKG" >> debian/matrix-synapse-py3.substvars
|
||||
|
|
12
debian/changelog
vendored
12
debian/changelog
vendored
|
@ -1,3 +1,15 @@
|
|||
matrix-synapse-py3 (0.34.1.1++1) stable; urgency=medium
|
||||
|
||||
* Update conflicts specifications to allow smoother transition from matrix-synapse.
|
||||
|
||||
-- Synapse Packaging team <packages@matrix.org> Sat, 12 Jan 2019 12:58:35 +0000
|
||||
|
||||
matrix-synapse-py3 (0.34.1.1) stable; urgency=high
|
||||
|
||||
* New synapse release 0.34.1.1
|
||||
|
||||
-- Synapse Packaging team <packages@matrix.org> Thu, 10 Jan 2019 15:04:52 +0000
|
||||
|
||||
matrix-synapse-py3 (0.34.1+1) stable; urgency=medium
|
||||
|
||||
* Remove 'Breaks: matrix-synapse-ldap3'. (matrix-synapse-py3 includes
|
||||
|
|
8
debian/control
vendored
8
debian/control
vendored
|
@ -19,16 +19,16 @@ Homepage: https://github.com/matrix-org/synapse
|
|||
Package: matrix-synapse-py3
|
||||
Architecture: amd64
|
||||
Provides: matrix-synapse
|
||||
Breaks:
|
||||
matrix-synapse (<< 0.34.0-0matrix2),
|
||||
matrix-synapse (>= 0.34.0-1),
|
||||
Conflicts:
|
||||
matrix-synapse (<< 0.34.0.1-0matrix2),
|
||||
matrix-synapse (>= 0.34.0.1-1),
|
||||
Pre-Depends: dpkg (>= 1.16.1)
|
||||
Depends:
|
||||
adduser,
|
||||
debconf,
|
||||
python3-distutils|libpython3-stdlib (<< 3.6),
|
||||
python3,
|
||||
${misc:Depends},
|
||||
${synapse:pydepends},
|
||||
# some of our scripts use perl, but none of them are important,
|
||||
# so we put perl:Depends in Suggests rather than Depends.
|
||||
Suggests:
|
||||
|
|
3
debian/homeserver.yaml
vendored
3
debian/homeserver.yaml
vendored
|
@ -9,9 +9,6 @@ tls_certificate_path: "/etc/matrix-synapse/homeserver.tls.crt"
|
|||
# PEM encoded private key for TLS
|
||||
tls_private_key_path: "/etc/matrix-synapse/homeserver.tls.key"
|
||||
|
||||
# PEM dh parameters for ephemeral keys
|
||||
tls_dh_params_path: "/etc/matrix-synapse/homeserver.tls.dh"
|
||||
|
||||
# Don't bind to the https port
|
||||
no_tls: False
|
||||
|
||||
|
|
|
@ -1,9 +0,0 @@
|
|||
2048-bit DH parameters taken from rfc3526
|
||||
-----BEGIN DH PARAMETERS-----
|
||||
MIIBCAKCAQEA///////////JD9qiIWjCNMTGYouA3BzRKQJOCIpnzHQCC76mOxOb
|
||||
IlFKCHmONATd75UZs806QxswKwpt8l8UN0/hNW1tUcJF5IW1dmJefsb0TELppjft
|
||||
awv/XLb0Brft7jhr+1qJn6WunyQRfEsf5kkoZlHs5Fs9wgB8uKFjvwWY2kg2HFXT
|
||||
mmkWP6j9JM9fg2VdI9yjrZYcYvNWIIVSu57VKQdwlpZtZww1Tkq8mATxdGwIyhgh
|
||||
fDKQXkYuNs474553LBgOhgObJ4Oi7Aeij7XFXfBvTFLJ3ivL9pVYFxg5lUl86pVq
|
||||
5RXSJhiY+gUQFXKOWoqsqmj//////////wIBAg==
|
||||
-----END DH PARAMETERS-----
|
|
@ -1,46 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Build the Debian packages using Docker images.
|
||||
#
|
||||
# This script builds the Docker images and then executes them sequentially, each
|
||||
# one building a Debian package for the targeted operating system. It is
|
||||
# designed to be a "single command" to produce all the images.
|
||||
#
|
||||
# By default, builds for all known distributions, but a list of distributions
|
||||
# can be passed on the commandline for debugging.
|
||||
|
||||
set -ex
|
||||
|
||||
cd `dirname $0`
|
||||
|
||||
if [ $# -lt 1 ]; then
|
||||
DISTS=(
|
||||
debian:stretch
|
||||
debian:buster
|
||||
debian:sid
|
||||
ubuntu:xenial
|
||||
ubuntu:bionic
|
||||
ubuntu:cosmic
|
||||
)
|
||||
else
|
||||
DISTS=("$@")
|
||||
fi
|
||||
|
||||
# Make the dir where the debs will live.
|
||||
#
|
||||
# Note that we deliberately put this outside the source tree, otherwise we tend
|
||||
# to get source packages which are full of debs. (We could hack around that
|
||||
# with more magic in the build_debian.sh script, but that doesn't solve the
|
||||
# problem for natively-run dpkg-buildpakage).
|
||||
|
||||
mkdir -p ../../debs
|
||||
|
||||
# Build each OS image;
|
||||
for i in "${DISTS[@]}"; do
|
||||
TAG=$(echo ${i} | cut -d ":" -f 2)
|
||||
docker build --tag dh-venv-builder:${TAG} --build-arg distro=${i} -f Dockerfile-dhvirtualenv .
|
||||
docker run -it --rm --volume=$(pwd)/../\:/synapse/source:ro --volume=$(pwd)/../../debs:/debs \
|
||||
-e TARGET_USERID=$(id -u) \
|
||||
-e TARGET_GROUPID=$(id -g) \
|
||||
dh-venv-builder:${TAG}
|
||||
done
|
|
@ -4,7 +4,6 @@
|
|||
|
||||
tls_certificate_path: "/data/{{ SYNAPSE_SERVER_NAME }}.tls.crt"
|
||||
tls_private_key_path: "/data/{{ SYNAPSE_SERVER_NAME }}.tls.key"
|
||||
tls_dh_params_path: "/data/{{ SYNAPSE_SERVER_NAME }}.tls.dh"
|
||||
no_tls: {{ "True" if SYNAPSE_NO_TLS else "False" }}
|
||||
tls_fingerprints: []
|
||||
|
||||
|
|
154
scripts-dev/build_debian_packages
Executable file
154
scripts-dev/build_debian_packages
Executable file
|
@ -0,0 +1,154 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# Build the Debian packages using Docker images.
|
||||
#
|
||||
# This script builds the Docker images and then executes them sequentially, each
|
||||
# one building a Debian package for the targeted operating system. It is
|
||||
# designed to be a "single command" to produce all the images.
|
||||
#
|
||||
# By default, builds for all known distributions, but a list of distributions
|
||||
# can be passed on the commandline for debugging.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
DISTS = (
|
||||
"debian:stretch",
|
||||
"debian:buster",
|
||||
"debian:sid",
|
||||
"ubuntu:xenial",
|
||||
"ubuntu:bionic",
|
||||
"ubuntu:cosmic",
|
||||
)
|
||||
|
||||
DESC = '''\
|
||||
Builds .debs for synapse, using a Docker image for the build environment.
|
||||
|
||||
By default, builds for all known distributions, but a list of distributions
|
||||
can be passed on the commandline for debugging.
|
||||
'''
|
||||
|
||||
|
||||
class Builder(object):
|
||||
def __init__(self, redirect_stdout=False):
|
||||
self.redirect_stdout = redirect_stdout
|
||||
self.active_containers = set()
|
||||
self._lock = threading.Lock()
|
||||
self._failed = False
|
||||
|
||||
def run_build(self, dist):
|
||||
"""Build deb for a single distribution"""
|
||||
|
||||
if self._failed:
|
||||
print("not building %s due to earlier failure" % (dist, ))
|
||||
raise Exception("failed")
|
||||
|
||||
try:
|
||||
self._inner_build(dist)
|
||||
except Exception as e:
|
||||
print("build of %s failed: %s" % (dist, e), file=sys.stderr)
|
||||
self._failed = True
|
||||
raise
|
||||
|
||||
def _inner_build(self, dist):
|
||||
projdir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
os.chdir(projdir)
|
||||
|
||||
tag = dist.split(":", 1)[1]
|
||||
|
||||
# Make the dir where the debs will live.
|
||||
#
|
||||
# Note that we deliberately put this outside the source tree, otherwise
|
||||
# we tend to get source packages which are full of debs. (We could hack
|
||||
# around that with more magic in the build_debian.sh script, but that
|
||||
# doesn't solve the problem for natively-run dpkg-buildpakage).
|
||||
debsdir = os.path.join(projdir, '../debs')
|
||||
os.makedirs(debsdir, exist_ok=True)
|
||||
|
||||
if self.redirect_stdout:
|
||||
logfile = os.path.join(debsdir, "%s.buildlog" % (tag, ))
|
||||
print("building %s: directing output to %s" % (dist, logfile))
|
||||
stdout = open(logfile, "w")
|
||||
else:
|
||||
stdout = None
|
||||
|
||||
# first build a docker image for the build environment
|
||||
subprocess.check_call([
|
||||
"docker", "build",
|
||||
"--tag", "dh-venv-builder:" + tag,
|
||||
"--build-arg", "distro=" + dist,
|
||||
"-f", "docker/Dockerfile-dhvirtualenv",
|
||||
"docker",
|
||||
], stdout=stdout, stderr=subprocess.STDOUT)
|
||||
|
||||
container_name = "synapse_build_" + tag
|
||||
with self._lock:
|
||||
self.active_containers.add(container_name)
|
||||
|
||||
# then run the build itself
|
||||
subprocess.check_call([
|
||||
"docker", "run",
|
||||
"--rm",
|
||||
"--name", container_name,
|
||||
"--volume=" + projdir + ":/synapse/source:ro",
|
||||
"--volume=" + debsdir + ":/debs",
|
||||
"-e", "TARGET_USERID=%i" % (os.getuid(), ),
|
||||
"-e", "TARGET_GROUPID=%i" % (os.getgid(), ),
|
||||
"dh-venv-builder:" + tag,
|
||||
], stdout=stdout, stderr=subprocess.STDOUT)
|
||||
|
||||
with self._lock:
|
||||
self.active_containers.remove(container_name)
|
||||
|
||||
if stdout is not None:
|
||||
stdout.close()
|
||||
print("Completed build of %s" % (dist, ))
|
||||
|
||||
def kill_containers(self):
|
||||
with self._lock:
|
||||
active = list(self.active_containers)
|
||||
|
||||
for c in active:
|
||||
print("killing container %s" % (c,))
|
||||
subprocess.run([
|
||||
"docker", "kill", c,
|
||||
], stdout=subprocess.DEVNULL)
|
||||
with self._lock:
|
||||
self.active_containers.remove(c)
|
||||
|
||||
|
||||
def run_builds(dists, jobs=1):
|
||||
builder = Builder(redirect_stdout=(jobs > 1))
|
||||
|
||||
def sig(signum, _frame):
|
||||
print("Caught SIGINT")
|
||||
builder.kill_containers()
|
||||
signal.signal(signal.SIGINT, sig)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=jobs) as e:
|
||||
res = e.map(builder.run_build, dists)
|
||||
|
||||
# make sure we consume the iterable so that exceptions are raised.
|
||||
for r in res:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(
|
||||
description=DESC,
|
||||
)
|
||||
parser.add_argument(
|
||||
'-j', '--jobs', type=int, default=1,
|
||||
help='specify the number of builds to run in parallel',
|
||||
)
|
||||
parser.add_argument(
|
||||
'dist', nargs='*', default=DISTS,
|
||||
help='a list of distributions to build for. Default: %(default)s',
|
||||
)
|
||||
args = parser.parse_args()
|
||||
run_builds(dists=args.dist, jobs=args.jobs)
|
|
@ -68,6 +68,7 @@ class EventTypes(object):
|
|||
Aliases = "m.room.aliases"
|
||||
Redaction = "m.room.redaction"
|
||||
ThirdPartyInvite = "m.room.third_party_invite"
|
||||
Encryption = "m.room.encryption"
|
||||
|
||||
RoomHistoryVisibility = "m.room.history_visibility"
|
||||
CanonicalAlias = "m.room.canonical_alias"
|
||||
|
@ -128,4 +129,4 @@ class UserTypes(object):
|
|||
'admin' and 'guest' users should also be UserTypes. Normal users are type None
|
||||
"""
|
||||
SUPPORT = "support"
|
||||
ALL_USER_TYPES = (SUPPORT)
|
||||
ALL_USER_TYPES = (SUPPORT,)
|
||||
|
|
|
@ -24,7 +24,9 @@ from synapse.config import ConfigError
|
|||
|
||||
CLIENT_PREFIX = "/_matrix/client/api/v1"
|
||||
CLIENT_V2_ALPHA_PREFIX = "/_matrix/client/v2_alpha"
|
||||
FEDERATION_PREFIX = "/_matrix/federation/v1"
|
||||
FEDERATION_PREFIX = "/_matrix/federation"
|
||||
FEDERATION_V1_PREFIX = FEDERATION_PREFIX + "/v1"
|
||||
FEDERATION_V2_PREFIX = FEDERATION_PREFIX + "/v2"
|
||||
STATIC_PREFIX = "/_matrix/static"
|
||||
WEB_CLIENT_PREFIX = "/_matrix/client"
|
||||
CONTENT_REPO_PREFIX = "/_matrix/content"
|
||||
|
|
|
@ -13,10 +13,12 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
from six import iteritems
|
||||
|
||||
|
@ -324,17 +326,12 @@ def setup(config_options):
|
|||
|
||||
events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||
|
||||
tls_server_context_factory = context_factory.ServerContextFactory(config)
|
||||
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
|
||||
|
||||
database_engine = create_engine(config.database_config)
|
||||
config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection
|
||||
|
||||
hs = SynapseHomeServer(
|
||||
config.server_name,
|
||||
db_config=config.database_config,
|
||||
tls_server_context_factory=tls_server_context_factory,
|
||||
tls_client_options_factory=tls_client_options_factory,
|
||||
config=config,
|
||||
version_string="Synapse/" + get_version_string(synapse),
|
||||
database_engine=database_engine,
|
||||
|
@ -361,12 +358,53 @@ def setup(config_options):
|
|||
logger.info("Database prepared in %s.", config.database_config['name'])
|
||||
|
||||
hs.setup()
|
||||
hs.start_listening()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def start():
|
||||
hs.get_pusherpool().start()
|
||||
hs.get_datastore().start_profiling()
|
||||
hs.get_datastore().start_doing_background_updates()
|
||||
try:
|
||||
# Check if the certificate is still valid.
|
||||
cert_days_remaining = hs.config.is_disk_cert_valid()
|
||||
|
||||
if hs.config.acme_enabled:
|
||||
# If ACME is enabled, we might need to provision a certificate
|
||||
# before starting.
|
||||
acme = hs.get_acme_handler()
|
||||
|
||||
# Start up the webservices which we will respond to ACME
|
||||
# challenges with.
|
||||
yield acme.start_listening()
|
||||
|
||||
# We want to reprovision if cert_days_remaining is None (meaning no
|
||||
# certificate exists), or the days remaining number it returns
|
||||
# is less than our re-registration threshold.
|
||||
if (cert_days_remaining is None) or (
|
||||
not cert_days_remaining > hs.config.acme_reprovision_threshold
|
||||
):
|
||||
yield acme.provision_certificate()
|
||||
|
||||
# Read the certificate from disk and build the context factories for
|
||||
# TLS.
|
||||
hs.config.read_certificate_from_disk()
|
||||
hs.tls_server_context_factory = context_factory.ServerContextFactory(config)
|
||||
hs.tls_client_options_factory = context_factory.ClientTLSOptionsFactory(
|
||||
config
|
||||
)
|
||||
|
||||
# It is now safe to start your Synapse.
|
||||
hs.start_listening()
|
||||
hs.get_pusherpool().start()
|
||||
hs.get_datastore().start_profiling()
|
||||
hs.get_datastore().start_doing_background_updates()
|
||||
except Exception as e:
|
||||
# If a DeferredList failed (like in listening on the ACME listener),
|
||||
# we need to print the subfailure explicitly.
|
||||
if isinstance(e, defer.FirstError):
|
||||
e.subFailure.printTraceback(sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Something else went wrong when starting. Print it and bail out.
|
||||
traceback.print_exc(file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
reactor.callWhenRunning(start)
|
||||
|
||||
|
|
|
@ -367,7 +367,7 @@ class Config(object):
|
|||
if not keys_directory:
|
||||
keys_directory = os.path.dirname(config_files[-1])
|
||||
|
||||
config_dir_path = os.path.abspath(keys_directory)
|
||||
self.config_dir_path = os.path.abspath(keys_directory)
|
||||
|
||||
specified_config = {}
|
||||
for config_file in config_files:
|
||||
|
@ -379,7 +379,7 @@ class Config(object):
|
|||
|
||||
server_name = specified_config["server_name"]
|
||||
config_string = self.generate_config(
|
||||
config_dir_path=config_dir_path,
|
||||
config_dir_path=self.config_dir_path,
|
||||
data_dir_path=os.getcwd(),
|
||||
server_name=server_name,
|
||||
generate_secrets=False,
|
||||
|
|
|
@ -83,9 +83,6 @@ class KeyConfig(Config):
|
|||
# a secret which is used to sign access tokens. If none is specified,
|
||||
# the registration_shared_secret is used, if one is given; otherwise,
|
||||
# a secret key is derived from the signing key.
|
||||
#
|
||||
# Note that changing this will invalidate any active access tokens, so
|
||||
# all clients will have to log back in.
|
||||
%(macaroon_secret_key)s
|
||||
|
||||
# Used to enable access token expiration.
|
||||
|
|
|
@ -50,6 +50,10 @@ class RegistrationConfig(Config):
|
|||
raise ConfigError('Invalid auto_join_rooms entry %s' % (room_alias,))
|
||||
self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True)
|
||||
|
||||
self.disable_msisdn_registration = (
|
||||
config.get("disable_msisdn_registration", False)
|
||||
)
|
||||
|
||||
def default_config(self, generate_secrets=False, **kwargs):
|
||||
if generate_secrets:
|
||||
registration_shared_secret = 'registration_shared_secret: "%s"' % (
|
||||
|
@ -70,6 +74,11 @@ class RegistrationConfig(Config):
|
|||
# - email
|
||||
# - msisdn
|
||||
|
||||
# Explicitly disable asking for MSISDNs from the registration
|
||||
# flow (overrides registrations_require_3pid if MSISDNs are set as required)
|
||||
#
|
||||
# disable_msisdn_registration = True
|
||||
|
||||
# Mandate that users are only allowed to associate certain formats of
|
||||
# 3PIDs with accounts on this server.
|
||||
#
|
||||
|
|
|
@ -13,52 +13,38 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
from datetime import datetime
|
||||
from hashlib import sha256
|
||||
|
||||
from unpaddedbase64 import encode_base64
|
||||
|
||||
from OpenSSL import crypto
|
||||
|
||||
from ._base import Config
|
||||
from synapse.config._base import Config
|
||||
|
||||
GENERATE_DH_PARAMS = False
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
class TlsConfig(Config):
|
||||
def read_config(self, config):
|
||||
self.tls_certificate = self.read_tls_certificate(
|
||||
config.get("tls_certificate_path")
|
||||
)
|
||||
self.tls_certificate_file = config.get("tls_certificate_path")
|
||||
|
||||
acme_config = config.get("acme", {})
|
||||
self.acme_enabled = acme_config.get("enabled", False)
|
||||
self.acme_url = acme_config.get(
|
||||
"url", "https://acme-v01.api.letsencrypt.org/directory"
|
||||
)
|
||||
self.acme_port = acme_config.get("port", 8449)
|
||||
self.acme_bind_addresses = acme_config.get("bind_addresses", ["127.0.0.1"])
|
||||
self.acme_reprovision_threshold = acme_config.get("reprovision_threshold", 30)
|
||||
|
||||
self.tls_certificate_file = os.path.abspath(config.get("tls_certificate_path"))
|
||||
self.tls_private_key_file = os.path.abspath(config.get("tls_private_key_path"))
|
||||
self._original_tls_fingerprints = config["tls_fingerprints"]
|
||||
self.tls_fingerprints = list(self._original_tls_fingerprints)
|
||||
self.no_tls = config.get("no_tls", False)
|
||||
|
||||
if self.no_tls:
|
||||
self.tls_private_key = None
|
||||
else:
|
||||
self.tls_private_key = self.read_tls_private_key(
|
||||
config.get("tls_private_key_path")
|
||||
)
|
||||
|
||||
self.tls_dh_params_path = self.check_file(
|
||||
config.get("tls_dh_params_path"), "tls_dh_params"
|
||||
)
|
||||
|
||||
self.tls_fingerprints = config["tls_fingerprints"]
|
||||
|
||||
# Check that our own certificate is included in the list of fingerprints
|
||||
# and include it if it is not.
|
||||
x509_certificate_bytes = crypto.dump_certificate(
|
||||
crypto.FILETYPE_ASN1,
|
||||
self.tls_certificate
|
||||
)
|
||||
sha256_fingerprint = encode_base64(sha256(x509_certificate_bytes).digest())
|
||||
sha256_fingerprints = set(f["sha256"] for f in self.tls_fingerprints)
|
||||
if sha256_fingerprint not in sha256_fingerprints:
|
||||
self.tls_fingerprints.append({u"sha256": sha256_fingerprint})
|
||||
|
||||
# This config option applies to non-federation HTTP clients
|
||||
# (e.g. for talking to recaptcha, identity servers, and such)
|
||||
# It should never be used in production, and is intended for
|
||||
|
@ -67,14 +53,70 @@ class TlsConfig(Config):
|
|||
"use_insecure_ssl_client_just_for_testing_do_not_use"
|
||||
)
|
||||
|
||||
self.tls_certificate = None
|
||||
self.tls_private_key = None
|
||||
|
||||
def is_disk_cert_valid(self):
|
||||
"""
|
||||
Is the certificate we have on disk valid, and if so, for how long?
|
||||
|
||||
Returns:
|
||||
int: Days remaining of certificate validity.
|
||||
None: No certificate exists.
|
||||
"""
|
||||
if not os.path.exists(self.tls_certificate_file):
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(self.tls_certificate_file, 'rb') as f:
|
||||
cert_pem = f.read()
|
||||
except Exception:
|
||||
logger.exception("Failed to read existing certificate off disk!")
|
||||
raise
|
||||
|
||||
try:
|
||||
tls_certificate = crypto.load_certificate(crypto.FILETYPE_PEM, cert_pem)
|
||||
except Exception:
|
||||
logger.exception("Failed to parse existing certificate off disk!")
|
||||
raise
|
||||
|
||||
# YYYYMMDDhhmmssZ -- in UTC
|
||||
expires_on = datetime.strptime(
|
||||
tls_certificate.get_notAfter().decode('ascii'), "%Y%m%d%H%M%SZ"
|
||||
)
|
||||
now = datetime.utcnow()
|
||||
days_remaining = (expires_on - now).days
|
||||
return days_remaining
|
||||
|
||||
def read_certificate_from_disk(self):
|
||||
"""
|
||||
Read the certificates from disk.
|
||||
"""
|
||||
self.tls_certificate = self.read_tls_certificate(self.tls_certificate_file)
|
||||
|
||||
if not self.no_tls:
|
||||
self.tls_private_key = self.read_tls_private_key(self.tls_private_key_file)
|
||||
|
||||
self.tls_fingerprints = list(self._original_tls_fingerprints)
|
||||
|
||||
# Check that our own certificate is included in the list of fingerprints
|
||||
# and include it if it is not.
|
||||
x509_certificate_bytes = crypto.dump_certificate(
|
||||
crypto.FILETYPE_ASN1, self.tls_certificate
|
||||
)
|
||||
sha256_fingerprint = encode_base64(sha256(x509_certificate_bytes).digest())
|
||||
sha256_fingerprints = set(f["sha256"] for f in self.tls_fingerprints)
|
||||
if sha256_fingerprint not in sha256_fingerprints:
|
||||
self.tls_fingerprints.append({u"sha256": sha256_fingerprint})
|
||||
|
||||
def default_config(self, config_dir_path, server_name, **kwargs):
|
||||
base_key_name = os.path.join(config_dir_path, server_name)
|
||||
|
||||
tls_certificate_path = base_key_name + ".tls.crt"
|
||||
tls_private_key_path = base_key_name + ".tls.key"
|
||||
tls_dh_params_path = base_key_name + ".tls.dh"
|
||||
|
||||
return """\
|
||||
return (
|
||||
"""\
|
||||
# PEM encoded X509 certificate for TLS.
|
||||
# You can replace the self-signed certificate that synapse
|
||||
# autogenerates on launch with your own SSL certificate + key pair
|
||||
|
@ -85,9 +127,6 @@ class TlsConfig(Config):
|
|||
# PEM encoded private key for TLS
|
||||
tls_private_key_path: "%(tls_private_key_path)s"
|
||||
|
||||
# PEM dh parameters for ephemeral keys
|
||||
tls_dh_params_path: "%(tls_dh_params_path)s"
|
||||
|
||||
# Don't bind to the https port
|
||||
no_tls: False
|
||||
|
||||
|
@ -118,7 +157,24 @@ class TlsConfig(Config):
|
|||
#
|
||||
tls_fingerprints: []
|
||||
# tls_fingerprints: [{"sha256": "<base64_encoded_sha256_fingerprint>"}]
|
||||
""" % locals()
|
||||
|
||||
## Support for ACME certificate auto-provisioning.
|
||||
# acme:
|
||||
# enabled: false
|
||||
## ACME path.
|
||||
## If you only want to test, use the staging url:
|
||||
## https://acme-staging.api.letsencrypt.org/directory
|
||||
# url: 'https://acme-v01.api.letsencrypt.org/directory'
|
||||
## Port number (to listen for the HTTP-01 challenge).
|
||||
## Using port 80 requires utilising something like authbind, or proxying to it.
|
||||
# port: 8449
|
||||
## Hosts to bind to.
|
||||
# bind_addresses: ['127.0.0.1']
|
||||
## How many days remaining on a certificate before it is renewed.
|
||||
# reprovision_threshold: 30
|
||||
"""
|
||||
% locals()
|
||||
)
|
||||
|
||||
def read_tls_certificate(self, cert_path):
|
||||
cert_pem = self.read_file(cert_path, "tls_certificate")
|
||||
|
@ -131,7 +187,6 @@ class TlsConfig(Config):
|
|||
def generate_files(self, config):
|
||||
tls_certificate_path = config["tls_certificate_path"]
|
||||
tls_private_key_path = config["tls_private_key_path"]
|
||||
tls_dh_params_path = config["tls_dh_params_path"]
|
||||
|
||||
if not self.path_exists(tls_private_key_path):
|
||||
with open(tls_private_key_path, "wb") as private_key_file:
|
||||
|
@ -165,31 +220,3 @@ class TlsConfig(Config):
|
|||
cert_pem = crypto.dump_certificate(crypto.FILETYPE_PEM, cert)
|
||||
|
||||
certificate_file.write(cert_pem)
|
||||
|
||||
if not self.path_exists(tls_dh_params_path):
|
||||
if GENERATE_DH_PARAMS:
|
||||
subprocess.check_call([
|
||||
"openssl", "dhparam",
|
||||
"-outform", "PEM",
|
||||
"-out", tls_dh_params_path,
|
||||
"2048"
|
||||
])
|
||||
else:
|
||||
with open(tls_dh_params_path, "w") as dh_params_file:
|
||||
dh_params_file.write(
|
||||
"2048-bit DH parameters taken from rfc3526\n"
|
||||
"-----BEGIN DH PARAMETERS-----\n"
|
||||
"MIIBCAKCAQEA///////////JD9qiIWjC"
|
||||
"NMTGYouA3BzRKQJOCIpnzHQCC76mOxOb\n"
|
||||
"IlFKCHmONATd75UZs806QxswKwpt8l8U"
|
||||
"N0/hNW1tUcJF5IW1dmJefsb0TELppjft\n"
|
||||
"awv/XLb0Brft7jhr+1qJn6WunyQRfEsf"
|
||||
"5kkoZlHs5Fs9wgB8uKFjvwWY2kg2HFXT\n"
|
||||
"mmkWP6j9JM9fg2VdI9yjrZYcYvNWIIVS"
|
||||
"u57VKQdwlpZtZww1Tkq8mATxdGwIyhgh\n"
|
||||
"fDKQXkYuNs474553LBgOhgObJ4Oi7Aei"
|
||||
"j7XFXfBvTFLJ3ivL9pVYFxg5lUl86pVq\n"
|
||||
"5RXSJhiY+gUQFXKOWoqsqmj/////////"
|
||||
"/wIBAg==\n"
|
||||
"-----END DH PARAMETERS-----\n"
|
||||
)
|
||||
|
|
|
@ -17,6 +17,7 @@ from zope.interface import implementer
|
|||
|
||||
from OpenSSL import SSL, crypto
|
||||
from twisted.internet._sslverify import _defaultCurveName
|
||||
from twisted.internet.abstract import isIPAddress, isIPv6Address
|
||||
from twisted.internet.interfaces import IOpenSSLClientConnectionCreator
|
||||
from twisted.internet.ssl import CertificateOptions, ContextFactory
|
||||
from twisted.python.failure import Failure
|
||||
|
@ -46,8 +47,10 @@ class ServerContextFactory(ContextFactory):
|
|||
if not config.no_tls:
|
||||
context.use_privatekey(config.tls_private_key)
|
||||
|
||||
context.load_tmp_dh(config.tls_dh_params_path)
|
||||
context.set_cipher_list("!ADH:HIGH+kEDH:!AECDH:HIGH+kEECDH")
|
||||
# https://hynek.me/articles/hardening-your-web-servers-ssl-ciphers/
|
||||
context.set_cipher_list(
|
||||
"ECDH+AESGCM:ECDH+CHACHA20:ECDH+AES256:ECDH+AES128:!aNULL:!SHA1"
|
||||
)
|
||||
|
||||
def getContext(self):
|
||||
return self._context
|
||||
|
@ -96,8 +99,14 @@ class ClientTLSOptions(object):
|
|||
|
||||
def __init__(self, hostname, ctx):
|
||||
self._ctx = ctx
|
||||
self._hostname = hostname
|
||||
self._hostnameBytes = _idnaBytes(hostname)
|
||||
|
||||
if isIPAddress(hostname) or isIPv6Address(hostname):
|
||||
self._hostnameBytes = hostname.encode('ascii')
|
||||
self._sendSNI = False
|
||||
else:
|
||||
self._hostnameBytes = _idnaBytes(hostname)
|
||||
self._sendSNI = True
|
||||
|
||||
ctx.set_info_callback(
|
||||
_tolerateErrors(self._identityVerifyingInfoCallback)
|
||||
)
|
||||
|
@ -109,7 +118,9 @@ class ClientTLSOptions(object):
|
|||
return connection
|
||||
|
||||
def _identityVerifyingInfoCallback(self, connection, where, ret):
|
||||
if where & SSL.SSL_CB_HANDSHAKE_START:
|
||||
# Literal IPv4 and IPv6 addresses are not permitted
|
||||
# as host names according to the RFCs
|
||||
if where & SSL.SSL_CB_HANDSHAKE_START and self._sendSNI:
|
||||
connection.set_tlsext_host_name(self._hostnameBytes)
|
||||
|
||||
|
||||
|
|
|
@ -1,149 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
from six.moves import urllib
|
||||
|
||||
from canonicaljson import json
|
||||
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.internet.error import ConnectError
|
||||
from twisted.internet.protocol import Factory
|
||||
from twisted.names.error import DomainError
|
||||
from twisted.web.http import HTTPClient
|
||||
|
||||
from synapse.http.endpoint import matrix_federation_endpoint
|
||||
from synapse.util import logcontext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
KEY_API_V2 = "/_matrix/key/v2/server/%s"
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def fetch_server_key(server_name, tls_client_options_factory, key_id):
|
||||
"""Fetch the keys for a remote server."""
|
||||
|
||||
factory = SynapseKeyClientFactory()
|
||||
factory.path = KEY_API_V2 % (urllib.parse.quote(key_id), )
|
||||
factory.host = server_name
|
||||
endpoint = matrix_federation_endpoint(
|
||||
reactor, server_name, tls_client_options_factory, timeout=30
|
||||
)
|
||||
|
||||
for i in range(5):
|
||||
try:
|
||||
with logcontext.PreserveLoggingContext():
|
||||
protocol = yield endpoint.connect(factory)
|
||||
server_response, server_certificate = yield protocol.remote_key
|
||||
defer.returnValue((server_response, server_certificate))
|
||||
except SynapseKeyClientError as e:
|
||||
logger.warn("Error getting key for %r: %s", server_name, e)
|
||||
if e.status.startswith(b"4"):
|
||||
# Don't retry for 4xx responses.
|
||||
raise IOError("Cannot get key for %r" % server_name)
|
||||
except (ConnectError, DomainError) as e:
|
||||
logger.warn("Error getting key for %r: %s", server_name, e)
|
||||
except Exception:
|
||||
logger.exception("Error getting key for %r", server_name)
|
||||
raise IOError("Cannot get key for %r" % server_name)
|
||||
|
||||
|
||||
class SynapseKeyClientError(Exception):
|
||||
"""The key wasn't retrieved from the remote server."""
|
||||
status = None
|
||||
pass
|
||||
|
||||
|
||||
class SynapseKeyClientProtocol(HTTPClient):
|
||||
"""Low level HTTPS client which retrieves an application/json response from
|
||||
the server and extracts the X.509 certificate for the remote peer from the
|
||||
SSL connection."""
|
||||
|
||||
timeout = 30
|
||||
|
||||
def __init__(self):
|
||||
self.remote_key = defer.Deferred()
|
||||
self.host = None
|
||||
self._peer = None
|
||||
|
||||
def connectionMade(self):
|
||||
self._peer = self.transport.getPeer()
|
||||
logger.debug("Connected to %s", self._peer)
|
||||
|
||||
if not isinstance(self.path, bytes):
|
||||
self.path = self.path.encode('ascii')
|
||||
|
||||
if not isinstance(self.host, bytes):
|
||||
self.host = self.host.encode('ascii')
|
||||
|
||||
self.sendCommand(b"GET", self.path)
|
||||
if self.host:
|
||||
self.sendHeader(b"Host", self.host)
|
||||
self.endHeaders()
|
||||
self.timer = reactor.callLater(
|
||||
self.timeout,
|
||||
self.on_timeout
|
||||
)
|
||||
|
||||
def errback(self, error):
|
||||
if not self.remote_key.called:
|
||||
self.remote_key.errback(error)
|
||||
|
||||
def callback(self, result):
|
||||
if not self.remote_key.called:
|
||||
self.remote_key.callback(result)
|
||||
|
||||
def handleStatus(self, version, status, message):
|
||||
if status != b"200":
|
||||
# logger.info("Non-200 response from %s: %s %s",
|
||||
# self.transport.getHost(), status, message)
|
||||
error = SynapseKeyClientError(
|
||||
"Non-200 response %r from %r" % (status, self.host)
|
||||
)
|
||||
error.status = status
|
||||
self.errback(error)
|
||||
self.transport.abortConnection()
|
||||
|
||||
def handleResponse(self, response_body_bytes):
|
||||
try:
|
||||
json_response = json.loads(response_body_bytes)
|
||||
except ValueError:
|
||||
# logger.info("Invalid JSON response from %s",
|
||||
# self.transport.getHost())
|
||||
self.transport.abortConnection()
|
||||
return
|
||||
|
||||
certificate = self.transport.getPeerCertificate()
|
||||
self.callback((json_response, certificate))
|
||||
self.transport.abortConnection()
|
||||
self.timer.cancel()
|
||||
|
||||
def on_timeout(self):
|
||||
logger.debug(
|
||||
"Timeout waiting for response from %s: %s",
|
||||
self.host, self._peer,
|
||||
)
|
||||
self.errback(IOError("Timeout waiting for response"))
|
||||
self.transport.abortConnection()
|
||||
|
||||
|
||||
class SynapseKeyClientFactory(Factory):
|
||||
def protocol(self):
|
||||
protocol = SynapseKeyClientProtocol()
|
||||
protocol.path = self.path
|
||||
protocol.host = self.host
|
||||
return protocol
|
|
@ -14,10 +14,11 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
|
||||
from six.moves import urllib
|
||||
|
||||
from signedjson.key import (
|
||||
decode_verify_key_bytes,
|
||||
encode_verify_key_base64,
|
||||
|
@ -30,13 +31,11 @@ from signedjson.sign import (
|
|||
signature_ids,
|
||||
verify_signed_json,
|
||||
)
|
||||
from unpaddedbase64 import decode_base64, encode_base64
|
||||
from unpaddedbase64 import decode_base64
|
||||
|
||||
from OpenSSL import crypto
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.crypto.keyclient import fetch_server_key
|
||||
from synapse.util import logcontext, unwrapFirstError
|
||||
from synapse.util.logcontext import (
|
||||
LoggingContext,
|
||||
|
@ -503,31 +502,16 @@ class Keyring(object):
|
|||
if requested_key_id in keys:
|
||||
continue
|
||||
|
||||
(response, tls_certificate) = yield fetch_server_key(
|
||||
server_name, self.hs.tls_client_options_factory, requested_key_id
|
||||
response = yield self.client.get_json(
|
||||
destination=server_name,
|
||||
path="/_matrix/key/v2/server/" + urllib.parse.quote(requested_key_id),
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
if (u"signatures" not in response
|
||||
or server_name not in response[u"signatures"]):
|
||||
raise KeyLookupError("Key response not signed by remote server")
|
||||
|
||||
if "tls_fingerprints" not in response:
|
||||
raise KeyLookupError("Key response missing TLS fingerprints")
|
||||
|
||||
certificate_bytes = crypto.dump_certificate(
|
||||
crypto.FILETYPE_ASN1, tls_certificate
|
||||
)
|
||||
sha256_fingerprint = hashlib.sha256(certificate_bytes).digest()
|
||||
sha256_fingerprint_b64 = encode_base64(sha256_fingerprint)
|
||||
|
||||
response_sha256_fingerprints = set()
|
||||
for fingerprint in response[u"tls_fingerprints"]:
|
||||
if u"sha256" in fingerprint:
|
||||
response_sha256_fingerprints.add(fingerprint[u"sha256"])
|
||||
|
||||
if sha256_fingerprint_b64 not in response_sha256_fingerprints:
|
||||
raise KeyLookupError("TLS certificate not allowed by fingerprints")
|
||||
|
||||
response_keys = yield self.process_v2_response(
|
||||
from_server=server_name,
|
||||
requested_ids=[requested_key_id],
|
||||
|
|
|
@ -369,13 +369,13 @@ class FederationServer(FederationBase):
|
|||
})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_invite_request(self, origin, content):
|
||||
def on_invite_request(self, origin, content, room_version):
|
||||
pdu = event_from_pdu_json(content)
|
||||
origin_host, _ = parse_server_name(origin)
|
||||
yield self.check_server_matches_acl(origin_host, pdu.room_id)
|
||||
ret_pdu = yield self.handler.on_invite_request(origin, pdu)
|
||||
time_now = self._clock.time_msec()
|
||||
defer.returnValue((200, {"event": ret_pdu.get_pdu_json(time_now)}))
|
||||
defer.returnValue({"event": ret_pdu.get_pdu_json(time_now)})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_send_join_request(self, origin, content):
|
||||
|
|
|
@ -21,7 +21,7 @@ from six.moves import urllib
|
|||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import Membership
|
||||
from synapse.api.urls import FEDERATION_PREFIX as PREFIX
|
||||
from synapse.api.urls import FEDERATION_V1_PREFIX
|
||||
from synapse.util.logutils import log_function
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -51,7 +51,7 @@ class TransportLayerClient(object):
|
|||
logger.debug("get_room_state dest=%s, room=%s",
|
||||
destination, room_id)
|
||||
|
||||
path = _create_path(PREFIX, "/state/%s/", room_id)
|
||||
path = _create_v1_path("/state/%s/", room_id)
|
||||
return self.client.get_json(
|
||||
destination, path=path, args={"event_id": event_id},
|
||||
)
|
||||
|
@ -73,7 +73,7 @@ class TransportLayerClient(object):
|
|||
logger.debug("get_room_state_ids dest=%s, room=%s",
|
||||
destination, room_id)
|
||||
|
||||
path = _create_path(PREFIX, "/state_ids/%s/", room_id)
|
||||
path = _create_v1_path("/state_ids/%s/", room_id)
|
||||
return self.client.get_json(
|
||||
destination, path=path, args={"event_id": event_id},
|
||||
)
|
||||
|
@ -95,7 +95,7 @@ class TransportLayerClient(object):
|
|||
logger.debug("get_pdu dest=%s, event_id=%s",
|
||||
destination, event_id)
|
||||
|
||||
path = _create_path(PREFIX, "/event/%s/", event_id)
|
||||
path = _create_v1_path("/event/%s/", event_id)
|
||||
return self.client.get_json(destination, path=path, timeout=timeout)
|
||||
|
||||
@log_function
|
||||
|
@ -121,7 +121,7 @@ class TransportLayerClient(object):
|
|||
# TODO: raise?
|
||||
return
|
||||
|
||||
path = _create_path(PREFIX, "/backfill/%s/", room_id)
|
||||
path = _create_v1_path("/backfill/%s/", room_id)
|
||||
|
||||
args = {
|
||||
"v": event_tuples,
|
||||
|
@ -167,7 +167,7 @@ class TransportLayerClient(object):
|
|||
# generated by the json_data_callback.
|
||||
json_data = transaction.get_dict()
|
||||
|
||||
path = _create_path(PREFIX, "/send/%s/", transaction.transaction_id)
|
||||
path = _create_v1_path("/send/%s/", transaction.transaction_id)
|
||||
|
||||
response = yield self.client.put_json(
|
||||
transaction.destination,
|
||||
|
@ -184,7 +184,7 @@ class TransportLayerClient(object):
|
|||
@log_function
|
||||
def make_query(self, destination, query_type, args, retry_on_dns_fail,
|
||||
ignore_backoff=False):
|
||||
path = _create_path(PREFIX, "/query/%s", query_type)
|
||||
path = _create_v1_path("/query/%s", query_type)
|
||||
|
||||
content = yield self.client.get_json(
|
||||
destination=destination,
|
||||
|
@ -231,7 +231,7 @@ class TransportLayerClient(object):
|
|||
"make_membership_event called with membership='%s', must be one of %s" %
|
||||
(membership, ",".join(valid_memberships))
|
||||
)
|
||||
path = _create_path(PREFIX, "/make_%s/%s/%s", membership, room_id, user_id)
|
||||
path = _create_v1_path("/make_%s/%s/%s", membership, room_id, user_id)
|
||||
|
||||
ignore_backoff = False
|
||||
retry_on_dns_fail = False
|
||||
|
@ -258,7 +258,7 @@ class TransportLayerClient(object):
|
|||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def send_join(self, destination, room_id, event_id, content):
|
||||
path = _create_path(PREFIX, "/send_join/%s/%s", room_id, event_id)
|
||||
path = _create_v1_path("/send_join/%s/%s", room_id, event_id)
|
||||
|
||||
response = yield self.client.put_json(
|
||||
destination=destination,
|
||||
|
@ -271,7 +271,7 @@ class TransportLayerClient(object):
|
|||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def send_leave(self, destination, room_id, event_id, content):
|
||||
path = _create_path(PREFIX, "/send_leave/%s/%s", room_id, event_id)
|
||||
path = _create_v1_path("/send_leave/%s/%s", room_id, event_id)
|
||||
|
||||
response = yield self.client.put_json(
|
||||
destination=destination,
|
||||
|
@ -290,7 +290,7 @@ class TransportLayerClient(object):
|
|||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def send_invite(self, destination, room_id, event_id, content):
|
||||
path = _create_path(PREFIX, "/invite/%s/%s", room_id, event_id)
|
||||
path = _create_v1_path("/invite/%s/%s", room_id, event_id)
|
||||
|
||||
response = yield self.client.put_json(
|
||||
destination=destination,
|
||||
|
@ -306,7 +306,7 @@ class TransportLayerClient(object):
|
|||
def get_public_rooms(self, remote_server, limit, since_token,
|
||||
search_filter=None, include_all_networks=False,
|
||||
third_party_instance_id=None):
|
||||
path = PREFIX + "/publicRooms"
|
||||
path = _create_v1_path("/publicRooms")
|
||||
|
||||
args = {
|
||||
"include_all_networks": "true" if include_all_networks else "false",
|
||||
|
@ -332,7 +332,7 @@ class TransportLayerClient(object):
|
|||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def exchange_third_party_invite(self, destination, room_id, event_dict):
|
||||
path = _create_path(PREFIX, "/exchange_third_party_invite/%s", room_id,)
|
||||
path = _create_v1_path("/exchange_third_party_invite/%s", room_id,)
|
||||
|
||||
response = yield self.client.put_json(
|
||||
destination=destination,
|
||||
|
@ -345,7 +345,7 @@ class TransportLayerClient(object):
|
|||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def get_event_auth(self, destination, room_id, event_id):
|
||||
path = _create_path(PREFIX, "/event_auth/%s/%s", room_id, event_id)
|
||||
path = _create_v1_path("/event_auth/%s/%s", room_id, event_id)
|
||||
|
||||
content = yield self.client.get_json(
|
||||
destination=destination,
|
||||
|
@ -357,7 +357,7 @@ class TransportLayerClient(object):
|
|||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def send_query_auth(self, destination, room_id, event_id, content):
|
||||
path = _create_path(PREFIX, "/query_auth/%s/%s", room_id, event_id)
|
||||
path = _create_v1_path("/query_auth/%s/%s", room_id, event_id)
|
||||
|
||||
content = yield self.client.post_json(
|
||||
destination=destination,
|
||||
|
@ -392,7 +392,7 @@ class TransportLayerClient(object):
|
|||
Returns:
|
||||
A dict containg the device keys.
|
||||
"""
|
||||
path = PREFIX + "/user/keys/query"
|
||||
path = _create_v1_path("/user/keys/query")
|
||||
|
||||
content = yield self.client.post_json(
|
||||
destination=destination,
|
||||
|
@ -419,7 +419,7 @@ class TransportLayerClient(object):
|
|||
Returns:
|
||||
A dict containg the device keys.
|
||||
"""
|
||||
path = _create_path(PREFIX, "/user/devices/%s", user_id)
|
||||
path = _create_v1_path("/user/devices/%s", user_id)
|
||||
|
||||
content = yield self.client.get_json(
|
||||
destination=destination,
|
||||
|
@ -455,7 +455,7 @@ class TransportLayerClient(object):
|
|||
A dict containg the one-time keys.
|
||||
"""
|
||||
|
||||
path = PREFIX + "/user/keys/claim"
|
||||
path = _create_v1_path("/user/keys/claim")
|
||||
|
||||
content = yield self.client.post_json(
|
||||
destination=destination,
|
||||
|
@ -469,7 +469,7 @@ class TransportLayerClient(object):
|
|||
@log_function
|
||||
def get_missing_events(self, destination, room_id, earliest_events,
|
||||
latest_events, limit, min_depth, timeout):
|
||||
path = _create_path(PREFIX, "/get_missing_events/%s", room_id,)
|
||||
path = _create_v1_path("/get_missing_events/%s", room_id,)
|
||||
|
||||
content = yield self.client.post_json(
|
||||
destination=destination,
|
||||
|
@ -489,7 +489,7 @@ class TransportLayerClient(object):
|
|||
def get_group_profile(self, destination, group_id, requester_user_id):
|
||||
"""Get a group profile
|
||||
"""
|
||||
path = _create_path(PREFIX, "/groups/%s/profile", group_id,)
|
||||
path = _create_v1_path("/groups/%s/profile", group_id,)
|
||||
|
||||
return self.client.get_json(
|
||||
destination=destination,
|
||||
|
@ -508,7 +508,7 @@ class TransportLayerClient(object):
|
|||
requester_user_id (str)
|
||||
content (dict): The new profile of the group
|
||||
"""
|
||||
path = _create_path(PREFIX, "/groups/%s/profile", group_id,)
|
||||
path = _create_v1_path("/groups/%s/profile", group_id,)
|
||||
|
||||
return self.client.post_json(
|
||||
destination=destination,
|
||||
|
@ -522,7 +522,7 @@ class TransportLayerClient(object):
|
|||
def get_group_summary(self, destination, group_id, requester_user_id):
|
||||
"""Get a group summary
|
||||
"""
|
||||
path = _create_path(PREFIX, "/groups/%s/summary", group_id,)
|
||||
path = _create_v1_path("/groups/%s/summary", group_id,)
|
||||
|
||||
return self.client.get_json(
|
||||
destination=destination,
|
||||
|
@ -535,7 +535,7 @@ class TransportLayerClient(object):
|
|||
def get_rooms_in_group(self, destination, group_id, requester_user_id):
|
||||
"""Get all rooms in a group
|
||||
"""
|
||||
path = _create_path(PREFIX, "/groups/%s/rooms", group_id,)
|
||||
path = _create_v1_path("/groups/%s/rooms", group_id,)
|
||||
|
||||
return self.client.get_json(
|
||||
destination=destination,
|
||||
|
@ -548,7 +548,7 @@ class TransportLayerClient(object):
|
|||
content):
|
||||
"""Add a room to a group
|
||||
"""
|
||||
path = _create_path(PREFIX, "/groups/%s/room/%s", group_id, room_id,)
|
||||
path = _create_v1_path("/groups/%s/room/%s", group_id, room_id,)
|
||||
|
||||
return self.client.post_json(
|
||||
destination=destination,
|
||||
|
@ -562,8 +562,8 @@ class TransportLayerClient(object):
|
|||
config_key, content):
|
||||
"""Update room in group
|
||||
"""
|
||||
path = _create_path(
|
||||
PREFIX, "/groups/%s/room/%s/config/%s",
|
||||
path = _create_v1_path(
|
||||
"/groups/%s/room/%s/config/%s",
|
||||
group_id, room_id, config_key,
|
||||
)
|
||||
|
||||
|
@ -578,7 +578,7 @@ class TransportLayerClient(object):
|
|||
def remove_room_from_group(self, destination, group_id, requester_user_id, room_id):
|
||||
"""Remove a room from a group
|
||||
"""
|
||||
path = _create_path(PREFIX, "/groups/%s/room/%s", group_id, room_id,)
|
||||
path = _create_v1_path("/groups/%s/room/%s", group_id, room_id,)
|
||||
|
||||
return self.client.delete_json(
|
||||
destination=destination,
|
||||
|
@ -591,7 +591,7 @@ class TransportLayerClient(object):
|
|||
def get_users_in_group(self, destination, group_id, requester_user_id):
|
||||
"""Get users in a group
|
||||
"""
|
||||
path = _create_path(PREFIX, "/groups/%s/users", group_id,)
|
||||
path = _create_v1_path("/groups/%s/users", group_id,)
|
||||
|
||||
return self.client.get_json(
|
||||
destination=destination,
|
||||
|
@ -604,7 +604,7 @@ class TransportLayerClient(object):
|
|||
def get_invited_users_in_group(self, destination, group_id, requester_user_id):
|
||||
"""Get users that have been invited to a group
|
||||
"""
|
||||
path = _create_path(PREFIX, "/groups/%s/invited_users", group_id,)
|
||||
path = _create_v1_path("/groups/%s/invited_users", group_id,)
|
||||
|
||||
return self.client.get_json(
|
||||
destination=destination,
|
||||
|
@ -617,8 +617,8 @@ class TransportLayerClient(object):
|
|||
def accept_group_invite(self, destination, group_id, user_id, content):
|
||||
"""Accept a group invite
|
||||
"""
|
||||
path = _create_path(
|
||||
PREFIX, "/groups/%s/users/%s/accept_invite",
|
||||
path = _create_v1_path(
|
||||
"/groups/%s/users/%s/accept_invite",
|
||||
group_id, user_id,
|
||||
)
|
||||
|
||||
|
@ -633,7 +633,7 @@ class TransportLayerClient(object):
|
|||
def join_group(self, destination, group_id, user_id, content):
|
||||
"""Attempts to join a group
|
||||
"""
|
||||
path = _create_path(PREFIX, "/groups/%s/users/%s/join", group_id, user_id)
|
||||
path = _create_v1_path("/groups/%s/users/%s/join", group_id, user_id)
|
||||
|
||||
return self.client.post_json(
|
||||
destination=destination,
|
||||
|
@ -646,7 +646,7 @@ class TransportLayerClient(object):
|
|||
def invite_to_group(self, destination, group_id, user_id, requester_user_id, content):
|
||||
"""Invite a user to a group
|
||||
"""
|
||||
path = _create_path(PREFIX, "/groups/%s/users/%s/invite", group_id, user_id)
|
||||
path = _create_v1_path("/groups/%s/users/%s/invite", group_id, user_id)
|
||||
|
||||
return self.client.post_json(
|
||||
destination=destination,
|
||||
|
@ -662,7 +662,7 @@ class TransportLayerClient(object):
|
|||
invited.
|
||||
"""
|
||||
|
||||
path = _create_path(PREFIX, "/groups/local/%s/users/%s/invite", group_id, user_id)
|
||||
path = _create_v1_path("/groups/local/%s/users/%s/invite", group_id, user_id)
|
||||
|
||||
return self.client.post_json(
|
||||
destination=destination,
|
||||
|
@ -676,7 +676,7 @@ class TransportLayerClient(object):
|
|||
user_id, content):
|
||||
"""Remove a user fron a group
|
||||
"""
|
||||
path = _create_path(PREFIX, "/groups/%s/users/%s/remove", group_id, user_id)
|
||||
path = _create_v1_path("/groups/%s/users/%s/remove", group_id, user_id)
|
||||
|
||||
return self.client.post_json(
|
||||
destination=destination,
|
||||
|
@ -693,7 +693,7 @@ class TransportLayerClient(object):
|
|||
kicked from the group.
|
||||
"""
|
||||
|
||||
path = _create_path(PREFIX, "/groups/local/%s/users/%s/remove", group_id, user_id)
|
||||
path = _create_v1_path("/groups/local/%s/users/%s/remove", group_id, user_id)
|
||||
|
||||
return self.client.post_json(
|
||||
destination=destination,
|
||||
|
@ -708,7 +708,7 @@ class TransportLayerClient(object):
|
|||
the attestations
|
||||
"""
|
||||
|
||||
path = _create_path(PREFIX, "/groups/%s/renew_attestation/%s", group_id, user_id)
|
||||
path = _create_v1_path("/groups/%s/renew_attestation/%s", group_id, user_id)
|
||||
|
||||
return self.client.post_json(
|
||||
destination=destination,
|
||||
|
@ -723,12 +723,12 @@ class TransportLayerClient(object):
|
|||
"""Update a room entry in a group summary
|
||||
"""
|
||||
if category_id:
|
||||
path = _create_path(
|
||||
PREFIX, "/groups/%s/summary/categories/%s/rooms/%s",
|
||||
path = _create_v1_path(
|
||||
"/groups/%s/summary/categories/%s/rooms/%s",
|
||||
group_id, category_id, room_id,
|
||||
)
|
||||
else:
|
||||
path = _create_path(PREFIX, "/groups/%s/summary/rooms/%s", group_id, room_id,)
|
||||
path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id,)
|
||||
|
||||
return self.client.post_json(
|
||||
destination=destination,
|
||||
|
@ -744,12 +744,12 @@ class TransportLayerClient(object):
|
|||
"""Delete a room entry in a group summary
|
||||
"""
|
||||
if category_id:
|
||||
path = _create_path(
|
||||
PREFIX + "/groups/%s/summary/categories/%s/rooms/%s",
|
||||
path = _create_v1_path(
|
||||
"/groups/%s/summary/categories/%s/rooms/%s",
|
||||
group_id, category_id, room_id,
|
||||
)
|
||||
else:
|
||||
path = _create_path(PREFIX, "/groups/%s/summary/rooms/%s", group_id, room_id,)
|
||||
path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id,)
|
||||
|
||||
return self.client.delete_json(
|
||||
destination=destination,
|
||||
|
@ -762,7 +762,7 @@ class TransportLayerClient(object):
|
|||
def get_group_categories(self, destination, group_id, requester_user_id):
|
||||
"""Get all categories in a group
|
||||
"""
|
||||
path = _create_path(PREFIX, "/groups/%s/categories", group_id,)
|
||||
path = _create_v1_path("/groups/%s/categories", group_id,)
|
||||
|
||||
return self.client.get_json(
|
||||
destination=destination,
|
||||
|
@ -775,7 +775,7 @@ class TransportLayerClient(object):
|
|||
def get_group_category(self, destination, group_id, requester_user_id, category_id):
|
||||
"""Get category info in a group
|
||||
"""
|
||||
path = _create_path(PREFIX, "/groups/%s/categories/%s", group_id, category_id,)
|
||||
path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id,)
|
||||
|
||||
return self.client.get_json(
|
||||
destination=destination,
|
||||
|
@ -789,7 +789,7 @@ class TransportLayerClient(object):
|
|||
content):
|
||||
"""Update a category in a group
|
||||
"""
|
||||
path = _create_path(PREFIX, "/groups/%s/categories/%s", group_id, category_id,)
|
||||
path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id,)
|
||||
|
||||
return self.client.post_json(
|
||||
destination=destination,
|
||||
|
@ -804,7 +804,7 @@ class TransportLayerClient(object):
|
|||
category_id):
|
||||
"""Delete a category in a group
|
||||
"""
|
||||
path = _create_path(PREFIX, "/groups/%s/categories/%s", group_id, category_id,)
|
||||
path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id,)
|
||||
|
||||
return self.client.delete_json(
|
||||
destination=destination,
|
||||
|
@ -817,7 +817,7 @@ class TransportLayerClient(object):
|
|||
def get_group_roles(self, destination, group_id, requester_user_id):
|
||||
"""Get all roles in a group
|
||||
"""
|
||||
path = _create_path(PREFIX, "/groups/%s/roles", group_id,)
|
||||
path = _create_v1_path("/groups/%s/roles", group_id,)
|
||||
|
||||
return self.client.get_json(
|
||||
destination=destination,
|
||||
|
@ -830,7 +830,7 @@ class TransportLayerClient(object):
|
|||
def get_group_role(self, destination, group_id, requester_user_id, role_id):
|
||||
"""Get a roles info
|
||||
"""
|
||||
path = _create_path(PREFIX, "/groups/%s/roles/%s", group_id, role_id,)
|
||||
path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id,)
|
||||
|
||||
return self.client.get_json(
|
||||
destination=destination,
|
||||
|
@ -844,7 +844,7 @@ class TransportLayerClient(object):
|
|||
content):
|
||||
"""Update a role in a group
|
||||
"""
|
||||
path = _create_path(PREFIX, "/groups/%s/roles/%s", group_id, role_id,)
|
||||
path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id,)
|
||||
|
||||
return self.client.post_json(
|
||||
destination=destination,
|
||||
|
@ -858,7 +858,7 @@ class TransportLayerClient(object):
|
|||
def delete_group_role(self, destination, group_id, requester_user_id, role_id):
|
||||
"""Delete a role in a group
|
||||
"""
|
||||
path = _create_path(PREFIX, "/groups/%s/roles/%s", group_id, role_id,)
|
||||
path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id,)
|
||||
|
||||
return self.client.delete_json(
|
||||
destination=destination,
|
||||
|
@ -873,12 +873,12 @@ class TransportLayerClient(object):
|
|||
"""Update a users entry in a group
|
||||
"""
|
||||
if role_id:
|
||||
path = _create_path(
|
||||
PREFIX, "/groups/%s/summary/roles/%s/users/%s",
|
||||
path = _create_v1_path(
|
||||
"/groups/%s/summary/roles/%s/users/%s",
|
||||
group_id, role_id, user_id,
|
||||
)
|
||||
else:
|
||||
path = _create_path(PREFIX, "/groups/%s/summary/users/%s", group_id, user_id,)
|
||||
path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id,)
|
||||
|
||||
return self.client.post_json(
|
||||
destination=destination,
|
||||
|
@ -893,7 +893,7 @@ class TransportLayerClient(object):
|
|||
content):
|
||||
"""Sets the join policy for a group
|
||||
"""
|
||||
path = _create_path(PREFIX, "/groups/%s/settings/m.join_policy", group_id,)
|
||||
path = _create_v1_path("/groups/%s/settings/m.join_policy", group_id,)
|
||||
|
||||
return self.client.put_json(
|
||||
destination=destination,
|
||||
|
@ -909,12 +909,12 @@ class TransportLayerClient(object):
|
|||
"""Delete a users entry in a group
|
||||
"""
|
||||
if role_id:
|
||||
path = _create_path(
|
||||
PREFIX, "/groups/%s/summary/roles/%s/users/%s",
|
||||
path = _create_v1_path(
|
||||
"/groups/%s/summary/roles/%s/users/%s",
|
||||
group_id, role_id, user_id,
|
||||
)
|
||||
else:
|
||||
path = _create_path(PREFIX, "/groups/%s/summary/users/%s", group_id, user_id,)
|
||||
path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id,)
|
||||
|
||||
return self.client.delete_json(
|
||||
destination=destination,
|
||||
|
@ -927,7 +927,7 @@ class TransportLayerClient(object):
|
|||
"""Get the groups a list of users are publicising
|
||||
"""
|
||||
|
||||
path = PREFIX + "/get_groups_publicised"
|
||||
path = _create_v1_path("/get_groups_publicised")
|
||||
|
||||
content = {"user_ids": user_ids}
|
||||
|
||||
|
@ -939,20 +939,22 @@ class TransportLayerClient(object):
|
|||
)
|
||||
|
||||
|
||||
def _create_path(prefix, path, *args):
|
||||
"""Creates a path from the prefix, path template and args. Ensures that
|
||||
all args are url encoded.
|
||||
def _create_v1_path(path, *args):
|
||||
"""Creates a path against V1 federation API from the path template and
|
||||
args. Ensures that all args are url encoded.
|
||||
|
||||
Example:
|
||||
|
||||
_create_path(PREFIX, "/event/%s/", event_id)
|
||||
_create_v1_path("/event/%s/", event_id)
|
||||
|
||||
Args:
|
||||
prefix (str)
|
||||
path (str): String template for the path
|
||||
args: ([str]): Args to insert into path. Each arg will be url encoded
|
||||
|
||||
Returns:
|
||||
str
|
||||
"""
|
||||
return prefix + path % tuple(urllib.parse.quote(arg, "") for arg in args)
|
||||
return (
|
||||
FEDERATION_V1_PREFIX
|
||||
+ path % tuple(urllib.parse.quote(arg, "") for arg in args)
|
||||
)
|
||||
|
|
|
@ -21,8 +21,9 @@ import re
|
|||
from twisted.internet import defer
|
||||
|
||||
import synapse
|
||||
from synapse.api.constants import RoomVersions
|
||||
from synapse.api.errors import Codes, FederationDeniedError, SynapseError
|
||||
from synapse.api.urls import FEDERATION_PREFIX as PREFIX
|
||||
from synapse.api.urls import FEDERATION_V1_PREFIX, FEDERATION_V2_PREFIX
|
||||
from synapse.http.endpoint import parse_and_validate_server_name
|
||||
from synapse.http.server import JsonResource
|
||||
from synapse.http.servlet import (
|
||||
|
@ -227,6 +228,8 @@ class BaseFederationServlet(object):
|
|||
"""
|
||||
REQUIRE_AUTH = True
|
||||
|
||||
PREFIX = FEDERATION_V1_PREFIX # Allows specifying the API version
|
||||
|
||||
def __init__(self, handler, authenticator, ratelimiter, server_name):
|
||||
self.handler = handler
|
||||
self.authenticator = authenticator
|
||||
|
@ -286,7 +289,7 @@ class BaseFederationServlet(object):
|
|||
return new_func
|
||||
|
||||
def register(self, server):
|
||||
pattern = re.compile("^" + PREFIX + self.PATH + "$")
|
||||
pattern = re.compile("^" + self.PREFIX + self.PATH + "$")
|
||||
|
||||
for method in ("GET", "PUT", "POST"):
|
||||
code = getattr(self, "on_%s" % (method), None)
|
||||
|
@ -488,14 +491,46 @@ class FederationSendJoinServlet(BaseFederationServlet):
|
|||
defer.returnValue((200, content))
|
||||
|
||||
|
||||
class FederationInviteServlet(BaseFederationServlet):
|
||||
class FederationV1InviteServlet(BaseFederationServlet):
|
||||
PATH = "/invite/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, origin, content, query, context, event_id):
|
||||
# We don't get a room version, so we have to assume its EITHER v1 or
|
||||
# v2. This is "fine" as the only difference between V1 and V2 is the
|
||||
# state resolution algorithm, and we don't use that for processing
|
||||
# invites
|
||||
content = yield self.handler.on_invite_request(
|
||||
origin, content, room_version=RoomVersions.V1,
|
||||
)
|
||||
|
||||
# V1 federation API is defined to return a content of `[200, {...}]`
|
||||
# due to a historical bug.
|
||||
defer.returnValue((200, (200, content)))
|
||||
|
||||
|
||||
class FederationV2InviteServlet(BaseFederationServlet):
|
||||
PATH = "/invite/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
|
||||
|
||||
PREFIX = FEDERATION_V2_PREFIX
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, origin, content, query, context, event_id):
|
||||
# TODO(paul): assert that context/event_id parsed from path actually
|
||||
# match those given in content
|
||||
content = yield self.handler.on_invite_request(origin, content)
|
||||
|
||||
room_version = content["room_version"]
|
||||
event = content["event"]
|
||||
invite_room_state = content["invite_room_state"]
|
||||
|
||||
# Synapse expects invite_room_state to be in unsigned, as it is in v1
|
||||
# API
|
||||
|
||||
event.setdefault("unsigned", {})["invite_room_state"] = invite_room_state
|
||||
|
||||
content = yield self.handler.on_invite_request(
|
||||
origin, event, room_version=room_version,
|
||||
)
|
||||
defer.returnValue((200, content))
|
||||
|
||||
|
||||
|
@ -1263,7 +1298,8 @@ FEDERATION_SERVLET_CLASSES = (
|
|||
FederationEventServlet,
|
||||
FederationSendJoinServlet,
|
||||
FederationSendLeaveServlet,
|
||||
FederationInviteServlet,
|
||||
FederationV1InviteServlet,
|
||||
FederationV2InviteServlet,
|
||||
FederationQueryAuthServlet,
|
||||
FederationGetMissingEventsServlet,
|
||||
FederationEventAuthServlet,
|
||||
|
|
147
synapse/handlers/acme.py
Normal file
147
synapse/handlers/acme.py
Normal file
|
@ -0,0 +1,147 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2019 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
import attr
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.endpoints import serverFromString
|
||||
from twisted.python.filepath import FilePath
|
||||
from twisted.python.url import URL
|
||||
from twisted.web import server, static
|
||||
from twisted.web.resource import Resource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from txacme.interfaces import ICertificateStore
|
||||
|
||||
@attr.s
|
||||
@implementer(ICertificateStore)
|
||||
class ErsatzStore(object):
|
||||
"""
|
||||
A store that only stores in memory.
|
||||
"""
|
||||
|
||||
certs = attr.ib(default=attr.Factory(dict))
|
||||
|
||||
def store(self, server_name, pem_objects):
|
||||
self.certs[server_name] = [o.as_bytes() for o in pem_objects]
|
||||
return defer.succeed(None)
|
||||
|
||||
|
||||
except ImportError:
|
||||
# txacme is missing
|
||||
pass
|
||||
|
||||
|
||||
class AcmeHandler(object):
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
self.reactor = hs.get_reactor()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def start_listening(self):
|
||||
|
||||
# Configure logging for txacme, if you need to debug
|
||||
# from eliot import add_destinations
|
||||
# from eliot.twisted import TwistedDestination
|
||||
#
|
||||
# add_destinations(TwistedDestination())
|
||||
|
||||
from txacme.challenges import HTTP01Responder
|
||||
from txacme.service import AcmeIssuingService
|
||||
from txacme.endpoint import load_or_create_client_key
|
||||
from txacme.client import Client
|
||||
from josepy.jwa import RS256
|
||||
|
||||
self._store = ErsatzStore()
|
||||
responder = HTTP01Responder()
|
||||
|
||||
self._issuer = AcmeIssuingService(
|
||||
cert_store=self._store,
|
||||
client_creator=(
|
||||
lambda: Client.from_url(
|
||||
reactor=self.reactor,
|
||||
url=URL.from_text(self.hs.config.acme_url),
|
||||
key=load_or_create_client_key(
|
||||
FilePath(self.hs.config.config_dir_path)
|
||||
),
|
||||
alg=RS256,
|
||||
)
|
||||
),
|
||||
clock=self.reactor,
|
||||
responders=[responder],
|
||||
)
|
||||
|
||||
well_known = Resource()
|
||||
well_known.putChild(b'acme-challenge', responder.resource)
|
||||
responder_resource = Resource()
|
||||
responder_resource.putChild(b'.well-known', well_known)
|
||||
responder_resource.putChild(b'check', static.Data(b'OK', b'text/plain'))
|
||||
|
||||
srv = server.Site(responder_resource)
|
||||
|
||||
listeners = []
|
||||
|
||||
for host in self.hs.config.acme_bind_addresses:
|
||||
logger.info(
|
||||
"Listening for ACME requests on %s:%s", host, self.hs.config.acme_port
|
||||
)
|
||||
endpoint = serverFromString(
|
||||
self.reactor, "tcp:%s:interface=%s" % (self.hs.config.acme_port, host)
|
||||
)
|
||||
listeners.append(endpoint.listen(srv))
|
||||
|
||||
# Make sure we are registered to the ACME server. There's no public API
|
||||
# for this, it is usually triggered by startService, but since we don't
|
||||
# want it to control where we save the certificates, we have to reach in
|
||||
# and trigger the registration machinery ourselves.
|
||||
self._issuer._registered = False
|
||||
yield self._issuer._ensure_registered()
|
||||
|
||||
# Return a Deferred that will fire when all the servers have started up.
|
||||
yield defer.DeferredList(listeners, fireOnOneErrback=True, consumeErrors=True)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def provision_certificate(self):
|
||||
|
||||
logger.warning("Reprovisioning %s", self.hs.hostname)
|
||||
|
||||
try:
|
||||
yield self._issuer.issue_cert(self.hs.hostname)
|
||||
except Exception:
|
||||
logger.exception("Fail!")
|
||||
raise
|
||||
logger.warning("Reprovisioned %s, saving.", self.hs.hostname)
|
||||
cert_chain = self._store.certs[self.hs.hostname]
|
||||
|
||||
try:
|
||||
with open(self.hs.config.tls_private_key_file, "wb") as private_key_file:
|
||||
for x in cert_chain:
|
||||
if x.startswith(b"-----BEGIN RSA PRIVATE KEY-----"):
|
||||
private_key_file.write(x)
|
||||
|
||||
with open(self.hs.config.tls_certificate_file, "wb") as certificate_file:
|
||||
for x in cert_chain:
|
||||
if x.startswith(b"-----BEGIN CERTIFICATE-----"):
|
||||
certificate_file.write(x)
|
||||
except Exception:
|
||||
logger.exception("Failed saving!")
|
||||
raise
|
||||
|
||||
defer.returnValue(True)
|
|
@ -167,18 +167,21 @@ class IdentityHandler(BaseHandler):
|
|||
"mxid": mxid,
|
||||
"threepid": threepid,
|
||||
}
|
||||
headers = {}
|
||||
|
||||
# we abuse the federation http client to sign the request, but we have to send it
|
||||
# using the normal http client since we don't want the SRV lookup and want normal
|
||||
# 'browser-like' HTTPS.
|
||||
self.federation_http_client.sign_request(
|
||||
auth_headers = self.federation_http_client.build_auth_headers(
|
||||
destination=None,
|
||||
method='POST',
|
||||
url_bytes='/_matrix/identity/api/v1/3pid/unbind'.encode('ascii'),
|
||||
headers_dict=headers,
|
||||
content=content,
|
||||
destination_is=id_server,
|
||||
)
|
||||
headers = {
|
||||
b"Authorization": auth_headers,
|
||||
}
|
||||
|
||||
try:
|
||||
yield self.http_client.post_json_get_json(
|
||||
url,
|
||||
|
|
|
@ -269,6 +269,7 @@ class RoomCreationHandler(BaseHandler):
|
|||
(EventTypes.RoomHistoryVisibility, ""),
|
||||
(EventTypes.GuestAccess, ""),
|
||||
(EventTypes.RoomAvatar, ""),
|
||||
(EventTypes.Encryption, ""),
|
||||
)
|
||||
|
||||
old_room_state_ids = yield self.store.get_filtered_current_state_ids(
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from six import PY3, iteritems
|
||||
from six.moves import range
|
||||
|
@ -76,8 +77,14 @@ class RoomListHandler(BaseHandler):
|
|||
# We explicitly don't bother caching searches or requests for
|
||||
# appservice specific lists.
|
||||
logger.info("Bypassing cache as search request.")
|
||||
|
||||
# XXX: Quick hack to stop room directory queries taking too long.
|
||||
# Timeout request after 60s. Probably want a more fundamental
|
||||
# solution at some point
|
||||
timeout = datetime.now() + timedelta(seconds=60)
|
||||
return self._get_public_room_list(
|
||||
limit, since_token, search_filter, network_tuple=network_tuple,
|
||||
limit, since_token, search_filter,
|
||||
network_tuple=network_tuple, timeout=timeout,
|
||||
)
|
||||
|
||||
key = (limit, since_token, network_tuple)
|
||||
|
@ -90,7 +97,8 @@ class RoomListHandler(BaseHandler):
|
|||
@defer.inlineCallbacks
|
||||
def _get_public_room_list(self, limit=None, since_token=None,
|
||||
search_filter=None,
|
||||
network_tuple=EMPTY_THIRD_PARTY_ID,):
|
||||
network_tuple=EMPTY_THIRD_PARTY_ID,
|
||||
timeout=None,):
|
||||
if since_token and since_token != "END":
|
||||
since_token = RoomListNextBatch.from_token(since_token)
|
||||
else:
|
||||
|
@ -205,6 +213,9 @@ class RoomListHandler(BaseHandler):
|
|||
|
||||
chunk = []
|
||||
for i in range(0, len(rooms_to_scan), step):
|
||||
if timeout and datetime.now() > timeout:
|
||||
raise Exception("Timed out searching room directory")
|
||||
|
||||
batch = rooms_to_scan[i:i + step]
|
||||
logger.info("Processing %i rooms for result", len(batch))
|
||||
yield concurrently_execute(
|
||||
|
|
|
@ -20,6 +20,7 @@ from six import iteritems
|
|||
|
||||
from twisted.internet import defer
|
||||
|
||||
import synapse.metrics
|
||||
from synapse.api.constants import EventTypes, JoinRules, Membership
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.storage.roommember import ProfileInfo
|
||||
|
|
|
@ -12,30 +12,11 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import collections
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
|
||||
from twisted.internet.error import ConnectError
|
||||
from twisted.names import client, dns
|
||||
from twisted.names.error import DNSNameError, DomainError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SERVER_CACHE = {}
|
||||
|
||||
# our record of an individual server which can be tried to reach a destination.
|
||||
#
|
||||
# "host" is the hostname acquired from the SRV record. Except when there's
|
||||
# no SRV record, in which case it is the original hostname.
|
||||
_Server = collections.namedtuple(
|
||||
"_Server", "priority weight host port expires"
|
||||
)
|
||||
|
||||
|
||||
def parse_server_name(server_name):
|
||||
"""Split a server name into host/port parts.
|
||||
|
@ -100,264 +81,3 @@ def parse_and_validate_server_name(server_name):
|
|||
))
|
||||
|
||||
return host, port
|
||||
|
||||
|
||||
def matrix_federation_endpoint(reactor, destination, tls_client_options_factory=None,
|
||||
timeout=None):
|
||||
"""Construct an endpoint for the given matrix destination.
|
||||
|
||||
Args:
|
||||
reactor: Twisted reactor.
|
||||
destination (unicode): The name of the server to connect to.
|
||||
tls_client_options_factory
|
||||
(synapse.crypto.context_factory.ClientTLSOptionsFactory):
|
||||
Factory which generates TLS options for client connections.
|
||||
timeout (int): connection timeout in seconds
|
||||
"""
|
||||
|
||||
domain, port = parse_server_name(destination)
|
||||
|
||||
endpoint_kw_args = {}
|
||||
|
||||
if timeout is not None:
|
||||
endpoint_kw_args.update(timeout=timeout)
|
||||
|
||||
if tls_client_options_factory is None:
|
||||
transport_endpoint = HostnameEndpoint
|
||||
default_port = 8008
|
||||
else:
|
||||
# the SNI string should be the same as the Host header, minus the port.
|
||||
# as per https://github.com/matrix-org/synapse/issues/2525#issuecomment-336896777,
|
||||
# the Host header and SNI should therefore be the server_name of the remote
|
||||
# server.
|
||||
tls_options = tls_client_options_factory.get_options(domain)
|
||||
|
||||
def transport_endpoint(reactor, host, port, timeout):
|
||||
return wrapClientTLS(
|
||||
tls_options,
|
||||
HostnameEndpoint(reactor, host, port, timeout=timeout),
|
||||
)
|
||||
default_port = 8448
|
||||
|
||||
if port is None:
|
||||
return _WrappingEndpointFac(SRVClientEndpoint(
|
||||
reactor, "matrix", domain, protocol="tcp",
|
||||
default_port=default_port, endpoint=transport_endpoint,
|
||||
endpoint_kw_args=endpoint_kw_args
|
||||
), reactor)
|
||||
else:
|
||||
return _WrappingEndpointFac(transport_endpoint(
|
||||
reactor, domain, port, **endpoint_kw_args
|
||||
), reactor)
|
||||
|
||||
|
||||
class _WrappingEndpointFac(object):
|
||||
def __init__(self, endpoint_fac, reactor):
|
||||
self.endpoint_fac = endpoint_fac
|
||||
self.reactor = reactor
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def connect(self, protocolFactory):
|
||||
conn = yield self.endpoint_fac.connect(protocolFactory)
|
||||
conn = _WrappedConnection(conn, self.reactor)
|
||||
defer.returnValue(conn)
|
||||
|
||||
|
||||
class _WrappedConnection(object):
|
||||
"""Wraps a connection and calls abort on it if it hasn't seen any action
|
||||
for 2.5-3 minutes.
|
||||
"""
|
||||
__slots__ = ["conn", "last_request"]
|
||||
|
||||
def __init__(self, conn, reactor):
|
||||
object.__setattr__(self, "conn", conn)
|
||||
object.__setattr__(self, "last_request", time.time())
|
||||
self._reactor = reactor
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self.conn, name)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
setattr(self.conn, name, value)
|
||||
|
||||
def _time_things_out_maybe(self):
|
||||
# We use a slightly shorter timeout here just in case the callLater is
|
||||
# triggered early. Paranoia ftw.
|
||||
# TODO: Cancel the previous callLater rather than comparing time.time()?
|
||||
if time.time() - self.last_request >= 2.5 * 60:
|
||||
self.abort()
|
||||
# Abort the underlying TLS connection. The abort() method calls
|
||||
# loseConnection() on the TLS connection which tries to
|
||||
# shutdown the connection cleanly. We call abortConnection()
|
||||
# since that will promptly close the TLS connection.
|
||||
#
|
||||
# In Twisted >18.4; the TLS connection will be None if it has closed
|
||||
# which will make abortConnection() throw. Check that the TLS connection
|
||||
# is not None before trying to close it.
|
||||
if self.transport.getHandle() is not None:
|
||||
self.transport.abortConnection()
|
||||
|
||||
def request(self, request):
|
||||
self.last_request = time.time()
|
||||
|
||||
# Time this connection out if we haven't send a request in the last
|
||||
# N minutes
|
||||
# TODO: Cancel the previous callLater?
|
||||
self._reactor.callLater(3 * 60, self._time_things_out_maybe)
|
||||
|
||||
d = self.conn.request(request)
|
||||
|
||||
def update_request_time(res):
|
||||
self.last_request = time.time()
|
||||
# TODO: Cancel the previous callLater?
|
||||
self._reactor.callLater(3 * 60, self._time_things_out_maybe)
|
||||
return res
|
||||
|
||||
d.addCallback(update_request_time)
|
||||
|
||||
return d
|
||||
|
||||
|
||||
class SRVClientEndpoint(object):
|
||||
"""An endpoint which looks up SRV records for a service.
|
||||
Cycles through the list of servers starting with each call to connect
|
||||
picking the next server.
|
||||
Implements twisted.internet.interfaces.IStreamClientEndpoint.
|
||||
"""
|
||||
|
||||
def __init__(self, reactor, service, domain, protocol="tcp",
|
||||
default_port=None, endpoint=HostnameEndpoint,
|
||||
endpoint_kw_args={}):
|
||||
self.reactor = reactor
|
||||
self.service_name = "_%s._%s.%s" % (service, protocol, domain)
|
||||
|
||||
if default_port is not None:
|
||||
self.default_server = _Server(
|
||||
host=domain,
|
||||
port=default_port,
|
||||
priority=0,
|
||||
weight=0,
|
||||
expires=0,
|
||||
)
|
||||
else:
|
||||
self.default_server = None
|
||||
|
||||
self.endpoint = endpoint
|
||||
self.endpoint_kw_args = endpoint_kw_args
|
||||
|
||||
self.servers = None
|
||||
self.used_servers = None
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def fetch_servers(self):
|
||||
self.used_servers = []
|
||||
self.servers = yield resolve_service(self.service_name)
|
||||
|
||||
def pick_server(self):
|
||||
if not self.servers:
|
||||
if self.used_servers:
|
||||
self.servers = self.used_servers
|
||||
self.used_servers = []
|
||||
self.servers.sort()
|
||||
elif self.default_server:
|
||||
return self.default_server
|
||||
else:
|
||||
raise ConnectError(
|
||||
"No server available for %s" % self.service_name
|
||||
)
|
||||
|
||||
# look for all servers with the same priority
|
||||
min_priority = self.servers[0].priority
|
||||
weight_indexes = list(
|
||||
(index, server.weight + 1)
|
||||
for index, server in enumerate(self.servers)
|
||||
if server.priority == min_priority
|
||||
)
|
||||
|
||||
total_weight = sum(weight for index, weight in weight_indexes)
|
||||
target_weight = random.randint(0, total_weight)
|
||||
for index, weight in weight_indexes:
|
||||
target_weight -= weight
|
||||
if target_weight <= 0:
|
||||
server = self.servers[index]
|
||||
# XXX: this looks totally dubious:
|
||||
#
|
||||
# (a) we never reuse a server until we have been through
|
||||
# all of the servers at the same priority, so if the
|
||||
# weights are A: 100, B:1, we always do ABABAB instead of
|
||||
# AAAA...AAAB (approximately).
|
||||
#
|
||||
# (b) After using all the servers at the lowest priority,
|
||||
# we move onto the next priority. We should only use the
|
||||
# second priority if servers at the top priority are
|
||||
# unreachable.
|
||||
#
|
||||
del self.servers[index]
|
||||
self.used_servers.append(server)
|
||||
return server
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def connect(self, protocolFactory):
|
||||
if self.servers is None:
|
||||
yield self.fetch_servers()
|
||||
server = self.pick_server()
|
||||
logger.info("Connecting to %s:%s", server.host, server.port)
|
||||
endpoint = self.endpoint(
|
||||
self.reactor, server.host, server.port, **self.endpoint_kw_args
|
||||
)
|
||||
connection = yield endpoint.connect(protocolFactory)
|
||||
defer.returnValue(connection)
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=time):
|
||||
cache_entry = cache.get(service_name, None)
|
||||
if cache_entry:
|
||||
if all(s.expires > int(clock.time()) for s in cache_entry):
|
||||
servers = list(cache_entry)
|
||||
defer.returnValue(servers)
|
||||
|
||||
servers = []
|
||||
|
||||
try:
|
||||
try:
|
||||
answers, _, _ = yield dns_client.lookupService(service_name)
|
||||
except DNSNameError:
|
||||
defer.returnValue([])
|
||||
|
||||
if (len(answers) == 1
|
||||
and answers[0].type == dns.SRV
|
||||
and answers[0].payload
|
||||
and answers[0].payload.target == dns.Name(b'.')):
|
||||
raise ConnectError("Service %s unavailable" % service_name)
|
||||
|
||||
for answer in answers:
|
||||
if answer.type != dns.SRV or not answer.payload:
|
||||
continue
|
||||
|
||||
payload = answer.payload
|
||||
|
||||
servers.append(_Server(
|
||||
host=str(payload.target),
|
||||
port=int(payload.port),
|
||||
priority=int(payload.priority),
|
||||
weight=int(payload.weight),
|
||||
expires=int(clock.time()) + answer.ttl,
|
||||
))
|
||||
|
||||
servers.sort()
|
||||
cache[service_name] = list(servers)
|
||||
except DomainError as e:
|
||||
# We failed to resolve the name (other than a NameError)
|
||||
# Try something in the cache, else rereaise
|
||||
cache_entry = cache.get(service_name, None)
|
||||
if cache_entry:
|
||||
logger.warn(
|
||||
"Failed to resolve %r, falling back to cache. %r",
|
||||
service_name, e
|
||||
)
|
||||
servers = list(cache_entry)
|
||||
else:
|
||||
raise e
|
||||
|
||||
defer.returnValue(servers)
|
||||
|
|
14
synapse/http/federation/__init__.py
Normal file
14
synapse/http/federation/__init__.py
Normal file
|
@ -0,0 +1,14 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2019 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
124
synapse/http/federation/matrix_federation_agent.py
Normal file
124
synapse/http/federation/matrix_federation_agent.py
Normal file
|
@ -0,0 +1,124 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2019 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
|
||||
from twisted.web.client import URI, Agent, HTTPConnectionPool
|
||||
from twisted.web.iweb import IAgent
|
||||
|
||||
from synapse.http.endpoint import parse_server_name
|
||||
from synapse.http.federation.srv_resolver import SrvResolver, pick_server_from_list
|
||||
from synapse.util.logcontext import make_deferred_yieldable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@implementer(IAgent)
|
||||
class MatrixFederationAgent(object):
|
||||
"""An Agent-like thing which provides a `request` method which will look up a matrix
|
||||
server and send an HTTP request to it.
|
||||
|
||||
Doesn't implement any retries. (Those are done in MatrixFederationHttpClient.)
|
||||
|
||||
Args:
|
||||
reactor (IReactor): twisted reactor to use for underlying requests
|
||||
|
||||
tls_client_options_factory (ClientTLSOptionsFactory|None):
|
||||
factory to use for fetching client tls options, or none to disable TLS.
|
||||
|
||||
srv_resolver (SrvResolver|None):
|
||||
SRVResolver impl to use for looking up SRV records. None to use a default
|
||||
implementation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, reactor, tls_client_options_factory, _srv_resolver=None,
|
||||
):
|
||||
self._reactor = reactor
|
||||
self._tls_client_options_factory = tls_client_options_factory
|
||||
if _srv_resolver is None:
|
||||
_srv_resolver = SrvResolver()
|
||||
self._srv_resolver = _srv_resolver
|
||||
|
||||
self._pool = HTTPConnectionPool(reactor)
|
||||
self._pool.retryAutomatically = False
|
||||
self._pool.maxPersistentPerHost = 5
|
||||
self._pool.cachedConnectionTimeout = 2 * 60
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def request(self, method, uri, headers=None, bodyProducer=None):
|
||||
"""
|
||||
Args:
|
||||
method (bytes): HTTP method: GET/POST/etc
|
||||
|
||||
uri (bytes): Absolute URI to be retrieved
|
||||
|
||||
headers (twisted.web.http_headers.Headers|None):
|
||||
HTTP headers to send with the request, or None to
|
||||
send no extra headers.
|
||||
|
||||
bodyProducer (twisted.web.iweb.IBodyProducer|None):
|
||||
An object which can generate bytes to make up the
|
||||
body of this request (for example, the properly encoded contents of
|
||||
a file for a file upload). Or None if the request is to have
|
||||
no body.
|
||||
|
||||
Returns:
|
||||
Deferred[twisted.web.iweb.IResponse]:
|
||||
fires when the header of the response has been received (regardless of the
|
||||
response status code). Fails if there is any problem which prevents that
|
||||
response from being received (including problems that prevent the request
|
||||
from being sent).
|
||||
"""
|
||||
|
||||
parsed_uri = URI.fromBytes(uri)
|
||||
server_name_bytes = parsed_uri.netloc
|
||||
host, port = parse_server_name(server_name_bytes.decode("ascii"))
|
||||
|
||||
# XXX disabling TLS is really only supported here for the benefit of the
|
||||
# unit tests. We should make the UTs cope with TLS rather than having to make
|
||||
# the code support the unit tests.
|
||||
if self._tls_client_options_factory is None:
|
||||
tls_options = None
|
||||
else:
|
||||
tls_options = self._tls_client_options_factory.get_options(host)
|
||||
|
||||
if port is not None:
|
||||
target = (host, port)
|
||||
else:
|
||||
server_list = yield self._srv_resolver.resolve_service(server_name_bytes)
|
||||
if not server_list:
|
||||
target = (host, 8448)
|
||||
logger.debug("No SRV record for %s, using %s", host, target)
|
||||
else:
|
||||
target = pick_server_from_list(server_list)
|
||||
|
||||
class EndpointFactory(object):
|
||||
@staticmethod
|
||||
def endpointForURI(_uri):
|
||||
logger.info("Connecting to %s:%s", target[0], target[1])
|
||||
ep = HostnameEndpoint(self._reactor, host=target[0], port=target[1])
|
||||
if tls_options is not None:
|
||||
ep = wrapClientTLS(tls_options, ep)
|
||||
return ep
|
||||
|
||||
agent = Agent.usingEndpointFactory(self._reactor, EndpointFactory(), self._pool)
|
||||
res = yield make_deferred_yieldable(
|
||||
agent.request(method, uri, headers, bodyProducer)
|
||||
)
|
||||
defer.returnValue(res)
|
169
synapse/http/federation/srv_resolver.py
Normal file
169
synapse/http/federation/srv_resolver.py
Normal file
|
@ -0,0 +1,169 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2019 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
|
||||
import attr
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.error import ConnectError
|
||||
from twisted.names import client, dns
|
||||
from twisted.names.error import DNSNameError, DomainError
|
||||
|
||||
from synapse.util.logcontext import make_deferred_yieldable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SERVER_CACHE = {}
|
||||
|
||||
|
||||
@attr.s
|
||||
class Server(object):
|
||||
"""
|
||||
Our record of an individual server which can be tried to reach a destination.
|
||||
|
||||
Attributes:
|
||||
host (bytes): target hostname
|
||||
port (int):
|
||||
priority (int):
|
||||
weight (int):
|
||||
expires (int): when the cache should expire this record - in *seconds* since
|
||||
the epoch
|
||||
"""
|
||||
host = attr.ib()
|
||||
port = attr.ib()
|
||||
priority = attr.ib(default=0)
|
||||
weight = attr.ib(default=0)
|
||||
expires = attr.ib(default=0)
|
||||
|
||||
|
||||
def pick_server_from_list(server_list):
|
||||
"""Randomly choose a server from the server list
|
||||
|
||||
Args:
|
||||
server_list (list[Server]): list of candidate servers
|
||||
|
||||
Returns:
|
||||
Tuple[bytes, int]: (host, port) pair for the chosen server
|
||||
"""
|
||||
if not server_list:
|
||||
raise RuntimeError("pick_server_from_list called with empty list")
|
||||
|
||||
# TODO: currently we only use the lowest-priority servers. We should maintain a
|
||||
# cache of servers known to be "down" and filter them out
|
||||
|
||||
min_priority = min(s.priority for s in server_list)
|
||||
eligible_servers = list(s for s in server_list if s.priority == min_priority)
|
||||
total_weight = sum(s.weight for s in eligible_servers)
|
||||
target_weight = random.randint(0, total_weight)
|
||||
|
||||
for s in eligible_servers:
|
||||
target_weight -= s.weight
|
||||
|
||||
if target_weight <= 0:
|
||||
return s.host, s.port
|
||||
|
||||
# this should be impossible.
|
||||
raise RuntimeError(
|
||||
"pick_server_from_list got to end of eligible server list.",
|
||||
)
|
||||
|
||||
|
||||
class SrvResolver(object):
|
||||
"""Interface to the dns client to do SRV lookups, with result caching.
|
||||
|
||||
The default resolver in twisted.names doesn't do any caching (it has a CacheResolver,
|
||||
but the cache never gets populated), so we add our own caching layer here.
|
||||
|
||||
Args:
|
||||
dns_client (twisted.internet.interfaces.IResolver): twisted resolver impl
|
||||
cache (dict): cache object
|
||||
get_time (callable): clock implementation. Should return seconds since the epoch
|
||||
"""
|
||||
def __init__(self, dns_client=client, cache=SERVER_CACHE, get_time=time.time):
|
||||
self._dns_client = dns_client
|
||||
self._cache = cache
|
||||
self._get_time = get_time
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def resolve_service(self, service_name):
|
||||
"""Look up a SRV record
|
||||
|
||||
Args:
|
||||
service_name (bytes): record to look up
|
||||
|
||||
Returns:
|
||||
Deferred[list[Server]]:
|
||||
a list of the SRV records, or an empty list if none found
|
||||
"""
|
||||
now = int(self._get_time())
|
||||
|
||||
if not isinstance(service_name, bytes):
|
||||
raise TypeError("%r is not a byte string" % (service_name,))
|
||||
|
||||
cache_entry = self._cache.get(service_name, None)
|
||||
if cache_entry:
|
||||
if all(s.expires > now for s in cache_entry):
|
||||
servers = list(cache_entry)
|
||||
defer.returnValue(servers)
|
||||
|
||||
try:
|
||||
answers, _, _ = yield make_deferred_yieldable(
|
||||
self._dns_client.lookupService(service_name),
|
||||
)
|
||||
except DNSNameError:
|
||||
# TODO: cache this. We can get the SOA out of the exception, and use
|
||||
# the negative-TTL value.
|
||||
defer.returnValue([])
|
||||
except DomainError as e:
|
||||
# We failed to resolve the name (other than a NameError)
|
||||
# Try something in the cache, else rereaise
|
||||
cache_entry = self._cache.get(service_name, None)
|
||||
if cache_entry:
|
||||
logger.warn(
|
||||
"Failed to resolve %r, falling back to cache. %r",
|
||||
service_name, e
|
||||
)
|
||||
defer.returnValue(list(cache_entry))
|
||||
else:
|
||||
raise e
|
||||
|
||||
if (len(answers) == 1
|
||||
and answers[0].type == dns.SRV
|
||||
and answers[0].payload
|
||||
and answers[0].payload.target == dns.Name(b'.')):
|
||||
raise ConnectError("Service %s unavailable" % service_name)
|
||||
|
||||
servers = []
|
||||
|
||||
for answer in answers:
|
||||
if answer.type != dns.SRV or not answer.payload:
|
||||
continue
|
||||
|
||||
payload = answer.payload
|
||||
|
||||
servers.append(Server(
|
||||
host=payload.target.name,
|
||||
port=payload.port,
|
||||
priority=payload.priority,
|
||||
weight=payload.weight,
|
||||
expires=now + answer.ttl,
|
||||
))
|
||||
|
||||
self._cache[service_name] = list(servers)
|
||||
defer.returnValue(servers)
|
|
@ -32,7 +32,7 @@ from twisted.internet import defer, protocol
|
|||
from twisted.internet.error import DNSLookupError
|
||||
from twisted.internet.task import _EPSILON, Cooperator
|
||||
from twisted.web._newclient import ResponseDone
|
||||
from twisted.web.client import Agent, FileBodyProducer, HTTPConnectionPool
|
||||
from twisted.web.client import FileBodyProducer
|
||||
from twisted.web.http_headers import Headers
|
||||
|
||||
import synapse.metrics
|
||||
|
@ -44,7 +44,7 @@ from synapse.api.errors import (
|
|||
RequestSendFailed,
|
||||
SynapseError,
|
||||
)
|
||||
from synapse.http.endpoint import matrix_federation_endpoint
|
||||
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
|
||||
from synapse.util.async_helpers import timeout_deferred
|
||||
from synapse.util.logcontext import make_deferred_yieldable
|
||||
from synapse.util.metrics import Measure
|
||||
|
@ -66,20 +66,6 @@ else:
|
|||
MAXINT = sys.maxint
|
||||
|
||||
|
||||
class MatrixFederationEndpointFactory(object):
|
||||
def __init__(self, hs):
|
||||
self.reactor = hs.get_reactor()
|
||||
self.tls_client_options_factory = hs.tls_client_options_factory
|
||||
|
||||
def endpointForURI(self, uri):
|
||||
destination = uri.netloc.decode('ascii')
|
||||
|
||||
return matrix_federation_endpoint(
|
||||
self.reactor, destination, timeout=10,
|
||||
tls_client_options_factory=self.tls_client_options_factory
|
||||
)
|
||||
|
||||
|
||||
_next_id = 1
|
||||
|
||||
|
||||
|
@ -187,12 +173,10 @@ class MatrixFederationHttpClient(object):
|
|||
self.signing_key = hs.config.signing_key[0]
|
||||
self.server_name = hs.hostname
|
||||
reactor = hs.get_reactor()
|
||||
pool = HTTPConnectionPool(reactor)
|
||||
pool.retryAutomatically = False
|
||||
pool.maxPersistentPerHost = 5
|
||||
pool.cachedConnectionTimeout = 2 * 60
|
||||
self.agent = Agent.usingEndpointFactory(
|
||||
reactor, MatrixFederationEndpointFactory(hs), pool=pool
|
||||
|
||||
self.agent = MatrixFederationAgent(
|
||||
hs.get_reactor(),
|
||||
hs.tls_client_options_factory,
|
||||
)
|
||||
self.clock = hs.get_clock()
|
||||
self._store = hs.get_datastore()
|
||||
|
@ -229,19 +213,18 @@ class MatrixFederationHttpClient(object):
|
|||
backoff_on_404 (bool): Back off if we get a 404
|
||||
|
||||
Returns:
|
||||
Deferred: resolves with the http response object on success.
|
||||
Deferred[twisted.web.client.Response]: resolves with the HTTP
|
||||
response object on success.
|
||||
|
||||
Fails with ``HttpResponseException``: if we get an HTTP response
|
||||
code >= 300 (except 429).
|
||||
|
||||
Fails with ``NotRetryingDestination`` if we are not yet ready
|
||||
to retry this server.
|
||||
|
||||
Fails with ``FederationDeniedError`` if this destination
|
||||
is not on our federation whitelist
|
||||
|
||||
Fails with ``RequestSendFailed`` if there were problems connecting to
|
||||
the remote, due to e.g. DNS failures, connection timeouts etc.
|
||||
Raises:
|
||||
HttpResponseException: If we get an HTTP response code >= 300
|
||||
(except 429).
|
||||
NotRetryingDestination: If we are not yet ready to retry this
|
||||
server.
|
||||
FederationDeniedError: If this destination is not on our
|
||||
federation whitelist
|
||||
RequestSendFailed: If there were problems connecting to the
|
||||
remote, due to e.g. DNS failures, connection timeouts etc.
|
||||
"""
|
||||
if timeout:
|
||||
_sec_timeout = timeout / 1000
|
||||
|
@ -299,9 +282,9 @@ class MatrixFederationHttpClient(object):
|
|||
json = request.get_json()
|
||||
if json:
|
||||
headers_dict[b"Content-Type"] = [b"application/json"]
|
||||
self.sign_request(
|
||||
auth_headers = self.build_auth_headers(
|
||||
destination_bytes, method_bytes, url_to_sign_bytes,
|
||||
headers_dict, json,
|
||||
json,
|
||||
)
|
||||
data = encode_canonical_json(json)
|
||||
producer = FileBodyProducer(
|
||||
|
@ -310,40 +293,40 @@ class MatrixFederationHttpClient(object):
|
|||
)
|
||||
else:
|
||||
producer = None
|
||||
self.sign_request(
|
||||
auth_headers = self.build_auth_headers(
|
||||
destination_bytes, method_bytes, url_to_sign_bytes,
|
||||
headers_dict,
|
||||
)
|
||||
|
||||
headers_dict[b"Authorization"] = auth_headers
|
||||
|
||||
logger.info(
|
||||
"{%s} [%s] Sending request: %s %s",
|
||||
"{%s} [%s] Sending request: %s %s; timeout %fs",
|
||||
request.txn_id, request.destination, request.method,
|
||||
url_str,
|
||||
)
|
||||
|
||||
# we don't want all the fancy cookie and redirect handling that
|
||||
# treq.request gives: just use the raw Agent.
|
||||
request_deferred = self.agent.request(
|
||||
method_bytes,
|
||||
url_bytes,
|
||||
headers=Headers(headers_dict),
|
||||
bodyProducer=producer,
|
||||
)
|
||||
|
||||
request_deferred = timeout_deferred(
|
||||
request_deferred,
|
||||
timeout=_sec_timeout,
|
||||
reactor=self.hs.get_reactor(),
|
||||
url_str, _sec_timeout,
|
||||
)
|
||||
|
||||
try:
|
||||
with Measure(self.clock, "outbound_request"):
|
||||
response = yield make_deferred_yieldable(
|
||||
request_deferred,
|
||||
# we don't want all the fancy cookie and redirect handling
|
||||
# that treq.request gives: just use the raw Agent.
|
||||
request_deferred = self.agent.request(
|
||||
method_bytes,
|
||||
url_bytes,
|
||||
headers=Headers(headers_dict),
|
||||
bodyProducer=producer,
|
||||
)
|
||||
|
||||
request_deferred = timeout_deferred(
|
||||
request_deferred,
|
||||
timeout=_sec_timeout,
|
||||
reactor=self.hs.get_reactor(),
|
||||
)
|
||||
|
||||
response = yield request_deferred
|
||||
except DNSLookupError as e:
|
||||
raise_from(RequestSendFailed(e, can_retry=retry_on_dns_fail), e)
|
||||
except Exception as e:
|
||||
logger.info("Failed to send request: %s", e)
|
||||
raise_from(RequestSendFailed(e, can_retry=True), e)
|
||||
|
||||
logger.info(
|
||||
|
@ -441,24 +424,23 @@ class MatrixFederationHttpClient(object):
|
|||
|
||||
defer.returnValue(response)
|
||||
|
||||
def sign_request(self, destination, method, url_bytes, headers_dict,
|
||||
content=None, destination_is=None):
|
||||
def build_auth_headers(
|
||||
self, destination, method, url_bytes, content=None, destination_is=None,
|
||||
):
|
||||
"""
|
||||
Signs a request by adding an Authorization header to headers_dict
|
||||
Builds the Authorization headers for a federation request
|
||||
Args:
|
||||
destination (bytes|None): The desination home server of the request.
|
||||
May be None if the destination is an identity server, in which case
|
||||
destination_is must be non-None.
|
||||
method (bytes): The HTTP method of the request
|
||||
url_bytes (bytes): The URI path of the request
|
||||
headers_dict (dict[bytes, list[bytes]]): Dictionary of request headers to
|
||||
append to
|
||||
content (object): The body of the request
|
||||
destination_is (bytes): As 'destination', but if the destination is an
|
||||
identity server
|
||||
|
||||
Returns:
|
||||
None
|
||||
list[bytes]: a list of headers to be added as "Authorization:" headers
|
||||
"""
|
||||
request = {
|
||||
"method": method,
|
||||
|
@ -485,8 +467,7 @@ class MatrixFederationHttpClient(object):
|
|||
self.server_name, key, sig,
|
||||
)).encode('ascii')
|
||||
)
|
||||
|
||||
headers_dict[b"Authorization"] = auth_headers
|
||||
return auth_headers
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def put_json(self, destination, path, args={}, data={},
|
||||
|
@ -516,17 +497,18 @@ class MatrixFederationHttpClient(object):
|
|||
requests)
|
||||
|
||||
Returns:
|
||||
Deferred: Succeeds when we get a 2xx HTTP response. The result
|
||||
will be the decoded JSON body.
|
||||
Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The
|
||||
result will be the decoded JSON body.
|
||||
|
||||
Fails with ``HttpResponseException`` if we get an HTTP response
|
||||
code >= 300.
|
||||
|
||||
Fails with ``NotRetryingDestination`` if we are not yet ready
|
||||
to retry this server.
|
||||
|
||||
Fails with ``FederationDeniedError`` if this destination
|
||||
is not on our federation whitelist
|
||||
Raises:
|
||||
HttpResponseException: If we get an HTTP response code >= 300
|
||||
(except 429).
|
||||
NotRetryingDestination: If we are not yet ready to retry this
|
||||
server.
|
||||
FederationDeniedError: If this destination is not on our
|
||||
federation whitelist
|
||||
RequestSendFailed: If there were problems connecting to the
|
||||
remote, due to e.g. DNS failures, connection timeouts etc.
|
||||
"""
|
||||
|
||||
request = MatrixFederationRequest(
|
||||
|
@ -570,17 +552,18 @@ class MatrixFederationHttpClient(object):
|
|||
try the request anyway.
|
||||
args (dict): query params
|
||||
Returns:
|
||||
Deferred: Succeeds when we get a 2xx HTTP response. The result
|
||||
will be the decoded JSON body.
|
||||
Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The
|
||||
result will be the decoded JSON body.
|
||||
|
||||
Fails with ``HttpResponseException`` if we get an HTTP response
|
||||
code >= 300.
|
||||
|
||||
Fails with ``NotRetryingDestination`` if we are not yet ready
|
||||
to retry this server.
|
||||
|
||||
Fails with ``FederationDeniedError`` if this destination
|
||||
is not on our federation whitelist
|
||||
Raises:
|
||||
HttpResponseException: If we get an HTTP response code >= 300
|
||||
(except 429).
|
||||
NotRetryingDestination: If we are not yet ready to retry this
|
||||
server.
|
||||
FederationDeniedError: If this destination is not on our
|
||||
federation whitelist
|
||||
RequestSendFailed: If there were problems connecting to the
|
||||
remote, due to e.g. DNS failures, connection timeouts etc.
|
||||
"""
|
||||
|
||||
request = MatrixFederationRequest(
|
||||
|
@ -625,17 +608,18 @@ class MatrixFederationHttpClient(object):
|
|||
ignore_backoff (bool): true to ignore the historical backoff data
|
||||
and try the request anyway.
|
||||
Returns:
|
||||
Deferred: Succeeds when we get a 2xx HTTP response. The result
|
||||
will be the decoded JSON body.
|
||||
Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The
|
||||
result will be the decoded JSON body.
|
||||
|
||||
Fails with ``HttpResponseException`` if we get an HTTP response
|
||||
code >= 300.
|
||||
|
||||
Fails with ``NotRetryingDestination`` if we are not yet ready
|
||||
to retry this server.
|
||||
|
||||
Fails with ``FederationDeniedError`` if this destination
|
||||
is not on our federation whitelist
|
||||
Raises:
|
||||
HttpResponseException: If we get an HTTP response code >= 300
|
||||
(except 429).
|
||||
NotRetryingDestination: If we are not yet ready to retry this
|
||||
server.
|
||||
FederationDeniedError: If this destination is not on our
|
||||
federation whitelist
|
||||
RequestSendFailed: If there were problems connecting to the
|
||||
remote, due to e.g. DNS failures, connection timeouts etc.
|
||||
"""
|
||||
logger.debug("get_json args: %s", args)
|
||||
|
||||
|
@ -676,17 +660,18 @@ class MatrixFederationHttpClient(object):
|
|||
ignore_backoff (bool): true to ignore the historical backoff data and
|
||||
try the request anyway.
|
||||
Returns:
|
||||
Deferred: Succeeds when we get a 2xx HTTP response. The result
|
||||
will be the decoded JSON body.
|
||||
Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The
|
||||
result will be the decoded JSON body.
|
||||
|
||||
Fails with ``HttpResponseException`` if we get an HTTP response
|
||||
code >= 300.
|
||||
|
||||
Fails with ``NotRetryingDestination`` if we are not yet ready
|
||||
to retry this server.
|
||||
|
||||
Fails with ``FederationDeniedError`` if this destination
|
||||
is not on our federation whitelist
|
||||
Raises:
|
||||
HttpResponseException: If we get an HTTP response code >= 300
|
||||
(except 429).
|
||||
NotRetryingDestination: If we are not yet ready to retry this
|
||||
server.
|
||||
FederationDeniedError: If this destination is not on our
|
||||
federation whitelist
|
||||
RequestSendFailed: If there were problems connecting to the
|
||||
remote, due to e.g. DNS failures, connection timeouts etc.
|
||||
"""
|
||||
request = MatrixFederationRequest(
|
||||
method="DELETE",
|
||||
|
@ -719,18 +704,20 @@ class MatrixFederationHttpClient(object):
|
|||
args (dict): Optional dictionary used to create the query string.
|
||||
ignore_backoff (bool): true to ignore the historical backoff data
|
||||
and try the request anyway.
|
||||
|
||||
Returns:
|
||||
Deferred: resolves with an (int,dict) tuple of the file length and
|
||||
a dict of the response headers.
|
||||
Deferred[tuple[int, dict]]: Resolves with an (int,dict) tuple of
|
||||
the file length and a dict of the response headers.
|
||||
|
||||
Fails with ``HttpResponseException`` if we get an HTTP response code
|
||||
>= 300
|
||||
|
||||
Fails with ``NotRetryingDestination`` if we are not yet ready
|
||||
to retry this server.
|
||||
|
||||
Fails with ``FederationDeniedError`` if this destination
|
||||
is not on our federation whitelist
|
||||
Raises:
|
||||
HttpResponseException: If we get an HTTP response code >= 300
|
||||
(except 429).
|
||||
NotRetryingDestination: If we are not yet ready to retry this
|
||||
server.
|
||||
FederationDeniedError: If this destination is not on our
|
||||
federation whitelist
|
||||
RequestSendFailed: If there were problems connecting to the
|
||||
remote, due to e.g. DNS failures, connection timeouts etc.
|
||||
"""
|
||||
request = MatrixFederationRequest(
|
||||
method="GET",
|
||||
|
|
|
@ -40,7 +40,11 @@ REQUIREMENTS = [
|
|||
"signedjson>=1.0.0",
|
||||
"pynacl>=1.2.1",
|
||||
"service_identity>=16.0.0",
|
||||
"Twisted>=17.1.0",
|
||||
|
||||
# our logcontext handling relies on the ability to cancel inlineCallbacks
|
||||
# (https://twistedmatrix.com/trac/ticket/4632) which landed in Twisted 18.7.
|
||||
"Twisted>=18.7.0",
|
||||
|
||||
"treq>=15.1",
|
||||
# Twisted has required pyopenssl 16.0 since about Twisted 16.6.
|
||||
"pyopenssl>=16.0.0",
|
||||
|
@ -52,15 +56,18 @@ REQUIREMENTS = [
|
|||
"pillow>=3.1.2",
|
||||
"sortedcontainers>=1.4.4",
|
||||
"psutil>=2.0.0",
|
||||
"pymacaroons-pynacl>=0.9.3",
|
||||
"msgpack-python>=0.4.2",
|
||||
"pymacaroons>=0.13.0",
|
||||
"msgpack>=0.5.0",
|
||||
"phonenumbers>=8.2.0",
|
||||
"six>=1.10",
|
||||
# prometheus_client 0.4.0 changed the format of counter metrics
|
||||
# (cf https://github.com/matrix-org/synapse/issues/4001)
|
||||
"prometheus_client>=0.0.18,<0.4.0",
|
||||
|
||||
# we use attr.s(slots), which arrived in 16.0.0
|
||||
"attrs>=16.0.0",
|
||||
# Twisted 18.7.0 requires attrs>=17.4.0
|
||||
"attrs>=17.4.0",
|
||||
|
||||
"netaddr>=0.7.18",
|
||||
]
|
||||
|
||||
|
@ -72,6 +79,10 @@ CONDITIONAL_REQUIREMENTS = {
|
|||
# ConsentResource uses select_autoescape, which arrived in jinja 2.9
|
||||
"resources.consent": ["Jinja2>=2.9"],
|
||||
|
||||
# ACME support is required to provision TLS certificates from authorities
|
||||
# that use the protocol, such as Let's Encrypt.
|
||||
"acme": ["txacme>=0.9.2"],
|
||||
|
||||
"saml2": ["pysaml2>=4.5.0"],
|
||||
"url_preview": ["lxml>=3.5.0"],
|
||||
"test": ["mock>=2.0"],
|
||||
|
|
|
@ -309,22 +309,16 @@ class RegisterRestServlet(RestServlet):
|
|||
assigned_user_id=registered_user_id,
|
||||
)
|
||||
|
||||
# Only give msisdn flows if the x_show_msisdn flag is given:
|
||||
# this is a hack to work around the fact that clients were shipped
|
||||
# that use fallback registration if they see any flows that they don't
|
||||
# recognise, which means we break registration for these clients if we
|
||||
# advertise msisdn flows. Once usage of Riot iOS <=0.3.9 and Riot
|
||||
# Android <=0.6.9 have fallen below an acceptable threshold, this
|
||||
# parameter should go away and we should always advertise msisdn flows.
|
||||
show_msisdn = False
|
||||
if 'x_show_msisdn' in body and body['x_show_msisdn']:
|
||||
show_msisdn = True
|
||||
|
||||
# FIXME: need a better error than "no auth flow found" for scenarios
|
||||
# where we required 3PID for registration but the user didn't give one
|
||||
require_email = 'email' in self.hs.config.registrations_require_3pid
|
||||
require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid
|
||||
|
||||
show_msisdn = True
|
||||
if self.hs.config.disable_msisdn_registration:
|
||||
show_msisdn = False
|
||||
require_msisdn = False
|
||||
|
||||
flows = []
|
||||
if self.hs.config.enable_registration_captcha:
|
||||
# only support 3PIDless registration if no 3PIDs are required
|
||||
|
|
|
@ -46,6 +46,7 @@ from synapse.federation.transport.client import TransportLayerClient
|
|||
from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer
|
||||
from synapse.groups.groups_server import GroupsServerHandler
|
||||
from synapse.handlers import Handlers
|
||||
from synapse.handlers.acme import AcmeHandler
|
||||
from synapse.handlers.appservice import ApplicationServicesHandler
|
||||
from synapse.handlers.auth import AuthHandler, MacaroonGenerator
|
||||
from synapse.handlers.deactivate_account import DeactivateAccountHandler
|
||||
|
@ -129,6 +130,7 @@ class HomeServer(object):
|
|||
'sync_handler',
|
||||
'typing_handler',
|
||||
'room_list_handler',
|
||||
'acme_handler',
|
||||
'auth_handler',
|
||||
'device_handler',
|
||||
'e2e_keys_handler',
|
||||
|
@ -310,6 +312,9 @@ class HomeServer(object):
|
|||
def build_e2e_room_keys_handler(self):
|
||||
return E2eRoomKeysHandler(self)
|
||||
|
||||
def build_acme_handler(self):
|
||||
return AcmeHandler(self)
|
||||
|
||||
def build_application_service_api(self):
|
||||
return ApplicationServiceApi(self)
|
||||
|
||||
|
|
|
@ -192,6 +192,41 @@ class SQLBaseStore(object):
|
|||
|
||||
self.database_engine = hs.database_engine
|
||||
|
||||
# A set of tables that are not safe to use native upserts in.
|
||||
self._unsafe_to_upsert_tables = {"user_ips"}
|
||||
|
||||
if self.database_engine.can_native_upsert:
|
||||
# Check ASAP (and then later, every 1s) to see if we have finished
|
||||
# background updates of tables that aren't safe to update.
|
||||
self._clock.call_later(0.0, self._check_safe_to_upsert)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_safe_to_upsert(self):
|
||||
"""
|
||||
Is it safe to use native UPSERT?
|
||||
|
||||
If there are background updates, we will need to wait, as they may be
|
||||
the addition of indexes that set the UNIQUE constraint that we require.
|
||||
|
||||
If the background updates have not completed, wait a second and check again.
|
||||
"""
|
||||
updates = yield self._simple_select_list(
|
||||
"background_updates",
|
||||
keyvalues=None,
|
||||
retcols=["update_name"],
|
||||
desc="check_background_updates",
|
||||
)
|
||||
updates = [x["update_name"] for x in updates]
|
||||
|
||||
# The User IPs table in schema #53 was missing a unique index, which we
|
||||
# run as a background update.
|
||||
if "user_ips_device_unique_index" not in updates:
|
||||
self._unsafe_to_upsert_tables.discard("user_id")
|
||||
|
||||
# If there's any tables left to check, reschedule to run.
|
||||
if self._unsafe_to_upsert_tables:
|
||||
self._clock.call_later(1.0, self._check_safe_to_upsert)
|
||||
|
||||
def start_profiling(self):
|
||||
self._previous_loop_ts = self._clock.time_msec()
|
||||
|
||||
|
@ -494,8 +529,15 @@ class SQLBaseStore(object):
|
|||
txn.executemany(sql, vals)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _simple_upsert(self, table, keyvalues, values,
|
||||
insertion_values={}, desc="_simple_upsert", lock=True):
|
||||
def _simple_upsert(
|
||||
self,
|
||||
table,
|
||||
keyvalues,
|
||||
values,
|
||||
insertion_values={},
|
||||
desc="_simple_upsert",
|
||||
lock=True
|
||||
):
|
||||
"""
|
||||
|
||||
`lock` should generally be set to True (the default), but can be set
|
||||
|
@ -516,16 +558,21 @@ class SQLBaseStore(object):
|
|||
inserting
|
||||
lock (bool): True to lock the table when doing the upsert.
|
||||
Returns:
|
||||
Deferred(bool): True if a new entry was created, False if an
|
||||
existing one was updated.
|
||||
Deferred(None or bool): Native upserts always return None. Emulated
|
||||
upserts return True if a new entry was created, False if an existing
|
||||
one was updated.
|
||||
"""
|
||||
attempts = 0
|
||||
while True:
|
||||
try:
|
||||
result = yield self.runInteraction(
|
||||
desc,
|
||||
self._simple_upsert_txn, table, keyvalues, values, insertion_values,
|
||||
lock=lock
|
||||
self._simple_upsert_txn,
|
||||
table,
|
||||
keyvalues,
|
||||
values,
|
||||
insertion_values,
|
||||
lock=lock,
|
||||
)
|
||||
defer.returnValue(result)
|
||||
except self.database_engine.module.IntegrityError as e:
|
||||
|
@ -537,21 +584,76 @@ class SQLBaseStore(object):
|
|||
|
||||
# presumably we raced with another transaction: let's retry.
|
||||
logger.warn(
|
||||
"IntegrityError when upserting into %s; retrying: %s",
|
||||
table, e
|
||||
"%s when upserting into %s; retrying: %s", e.__name__, table, e
|
||||
)
|
||||
|
||||
def _simple_upsert_txn(self, txn, table, keyvalues, values, insertion_values={},
|
||||
lock=True):
|
||||
def _simple_upsert_txn(
|
||||
self,
|
||||
txn,
|
||||
table,
|
||||
keyvalues,
|
||||
values,
|
||||
insertion_values={},
|
||||
lock=True,
|
||||
):
|
||||
"""
|
||||
Pick the UPSERT method which works best on the platform. Either the
|
||||
native one (Pg9.5+, recent SQLites), or fall back to an emulated method.
|
||||
|
||||
Args:
|
||||
txn: The transaction to use.
|
||||
table (str): The table to upsert into
|
||||
keyvalues (dict): The unique key tables and their new values
|
||||
values (dict): The nonunique columns and their new values
|
||||
insertion_values (dict): additional key/values to use only when
|
||||
inserting
|
||||
lock (bool): True to lock the table when doing the upsert.
|
||||
Returns:
|
||||
Deferred(None or bool): Native upserts always return None. Emulated
|
||||
upserts return True if a new entry was created, False if an existing
|
||||
one was updated.
|
||||
"""
|
||||
if (
|
||||
self.database_engine.can_native_upsert
|
||||
and table not in self._unsafe_to_upsert_tables
|
||||
):
|
||||
return self._simple_upsert_txn_native_upsert(
|
||||
txn,
|
||||
table,
|
||||
keyvalues,
|
||||
values,
|
||||
insertion_values=insertion_values,
|
||||
)
|
||||
else:
|
||||
return self._simple_upsert_txn_emulated(
|
||||
txn,
|
||||
table,
|
||||
keyvalues,
|
||||
values,
|
||||
insertion_values=insertion_values,
|
||||
lock=lock,
|
||||
)
|
||||
|
||||
def _simple_upsert_txn_emulated(
|
||||
self, txn, table, keyvalues, values, insertion_values={}, lock=True
|
||||
):
|
||||
# We need to lock the table :(, unless we're *really* careful
|
||||
if lock:
|
||||
self.database_engine.lock_table(txn, table)
|
||||
|
||||
def _getwhere(key):
|
||||
# If the value we're passing in is None (aka NULL), we need to use
|
||||
# IS, not =, as NULL = NULL equals NULL (False).
|
||||
if keyvalues[key] is None:
|
||||
return "%s IS ?" % (key,)
|
||||
else:
|
||||
return "%s = ?" % (key,)
|
||||
|
||||
# First try to update.
|
||||
sql = "UPDATE %s SET %s WHERE %s" % (
|
||||
table,
|
||||
", ".join("%s = ?" % (k,) for k in values),
|
||||
" AND ".join("%s = ?" % (k,) for k in keyvalues)
|
||||
" AND ".join(_getwhere(k) for k in keyvalues)
|
||||
)
|
||||
sqlargs = list(values.values()) + list(keyvalues.values())
|
||||
|
||||
|
@ -569,12 +671,44 @@ class SQLBaseStore(object):
|
|||
sql = "INSERT INTO %s (%s) VALUES (%s)" % (
|
||||
table,
|
||||
", ".join(k for k in allvalues),
|
||||
", ".join("?" for _ in allvalues)
|
||||
", ".join("?" for _ in allvalues),
|
||||
)
|
||||
txn.execute(sql, list(allvalues.values()))
|
||||
# successfully inserted
|
||||
return True
|
||||
|
||||
def _simple_upsert_txn_native_upsert(
|
||||
self, txn, table, keyvalues, values, insertion_values={}
|
||||
):
|
||||
"""
|
||||
Use the native UPSERT functionality in recent PostgreSQL versions.
|
||||
|
||||
Args:
|
||||
table (str): The table to upsert into
|
||||
keyvalues (dict): The unique key tables and their new values
|
||||
values (dict): The nonunique columns and their new values
|
||||
insertion_values (dict): additional key/values to use only when
|
||||
inserting
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
allvalues = {}
|
||||
allvalues.update(keyvalues)
|
||||
allvalues.update(values)
|
||||
allvalues.update(insertion_values)
|
||||
|
||||
sql = (
|
||||
"INSERT INTO %s (%s) VALUES (%s) "
|
||||
"ON CONFLICT (%s) DO UPDATE SET %s"
|
||||
) % (
|
||||
table,
|
||||
", ".join(k for k in allvalues),
|
||||
", ".join("?" for _ in allvalues),
|
||||
", ".join(k for k in keyvalues),
|
||||
", ".join(k + "=EXCLUDED." + k for k in values),
|
||||
)
|
||||
txn.execute(sql, list(allvalues.values()))
|
||||
|
||||
def _simple_select_one(self, table, keyvalues, retcols,
|
||||
allow_none=False, desc="_simple_select_one"):
|
||||
"""Executes a SELECT query on the named table, which is expected to
|
||||
|
|
|
@ -65,7 +65,27 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
|||
columns=["last_seen"],
|
||||
)
|
||||
|
||||
# (user_id, access_token, ip) -> (user_agent, device_id, last_seen)
|
||||
self.register_background_update_handler(
|
||||
"user_ips_remove_dupes",
|
||||
self._remove_user_ip_dupes,
|
||||
)
|
||||
|
||||
# Register a unique index
|
||||
self.register_background_index_update(
|
||||
"user_ips_device_unique_index",
|
||||
index_name="user_ips_user_token_ip_unique_index",
|
||||
table="user_ips",
|
||||
columns=["user_id", "access_token", "ip"],
|
||||
unique=True,
|
||||
)
|
||||
|
||||
# Drop the old non-unique index
|
||||
self.register_background_update_handler(
|
||||
"user_ips_drop_nonunique_index",
|
||||
self._remove_user_ip_nonunique,
|
||||
)
|
||||
|
||||
# (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
|
||||
self._batch_row_update = {}
|
||||
|
||||
self._client_ip_looper = self._clock.looping_call(
|
||||
|
@ -75,6 +95,129 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
|||
"before", "shutdown", self._update_client_ips_batch
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _remove_user_ip_nonunique(self, progress, batch_size):
|
||||
def f(conn):
|
||||
txn = conn.cursor()
|
||||
txn.execute(
|
||||
"DROP INDEX IF EXISTS user_ips_user_ip"
|
||||
)
|
||||
txn.close()
|
||||
|
||||
yield self.runWithConnection(f)
|
||||
yield self._end_background_update("user_ips_drop_nonunique_index")
|
||||
defer.returnValue(1)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _remove_user_ip_dupes(self, progress, batch_size):
|
||||
# This works function works by scanning the user_ips table in batches
|
||||
# based on `last_seen`. For each row in a batch it searches the rest of
|
||||
# the table to see if there are any duplicates, if there are then they
|
||||
# are removed and replaced with a suitable row.
|
||||
|
||||
# Fetch the start of the batch
|
||||
begin_last_seen = progress.get("last_seen", 0)
|
||||
|
||||
def get_last_seen(txn):
|
||||
txn.execute(
|
||||
"""
|
||||
SELECT last_seen FROM user_ips
|
||||
WHERE last_seen > ?
|
||||
ORDER BY last_seen
|
||||
LIMIT 1
|
||||
OFFSET ?
|
||||
""",
|
||||
(begin_last_seen, batch_size)
|
||||
)
|
||||
row = txn.fetchone()
|
||||
if row:
|
||||
return row[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
# Get a last seen that has roughly `batch_size` since `begin_last_seen`
|
||||
end_last_seen = yield self.runInteraction(
|
||||
"user_ips_dups_get_last_seen", get_last_seen
|
||||
)
|
||||
|
||||
# If it returns None, then we're processing the last batch
|
||||
last = end_last_seen is None
|
||||
|
||||
logger.info(
|
||||
"Scanning for duplicate 'user_ips' rows in range: %s <= last_seen < %s",
|
||||
begin_last_seen, end_last_seen,
|
||||
)
|
||||
|
||||
def remove(txn):
|
||||
# This works by looking at all entries in the given time span, and
|
||||
# then for each (user_id, access_token, ip) tuple in that range
|
||||
# checking for any duplicates in the rest of the table (via a join).
|
||||
# It then only returns entries which have duplicates, and the max
|
||||
# last_seen across all duplicates, which can the be used to delete
|
||||
# all other duplicates.
|
||||
# It is efficient due to the existence of (user_id, access_token,
|
||||
# ip) and (last_seen) indices.
|
||||
|
||||
# Define the search space, which requires handling the last batch in
|
||||
# a different way
|
||||
if last:
|
||||
clause = "? <= last_seen"
|
||||
args = (begin_last_seen,)
|
||||
else:
|
||||
clause = "? <= last_seen AND last_seen < ?"
|
||||
args = (begin_last_seen, end_last_seen)
|
||||
|
||||
txn.execute(
|
||||
"""
|
||||
SELECT user_id, access_token, ip,
|
||||
MAX(device_id), MAX(user_agent), MAX(last_seen)
|
||||
FROM (
|
||||
SELECT user_id, access_token, ip
|
||||
FROM user_ips
|
||||
WHERE {}
|
||||
) c
|
||||
INNER JOIN user_ips USING (user_id, access_token, ip)
|
||||
GROUP BY user_id, access_token, ip
|
||||
HAVING count(*) > 1
|
||||
""".format(clause),
|
||||
args
|
||||
)
|
||||
res = txn.fetchall()
|
||||
|
||||
# We've got some duplicates
|
||||
for i in res:
|
||||
user_id, access_token, ip, device_id, user_agent, last_seen = i
|
||||
|
||||
# Drop all the duplicates
|
||||
txn.execute(
|
||||
"""
|
||||
DELETE FROM user_ips
|
||||
WHERE user_id = ? AND access_token = ? AND ip = ?
|
||||
""",
|
||||
(user_id, access_token, ip)
|
||||
)
|
||||
|
||||
# Add in one to be the last_seen
|
||||
txn.execute(
|
||||
"""
|
||||
INSERT INTO user_ips
|
||||
(user_id, access_token, ip, device_id, user_agent, last_seen)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(user_id, access_token, ip, device_id, user_agent, last_seen)
|
||||
)
|
||||
|
||||
self._background_update_progress_txn(
|
||||
txn, "user_ips_remove_dupes", {"last_seen": end_last_seen}
|
||||
)
|
||||
|
||||
yield self.runInteraction("user_ips_dups_remove", remove)
|
||||
|
||||
if last:
|
||||
yield self._end_background_update("user_ips_remove_dupes")
|
||||
|
||||
defer.returnValue(batch_size)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id,
|
||||
now=None):
|
||||
|
@ -114,7 +257,10 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
|||
)
|
||||
|
||||
def _update_client_ips_batch_txn(self, txn, to_update):
|
||||
self.database_engine.lock_table(txn, "user_ips")
|
||||
if "user_ips" in self._unsafe_to_upsert_tables or (
|
||||
not self.database_engine.can_native_upsert
|
||||
):
|
||||
self.database_engine.lock_table(txn, "user_ips")
|
||||
|
||||
for entry in iteritems(to_update):
|
||||
(user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
|
||||
|
@ -127,10 +273,10 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
|||
"user_id": user_id,
|
||||
"access_token": access_token,
|
||||
"ip": ip,
|
||||
"user_agent": user_agent,
|
||||
"device_id": device_id,
|
||||
},
|
||||
values={
|
||||
"user_agent": user_agent,
|
||||
"device_id": device_id,
|
||||
"last_seen": last_seen,
|
||||
},
|
||||
lock=False,
|
||||
|
@ -227,7 +373,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
|||
results = {}
|
||||
|
||||
for key in self._batch_row_update:
|
||||
uid, access_token, ip = key
|
||||
uid, access_token, ip, = key
|
||||
if uid == user_id:
|
||||
user_agent, _, last_seen = self._batch_row_update[key]
|
||||
results[(access_token, ip)] = (user_agent, last_seen)
|
||||
|
|
|
@ -18,7 +18,7 @@ import platform
|
|||
|
||||
from ._base import IncorrectDatabaseSetup
|
||||
from .postgres import PostgresEngine
|
||||
from .sqlite3 import Sqlite3Engine
|
||||
from .sqlite import Sqlite3Engine
|
||||
|
||||
SUPPORTED_MODULE = {
|
||||
"sqlite3": Sqlite3Engine,
|
||||
|
|
|
@ -38,6 +38,13 @@ class PostgresEngine(object):
|
|||
return sql.replace("?", "%s")
|
||||
|
||||
def on_new_connection(self, db_conn):
|
||||
|
||||
# Get the version of PostgreSQL that we're using. As per the psycopg2
|
||||
# docs: The number is formed by converting the major, minor, and
|
||||
# revision numbers into two-decimal-digit numbers and appending them
|
||||
# together. For example, version 8.1.5 will be returned as 80105
|
||||
self._version = db_conn.server_version
|
||||
|
||||
db_conn.set_isolation_level(
|
||||
self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
|
||||
)
|
||||
|
@ -54,6 +61,13 @@ class PostgresEngine(object):
|
|||
|
||||
cursor.close()
|
||||
|
||||
@property
|
||||
def can_native_upsert(self):
|
||||
"""
|
||||
Can we use native UPSERTs? This requires PostgreSQL 9.5+.
|
||||
"""
|
||||
return self._version >= 90500
|
||||
|
||||
def is_deadlock(self, error):
|
||||
if isinstance(error, self.module.DatabaseError):
|
||||
# https://www.postgresql.org/docs/current/static/errcodes-appendix.html
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
import struct
|
||||
import threading
|
||||
from sqlite3 import sqlite_version_info
|
||||
|
||||
from synapse.storage.prepare_database import prepare_database
|
||||
|
||||
|
@ -30,6 +31,14 @@ class Sqlite3Engine(object):
|
|||
self._current_state_group_id = None
|
||||
self._current_state_group_id_lock = threading.Lock()
|
||||
|
||||
@property
|
||||
def can_native_upsert(self):
|
||||
"""
|
||||
Do we support native UPSERTs? This requires SQLite3 3.24+, plus some
|
||||
more work we haven't done yet to tell what was inserted vs updated.
|
||||
"""
|
||||
return sqlite_version_info >= (3, 24, 0)
|
||||
|
||||
def check_database(self, txn):
|
||||
pass
|
||||
|
|
@ -739,7 +739,18 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
|
|||
}
|
||||
|
||||
events_map = {ev.event_id: ev for ev, _ in events_context}
|
||||
room_version = yield self.get_room_version(room_id)
|
||||
|
||||
# We need to get the room version, which is in the create event.
|
||||
# Normally that'd be in the database, but its also possible that we're
|
||||
# currently trying to persist it.
|
||||
room_version = None
|
||||
for ev, _ in events_context:
|
||||
if ev.type == EventTypes.Create and ev.state_key == "":
|
||||
room_version = ev.content.get("room_version", "1")
|
||||
break
|
||||
|
||||
if not room_version:
|
||||
room_version = yield self.get_room_version(room_id)
|
||||
|
||||
logger.debug("calling resolve_state_groups from preserve_events")
|
||||
res = yield self._state_resolution_handler.resolve_state_groups(
|
||||
|
|
|
@ -215,7 +215,7 @@ class PusherStore(PusherWorkerStore):
|
|||
with self._pushers_id_gen.get_next() as stream_id:
|
||||
# no need to lock because `pushers` has a unique key on
|
||||
# (app_id, pushkey, user_name) so _simple_upsert will retry
|
||||
newly_inserted = yield self._simple_upsert(
|
||||
yield self._simple_upsert(
|
||||
table="pushers",
|
||||
keyvalues={
|
||||
"app_id": app_id,
|
||||
|
@ -238,7 +238,12 @@ class PusherStore(PusherWorkerStore):
|
|||
lock=False,
|
||||
)
|
||||
|
||||
if newly_inserted:
|
||||
user_has_pusher = self.get_if_user_has_pusher.cache.get(
|
||||
(user_id,), None, update_metrics=False
|
||||
)
|
||||
|
||||
if user_has_pusher is not True:
|
||||
# invalidate, since we the user might not have had a pusher before
|
||||
yield self.runInteraction(
|
||||
"add_pusher",
|
||||
self._invalidate_cache_and_stream,
|
||||
|
|
26
synapse/storage/schema/delta/53/user_ips_index.sql
Normal file
26
synapse/storage/schema/delta/53/user_ips_index.sql
Normal file
|
@ -0,0 +1,26 @@
|
|||
/* Copyright 2018 New Vector Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
-- delete duplicates
|
||||
INSERT INTO background_updates (update_name, progress_json) VALUES
|
||||
('user_ips_remove_dupes', '{}');
|
||||
|
||||
-- add a new unique index to user_ips table
|
||||
INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
|
||||
('user_ips_device_unique_index', '{}', 'user_ips_remove_dupes');
|
||||
|
||||
-- drop the old original index
|
||||
INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
|
||||
('user_ips_drop_nonunique_index', '{}', 'user_ips_device_unique_index');
|
|
@ -168,14 +168,14 @@ class UserDirectoryStore(SQLBaseStore):
|
|||
if isinstance(self.database_engine, PostgresEngine):
|
||||
# We weight the localpart most highly, then display name and finally
|
||||
# server name
|
||||
if new_entry:
|
||||
if self.database_engine.can_native_upsert:
|
||||
sql = """
|
||||
INSERT INTO user_directory_search(user_id, vector)
|
||||
VALUES (?,
|
||||
setweight(to_tsvector('english', ?), 'A')
|
||||
|| setweight(to_tsvector('english', ?), 'D')
|
||||
|| setweight(to_tsvector('english', COALESCE(?, '')), 'B')
|
||||
)
|
||||
) ON CONFLICT (user_id) DO UPDATE SET vector=EXCLUDED.vector
|
||||
"""
|
||||
txn.execute(
|
||||
sql,
|
||||
|
@ -185,20 +185,45 @@ class UserDirectoryStore(SQLBaseStore):
|
|||
)
|
||||
)
|
||||
else:
|
||||
sql = """
|
||||
UPDATE user_directory_search
|
||||
SET vector = setweight(to_tsvector('english', ?), 'A')
|
||||
|| setweight(to_tsvector('english', ?), 'D')
|
||||
|| setweight(to_tsvector('english', COALESCE(?, '')), 'B')
|
||||
WHERE user_id = ?
|
||||
"""
|
||||
txn.execute(
|
||||
sql,
|
||||
(
|
||||
get_localpart_from_id(user_id), get_domain_from_id(user_id),
|
||||
display_name, user_id,
|
||||
# TODO: Remove this code after we've bumped the minimum version
|
||||
# of postgres to always support upserts, so we can get rid of
|
||||
# `new_entry` usage
|
||||
if new_entry is True:
|
||||
sql = """
|
||||
INSERT INTO user_directory_search(user_id, vector)
|
||||
VALUES (?,
|
||||
setweight(to_tsvector('english', ?), 'A')
|
||||
|| setweight(to_tsvector('english', ?), 'D')
|
||||
|| setweight(to_tsvector('english', COALESCE(?, '')), 'B')
|
||||
)
|
||||
"""
|
||||
txn.execute(
|
||||
sql,
|
||||
(
|
||||
user_id, get_localpart_from_id(user_id),
|
||||
get_domain_from_id(user_id), display_name,
|
||||
)
|
||||
)
|
||||
elif new_entry is False:
|
||||
sql = """
|
||||
UPDATE user_directory_search
|
||||
SET vector = setweight(to_tsvector('english', ?), 'A')
|
||||
|| setweight(to_tsvector('english', ?), 'D')
|
||||
|| setweight(to_tsvector('english', COALESCE(?, '')), 'B')
|
||||
WHERE user_id = ?
|
||||
"""
|
||||
txn.execute(
|
||||
sql,
|
||||
(
|
||||
get_localpart_from_id(user_id),
|
||||
get_domain_from_id(user_id),
|
||||
display_name, user_id,
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"upsert returned None when 'can_native_upsert' is False"
|
||||
)
|
||||
)
|
||||
elif isinstance(self.database_engine, Sqlite3Engine):
|
||||
value = "%s %s" % (user_id, display_name,) if display_name else user_id
|
||||
self._simple_upsert_txn(
|
||||
|
|
|
@ -387,12 +387,14 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
|
|||
deferred that wraps and times out the given deferred, correctly handling
|
||||
the case where the given deferred's canceller throws.
|
||||
|
||||
(See https://twistedmatrix.com/trac/ticket/9534)
|
||||
|
||||
NOTE: Unlike `Deferred.addTimeout`, this function returns a new deferred
|
||||
|
||||
Args:
|
||||
deferred (Deferred)
|
||||
timeout (float): Timeout in seconds
|
||||
reactor (twisted.internet.reactor): The twisted reactor to use
|
||||
reactor (twisted.interfaces.IReactorTime): The twisted reactor to use
|
||||
on_timeout_cancel (callable): A callable which is called immediately
|
||||
after the deferred times out, and not if this deferred is
|
||||
otherwise cancelled before the timeout.
|
||||
|
|
|
@ -51,7 +51,6 @@ class ConfigGenerationTestCase(unittest.TestCase):
|
|||
"lemurs.win.log.config",
|
||||
"lemurs.win.signing.key",
|
||||
"lemurs.win.tls.crt",
|
||||
"lemurs.win.tls.dh",
|
||||
"lemurs.win.tls.key",
|
||||
]
|
||||
),
|
||||
|
|
14
tests/http/federation/__init__.py
Normal file
14
tests/http/federation/__init__.py
Normal file
|
@ -0,0 +1,14 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2019 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
240
tests/http/federation/test_matrix_federation_agent.py
Normal file
240
tests/http/federation/test_matrix_federation_agent.py
Normal file
|
@ -0,0 +1,240 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2019 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
|
||||
from mock import Mock
|
||||
|
||||
import treq
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.protocol import Factory
|
||||
from twisted.protocols.tls import TLSMemoryBIOFactory
|
||||
from twisted.test.ssl_helpers import ServerTLSContext
|
||||
from twisted.web.http import HTTPChannel
|
||||
|
||||
from synapse.crypto.context_factory import ClientTLSOptionsFactory
|
||||
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
|
||||
from tests.server import FakeTransport, ThreadedMemoryReactorClock
|
||||
from tests.unittest import TestCase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MatrixFederationAgentTests(TestCase):
|
||||
def setUp(self):
|
||||
self.reactor = ThreadedMemoryReactorClock()
|
||||
|
||||
self.mock_resolver = Mock()
|
||||
|
||||
self.agent = MatrixFederationAgent(
|
||||
reactor=self.reactor,
|
||||
tls_client_options_factory=ClientTLSOptionsFactory(None),
|
||||
_srv_resolver=self.mock_resolver,
|
||||
)
|
||||
|
||||
def _make_connection(self, client_factory, expected_sni):
|
||||
"""Builds a test server, and completes the outgoing client connection
|
||||
|
||||
Returns:
|
||||
HTTPChannel: the test server
|
||||
"""
|
||||
|
||||
# build the test server
|
||||
server_tls_protocol = _build_test_server()
|
||||
|
||||
# now, tell the client protocol factory to build the client protocol (it will be a
|
||||
# _WrappingProtocol, around a TLSMemoryBIOProtocol, around an
|
||||
# HTTP11ClientProtocol) and wire the output of said protocol up to the server via
|
||||
# a FakeTransport.
|
||||
#
|
||||
# Normally this would be done by the TCP socket code in Twisted, but we are
|
||||
# stubbing that out here.
|
||||
client_protocol = client_factory.buildProtocol(None)
|
||||
client_protocol.makeConnection(FakeTransport(server_tls_protocol, self.reactor))
|
||||
|
||||
# tell the server tls protocol to send its stuff back to the client, too
|
||||
server_tls_protocol.makeConnection(FakeTransport(client_protocol, self.reactor))
|
||||
|
||||
# give the reactor a pump to get the TLS juices flowing.
|
||||
self.reactor.pump((0.1,))
|
||||
|
||||
# check the SNI
|
||||
server_name = server_tls_protocol._tlsConnection.get_servername()
|
||||
self.assertEqual(
|
||||
server_name,
|
||||
expected_sni,
|
||||
"Expected SNI %s but got %s" % (expected_sni, server_name),
|
||||
)
|
||||
|
||||
# fish the test server back out of the server-side TLS protocol.
|
||||
return server_tls_protocol.wrappedProtocol
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _make_get_request(self, uri):
|
||||
"""
|
||||
Sends a simple GET request via the agent, and checks its logcontext management
|
||||
"""
|
||||
with LoggingContext("one") as context:
|
||||
fetch_d = self.agent.request(b'GET', uri)
|
||||
|
||||
# Nothing happened yet
|
||||
self.assertNoResult(fetch_d)
|
||||
|
||||
# should have reset logcontext to the sentinel
|
||||
_check_logcontext(LoggingContext.sentinel)
|
||||
|
||||
try:
|
||||
fetch_res = yield fetch_d
|
||||
defer.returnValue(fetch_res)
|
||||
finally:
|
||||
_check_logcontext(context)
|
||||
|
||||
def test_get(self):
|
||||
"""
|
||||
happy-path test of a GET request
|
||||
"""
|
||||
self.reactor.lookups["testserv"] = "1.2.3.4"
|
||||
test_d = self._make_get_request(b"matrix://testserv:8448/foo/bar")
|
||||
|
||||
# Nothing happened yet
|
||||
self.assertNoResult(test_d)
|
||||
|
||||
# Make sure treq is trying to connect
|
||||
clients = self.reactor.tcpClients
|
||||
self.assertEqual(len(clients), 1)
|
||||
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
|
||||
self.assertEqual(host, '1.2.3.4')
|
||||
self.assertEqual(port, 8448)
|
||||
|
||||
# make a test server, and wire up the client
|
||||
http_server = self._make_connection(
|
||||
client_factory,
|
||||
expected_sni=b"testserv",
|
||||
)
|
||||
|
||||
self.assertEqual(len(http_server.requests), 1)
|
||||
request = http_server.requests[0]
|
||||
self.assertEqual(request.method, b'GET')
|
||||
self.assertEqual(request.path, b'/foo/bar')
|
||||
self.assertEqual(
|
||||
request.requestHeaders.getRawHeaders(b'host'),
|
||||
[b'testserv:8448']
|
||||
)
|
||||
content = request.content.read()
|
||||
self.assertEqual(content, b'')
|
||||
|
||||
# Deferred is still without a result
|
||||
self.assertNoResult(test_d)
|
||||
|
||||
# send the headers
|
||||
request.responseHeaders.setRawHeaders(b'Content-Type', [b'application/json'])
|
||||
request.write('')
|
||||
|
||||
self.reactor.pump((0.1,))
|
||||
|
||||
response = self.successResultOf(test_d)
|
||||
|
||||
# that should give us a Response object
|
||||
self.assertEqual(response.code, 200)
|
||||
|
||||
# Send the body
|
||||
request.write('{ "a": 1 }'.encode('ascii'))
|
||||
request.finish()
|
||||
|
||||
self.reactor.pump((0.1,))
|
||||
|
||||
# check it can be read
|
||||
json = self.successResultOf(treq.json_content(response))
|
||||
self.assertEqual(json, {"a": 1})
|
||||
|
||||
def test_get_ip_address(self):
|
||||
"""
|
||||
Test the behaviour when the server name contains an explicit IP (with no port)
|
||||
"""
|
||||
|
||||
# the SRV lookup will return an empty list (XXX: why do we even do an SRV lookup?)
|
||||
self.mock_resolver.resolve_service.side_effect = lambda _: []
|
||||
|
||||
# then there will be a getaddrinfo on the IP
|
||||
self.reactor.lookups["1.2.3.4"] = "1.2.3.4"
|
||||
|
||||
test_d = self._make_get_request(b"matrix://1.2.3.4/foo/bar")
|
||||
|
||||
# Nothing happened yet
|
||||
self.assertNoResult(test_d)
|
||||
|
||||
self.mock_resolver.resolve_service.assert_called_once()
|
||||
|
||||
# Make sure treq is trying to connect
|
||||
clients = self.reactor.tcpClients
|
||||
self.assertEqual(len(clients), 1)
|
||||
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
|
||||
self.assertEqual(host, '1.2.3.4')
|
||||
self.assertEqual(port, 8448)
|
||||
|
||||
# make a test server, and wire up the client
|
||||
http_server = self._make_connection(
|
||||
client_factory,
|
||||
expected_sni=None,
|
||||
)
|
||||
|
||||
self.assertEqual(len(http_server.requests), 1)
|
||||
request = http_server.requests[0]
|
||||
self.assertEqual(request.method, b'GET')
|
||||
self.assertEqual(request.path, b'/foo/bar')
|
||||
# XXX currently broken
|
||||
# self.assertEqual(
|
||||
# request.requestHeaders.getRawHeaders(b'host'),
|
||||
# [b'1.2.3.4:8448']
|
||||
# )
|
||||
|
||||
# finish the request
|
||||
request.finish()
|
||||
self.reactor.pump((0.1,))
|
||||
self.successResultOf(test_d)
|
||||
|
||||
|
||||
def _check_logcontext(context):
|
||||
current = LoggingContext.current_context()
|
||||
if current is not context:
|
||||
raise AssertionError(
|
||||
"Expected logcontext %s but was %s" % (context, current),
|
||||
)
|
||||
|
||||
|
||||
def _build_test_server():
|
||||
"""Construct a test server
|
||||
|
||||
This builds an HTTP channel, wrapped with a TLSMemoryBIOProtocol
|
||||
|
||||
Returns:
|
||||
TLSMemoryBIOProtocol
|
||||
"""
|
||||
server_factory = Factory.forProtocol(HTTPChannel)
|
||||
# Request.finish expects the factory to have a 'log' method.
|
||||
server_factory.log = _log_request
|
||||
|
||||
server_tls_factory = TLSMemoryBIOFactory(
|
||||
ServerTLSContext(), isClient=False, wrappedFactory=server_factory,
|
||||
)
|
||||
|
||||
return server_tls_factory.buildProtocol(None)
|
||||
|
||||
|
||||
def _log_request(request):
|
||||
"""Implements Factory.log, which is expected by Request.finish"""
|
||||
logger.info("Completed request %s", request)
|
207
tests/http/federation/test_srv_resolver.py
Normal file
207
tests/http/federation/test_srv_resolver.py
Normal file
|
@ -0,0 +1,207 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2019 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from mock import Mock
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.defer import Deferred
|
||||
from twisted.internet.error import ConnectError
|
||||
from twisted.names import dns, error
|
||||
|
||||
from synapse.http.federation.srv_resolver import SrvResolver
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
|
||||
from tests import unittest
|
||||
from tests.utils import MockClock
|
||||
|
||||
|
||||
class SrvResolverTestCase(unittest.TestCase):
|
||||
def test_resolve(self):
|
||||
dns_client_mock = Mock()
|
||||
|
||||
service_name = b"test_service.example.com"
|
||||
host_name = b"example.com"
|
||||
|
||||
answer_srv = dns.RRHeader(
|
||||
type=dns.SRV, payload=dns.Record_SRV(target=host_name)
|
||||
)
|
||||
|
||||
result_deferred = Deferred()
|
||||
dns_client_mock.lookupService.return_value = result_deferred
|
||||
|
||||
cache = {}
|
||||
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def do_lookup():
|
||||
|
||||
with LoggingContext("one") as ctx:
|
||||
resolve_d = resolver.resolve_service(service_name)
|
||||
|
||||
self.assertNoResult(resolve_d)
|
||||
|
||||
# should have reset to the sentinel context
|
||||
self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel)
|
||||
|
||||
result = yield resolve_d
|
||||
|
||||
# should have restored our context
|
||||
self.assertIs(LoggingContext.current_context(), ctx)
|
||||
|
||||
defer.returnValue(result)
|
||||
|
||||
test_d = do_lookup()
|
||||
self.assertNoResult(test_d)
|
||||
|
||||
dns_client_mock.lookupService.assert_called_once_with(service_name)
|
||||
|
||||
result_deferred.callback(
|
||||
([answer_srv], None, None)
|
||||
)
|
||||
|
||||
servers = self.successResultOf(test_d)
|
||||
|
||||
self.assertEquals(len(servers), 1)
|
||||
self.assertEquals(servers, cache[service_name])
|
||||
self.assertEquals(servers[0].host, host_name)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_from_cache_expired_and_dns_fail(self):
|
||||
dns_client_mock = Mock()
|
||||
dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
|
||||
|
||||
service_name = b"test_service.example.com"
|
||||
|
||||
entry = Mock(spec_set=["expires"])
|
||||
entry.expires = 0
|
||||
|
||||
cache = {service_name: [entry]}
|
||||
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
||||
|
||||
servers = yield resolver.resolve_service(service_name)
|
||||
|
||||
dns_client_mock.lookupService.assert_called_once_with(service_name)
|
||||
|
||||
self.assertEquals(len(servers), 1)
|
||||
self.assertEquals(servers, cache[service_name])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_from_cache(self):
|
||||
clock = MockClock()
|
||||
|
||||
dns_client_mock = Mock(spec_set=['lookupService'])
|
||||
dns_client_mock.lookupService = Mock(spec_set=[])
|
||||
|
||||
service_name = b"test_service.example.com"
|
||||
|
||||
entry = Mock(spec_set=["expires"])
|
||||
entry.expires = 999999999
|
||||
|
||||
cache = {service_name: [entry]}
|
||||
resolver = SrvResolver(
|
||||
dns_client=dns_client_mock, cache=cache, get_time=clock.time,
|
||||
)
|
||||
|
||||
servers = yield resolver.resolve_service(service_name)
|
||||
|
||||
self.assertFalse(dns_client_mock.lookupService.called)
|
||||
|
||||
self.assertEquals(len(servers), 1)
|
||||
self.assertEquals(servers, cache[service_name])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_empty_cache(self):
|
||||
dns_client_mock = Mock()
|
||||
|
||||
dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
|
||||
|
||||
service_name = b"test_service.example.com"
|
||||
|
||||
cache = {}
|
||||
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
||||
|
||||
with self.assertRaises(error.DNSServerError):
|
||||
yield resolver.resolve_service(service_name)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_name_error(self):
|
||||
dns_client_mock = Mock()
|
||||
|
||||
dns_client_mock.lookupService.return_value = defer.fail(error.DNSNameError())
|
||||
|
||||
service_name = b"test_service.example.com"
|
||||
|
||||
cache = {}
|
||||
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
||||
|
||||
servers = yield resolver.resolve_service(service_name)
|
||||
|
||||
self.assertEquals(len(servers), 0)
|
||||
self.assertEquals(len(cache), 0)
|
||||
|
||||
def test_disabled_service(self):
|
||||
"""
|
||||
test the behaviour when there is a single record which is ".".
|
||||
"""
|
||||
service_name = b"test_service.example.com"
|
||||
|
||||
lookup_deferred = Deferred()
|
||||
dns_client_mock = Mock()
|
||||
dns_client_mock.lookupService.return_value = lookup_deferred
|
||||
cache = {}
|
||||
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
||||
|
||||
resolve_d = resolver.resolve_service(service_name)
|
||||
self.assertNoResult(resolve_d)
|
||||
|
||||
# returning a single "." should make the lookup fail with a ConenctError
|
||||
lookup_deferred.callback((
|
||||
[dns.RRHeader(type=dns.SRV, payload=dns.Record_SRV(target=b"."))],
|
||||
None,
|
||||
None,
|
||||
))
|
||||
|
||||
self.failureResultOf(resolve_d, ConnectError)
|
||||
|
||||
def test_non_srv_answer(self):
|
||||
"""
|
||||
test the behaviour when the dns server gives us a spurious non-SRV response
|
||||
"""
|
||||
service_name = b"test_service.example.com"
|
||||
|
||||
lookup_deferred = Deferred()
|
||||
dns_client_mock = Mock()
|
||||
dns_client_mock.lookupService.return_value = lookup_deferred
|
||||
cache = {}
|
||||
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
||||
|
||||
resolve_d = resolver.resolve_service(service_name)
|
||||
self.assertNoResult(resolve_d)
|
||||
|
||||
lookup_deferred.callback((
|
||||
[
|
||||
dns.RRHeader(type=dns.A, payload=dns.Record_A()),
|
||||
dns.RRHeader(type=dns.SRV, payload=dns.Record_SRV(target=b"host")),
|
||||
],
|
||||
None,
|
||||
None,
|
||||
))
|
||||
|
||||
servers = self.successResultOf(resolve_d)
|
||||
|
||||
self.assertEquals(len(servers), 1)
|
||||
self.assertEquals(servers, cache[service_name])
|
||||
self.assertEquals(servers[0].host, b"host")
|
|
@ -15,8 +15,10 @@
|
|||
|
||||
from mock import Mock
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.defer import TimeoutError
|
||||
from twisted.internet.error import ConnectingCancelledError, DNSLookupError
|
||||
from twisted.test.proto_helpers import StringTransport
|
||||
from twisted.web.client import ResponseNeverReceived
|
||||
from twisted.web.http import HTTPChannel
|
||||
|
||||
|
@ -25,11 +27,20 @@ from synapse.http.matrixfederationclient import (
|
|||
MatrixFederationHttpClient,
|
||||
MatrixFederationRequest,
|
||||
)
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
|
||||
from tests.server import FakeTransport
|
||||
from tests.unittest import HomeserverTestCase
|
||||
|
||||
|
||||
def check_logcontext(context):
|
||||
current = LoggingContext.current_context()
|
||||
if current is not context:
|
||||
raise AssertionError(
|
||||
"Expected logcontext %s but was %s" % (context, current),
|
||||
)
|
||||
|
||||
|
||||
class FederationClientTests(HomeserverTestCase):
|
||||
def make_homeserver(self, reactor, clock):
|
||||
|
||||
|
@ -42,9 +53,73 @@ class FederationClientTests(HomeserverTestCase):
|
|||
self.cl = MatrixFederationHttpClient(self.hs)
|
||||
self.reactor.lookups["testserv"] = "1.2.3.4"
|
||||
|
||||
def test_client_get(self):
|
||||
"""
|
||||
happy-path test of a GET request
|
||||
"""
|
||||
@defer.inlineCallbacks
|
||||
def do_request():
|
||||
with LoggingContext("one") as context:
|
||||
fetch_d = self.cl.get_json("testserv:8008", "foo/bar")
|
||||
|
||||
# Nothing happened yet
|
||||
self.assertNoResult(fetch_d)
|
||||
|
||||
# should have reset logcontext to the sentinel
|
||||
check_logcontext(LoggingContext.sentinel)
|
||||
|
||||
try:
|
||||
fetch_res = yield fetch_d
|
||||
defer.returnValue(fetch_res)
|
||||
finally:
|
||||
check_logcontext(context)
|
||||
|
||||
test_d = do_request()
|
||||
|
||||
self.pump()
|
||||
|
||||
# Nothing happened yet
|
||||
self.assertNoResult(test_d)
|
||||
|
||||
# Make sure treq is trying to connect
|
||||
clients = self.reactor.tcpClients
|
||||
self.assertEqual(len(clients), 1)
|
||||
(host, port, factory, _timeout, _bindAddress) = clients[0]
|
||||
self.assertEqual(host, '1.2.3.4')
|
||||
self.assertEqual(port, 8008)
|
||||
|
||||
# complete the connection and wire it up to a fake transport
|
||||
protocol = factory.buildProtocol(None)
|
||||
transport = StringTransport()
|
||||
protocol.makeConnection(transport)
|
||||
|
||||
# that should have made it send the request to the transport
|
||||
self.assertRegex(transport.value(), b"^GET /foo/bar")
|
||||
|
||||
# Deferred is still without a result
|
||||
self.assertNoResult(test_d)
|
||||
|
||||
# Send it the HTTP response
|
||||
res_json = '{ "a": 1 }'.encode('ascii')
|
||||
protocol.dataReceived(
|
||||
b"HTTP/1.1 200 OK\r\n"
|
||||
b"Server: Fake\r\n"
|
||||
b"Content-Type: application/json\r\n"
|
||||
b"Content-Length: %i\r\n"
|
||||
b"\r\n"
|
||||
b"%s" % (len(res_json), res_json)
|
||||
)
|
||||
|
||||
self.pump()
|
||||
|
||||
res = self.successResultOf(test_d)
|
||||
|
||||
# check the response is as expected
|
||||
self.assertEqual(res, {"a": 1})
|
||||
|
||||
def test_dns_error(self):
|
||||
"""
|
||||
If the DNS raising returns an error, it will bubble up.
|
||||
If the DNS lookup returns an error, it will bubble up.
|
||||
"""
|
||||
d = self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000)
|
||||
self.pump()
|
||||
|
@ -53,6 +128,28 @@ class FederationClientTests(HomeserverTestCase):
|
|||
self.assertIsInstance(f.value, RequestSendFailed)
|
||||
self.assertIsInstance(f.value.inner_exception, DNSLookupError)
|
||||
|
||||
def test_client_connection_refused(self):
|
||||
d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
|
||||
|
||||
self.pump()
|
||||
|
||||
# Nothing happened yet
|
||||
self.assertNoResult(d)
|
||||
|
||||
clients = self.reactor.tcpClients
|
||||
self.assertEqual(len(clients), 1)
|
||||
(host, port, factory, _timeout, _bindAddress) = clients[0]
|
||||
self.assertEqual(host, '1.2.3.4')
|
||||
self.assertEqual(port, 8008)
|
||||
e = Exception("go away")
|
||||
factory.clientConnectionFailed(None, e)
|
||||
self.pump(0.5)
|
||||
|
||||
f = self.failureResultOf(d)
|
||||
|
||||
self.assertIsInstance(f.value, RequestSendFailed)
|
||||
self.assertIs(f.value.inner_exception, e)
|
||||
|
||||
def test_client_never_connect(self):
|
||||
"""
|
||||
If the HTTP request is not connected and is timed out, it'll give a
|
||||
|
@ -63,7 +160,7 @@ class FederationClientTests(HomeserverTestCase):
|
|||
self.pump()
|
||||
|
||||
# Nothing happened yet
|
||||
self.assertFalse(d.called)
|
||||
self.assertNoResult(d)
|
||||
|
||||
# Make sure treq is trying to connect
|
||||
clients = self.reactor.tcpClients
|
||||
|
@ -72,7 +169,7 @@ class FederationClientTests(HomeserverTestCase):
|
|||
self.assertEqual(clients[0][1], 8008)
|
||||
|
||||
# Deferred is still without a result
|
||||
self.assertFalse(d.called)
|
||||
self.assertNoResult(d)
|
||||
|
||||
# Push by enough to time it out
|
||||
self.reactor.advance(10.5)
|
||||
|
@ -94,7 +191,7 @@ class FederationClientTests(HomeserverTestCase):
|
|||
self.pump()
|
||||
|
||||
# Nothing happened yet
|
||||
self.assertFalse(d.called)
|
||||
self.assertNoResult(d)
|
||||
|
||||
# Make sure treq is trying to connect
|
||||
clients = self.reactor.tcpClients
|
||||
|
@ -107,7 +204,7 @@ class FederationClientTests(HomeserverTestCase):
|
|||
client.makeConnection(conn)
|
||||
|
||||
# Deferred is still without a result
|
||||
self.assertFalse(d.called)
|
||||
self.assertNoResult(d)
|
||||
|
||||
# Push by enough to time it out
|
||||
self.reactor.advance(10.5)
|
||||
|
@ -135,7 +232,7 @@ class FederationClientTests(HomeserverTestCase):
|
|||
client.makeConnection(conn)
|
||||
|
||||
# Deferred does not have a result
|
||||
self.assertFalse(d.called)
|
||||
self.assertNoResult(d)
|
||||
|
||||
# Send it the HTTP response
|
||||
client.dataReceived(b"HTTP/1.1 200 OK\r\nServer: Fake\r\n\r\n")
|
||||
|
@ -159,7 +256,7 @@ class FederationClientTests(HomeserverTestCase):
|
|||
client.makeConnection(conn)
|
||||
|
||||
# Deferred does not have a result
|
||||
self.assertFalse(d.called)
|
||||
self.assertNoResult(d)
|
||||
|
||||
# Send it the HTTP response
|
||||
client.dataReceived(
|
||||
|
@ -195,3 +292,42 @@ class FederationClientTests(HomeserverTestCase):
|
|||
request = server.requests[0]
|
||||
content = request.content.read()
|
||||
self.assertEqual(content, b'{"a":"b"}')
|
||||
|
||||
def test_closes_connection(self):
|
||||
"""Check that the client closes unused HTTP connections"""
|
||||
d = self.cl.get_json("testserv:8008", "foo/bar")
|
||||
|
||||
self.pump()
|
||||
|
||||
# there should have been a call to connectTCP
|
||||
clients = self.reactor.tcpClients
|
||||
self.assertEqual(len(clients), 1)
|
||||
(_host, _port, factory, _timeout, _bindAddress) = clients[0]
|
||||
|
||||
# complete the connection and wire it up to a fake transport
|
||||
client = factory.buildProtocol(None)
|
||||
conn = StringTransport()
|
||||
client.makeConnection(conn)
|
||||
|
||||
# that should have made it send the request to the connection
|
||||
self.assertRegex(conn.value(), b"^GET /foo/bar")
|
||||
|
||||
# Send the HTTP response
|
||||
client.dataReceived(
|
||||
b"HTTP/1.1 200 OK\r\n"
|
||||
b"Content-Type: application/json\r\n"
|
||||
b"Content-Length: 2\r\n"
|
||||
b"\r\n"
|
||||
b"{}"
|
||||
)
|
||||
|
||||
# We should get a successful response
|
||||
r = self.successResultOf(d)
|
||||
self.assertEqual(r, {})
|
||||
|
||||
self.assertFalse(conn.disconnecting)
|
||||
|
||||
# wait for a while
|
||||
self.pump(120)
|
||||
|
||||
self.assertTrue(conn.disconnecting)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import json
|
||||
import logging
|
||||
from io import BytesIO
|
||||
|
||||
from six import text_type
|
||||
|
@ -22,6 +23,8 @@ from synapse.util import Clock
|
|||
|
||||
from tests.utils import setup_test_homeserver as _sth
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TimedOutException(Exception):
|
||||
"""
|
||||
|
@ -339,7 +342,7 @@ def get_clock():
|
|||
return (clock, hs_clock)
|
||||
|
||||
|
||||
@attr.s
|
||||
@attr.s(cmp=False)
|
||||
class FakeTransport(object):
|
||||
"""
|
||||
A twisted.internet.interfaces.ITransport implementation which sends all its data
|
||||
|
@ -414,6 +417,11 @@ class FakeTransport(object):
|
|||
self.buffer = self.buffer + byt
|
||||
|
||||
def _write():
|
||||
if not self.buffer:
|
||||
# nothing to do. Don't write empty buffers: it upsets the
|
||||
# TLSMemoryBIOProtocol
|
||||
return
|
||||
|
||||
if getattr(self.other, "transport") is not None:
|
||||
self.other.dataReceived(self.buffer)
|
||||
self.buffer = b""
|
||||
|
@ -421,7 +429,10 @@ class FakeTransport(object):
|
|||
|
||||
self._reactor.callLater(0.0, _write)
|
||||
|
||||
_write()
|
||||
# always actually do the write asynchronously. Some protocols (notably the
|
||||
# TLSMemoryBIOProtocol) get very confused if a read comes back while they are
|
||||
# still doing a write. Doing a callLater here breaks the cycle.
|
||||
self._reactor.callLater(0.0, _write)
|
||||
|
||||
def writeSequence(self, seq):
|
||||
for x in seq:
|
||||
|
|
|
@ -49,6 +49,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
|||
self.db_pool.runWithConnection = runWithConnection
|
||||
|
||||
config = Mock()
|
||||
config._enable_native_upserts = False
|
||||
config.event_cache_size = 1
|
||||
config.database_config = {"name": "sqlite3"}
|
||||
hs = TestHomeServer(
|
||||
|
|
|
@ -62,6 +62,77 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
|
|||
r,
|
||||
)
|
||||
|
||||
def test_insert_new_client_ip_none_device_id(self):
|
||||
"""
|
||||
An insert with a device ID of NULL will not create a new entry, but
|
||||
update an existing entry in the user_ips table.
|
||||
"""
|
||||
self.reactor.advance(12345678)
|
||||
|
||||
user_id = "@user:id"
|
||||
|
||||
# Add & trigger the storage loop
|
||||
self.get_success(
|
||||
self.store.insert_client_ip(
|
||||
user_id, "access_token", "ip", "user_agent", None
|
||||
)
|
||||
)
|
||||
self.reactor.advance(200)
|
||||
self.pump(0)
|
||||
|
||||
result = self.get_success(
|
||||
self.store._simple_select_list(
|
||||
table="user_ips",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
|
||||
desc="get_user_ip_and_agents",
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
result,
|
||||
[
|
||||
{
|
||||
'access_token': 'access_token',
|
||||
'ip': 'ip',
|
||||
'user_agent': 'user_agent',
|
||||
'device_id': None,
|
||||
'last_seen': 12345678000,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
# Add another & trigger the storage loop
|
||||
self.get_success(
|
||||
self.store.insert_client_ip(
|
||||
user_id, "access_token", "ip", "user_agent", None
|
||||
)
|
||||
)
|
||||
self.reactor.advance(10)
|
||||
self.pump(0)
|
||||
|
||||
result = self.get_success(
|
||||
self.store._simple_select_list(
|
||||
table="user_ips",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
|
||||
desc="get_user_ip_and_agents",
|
||||
)
|
||||
)
|
||||
# Only one result, has been upserted.
|
||||
self.assertEqual(
|
||||
result,
|
||||
[
|
||||
{
|
||||
'access_token': 'access_token',
|
||||
'ip': 'ip',
|
||||
'user_agent': 'user_agent',
|
||||
'device_id': None,
|
||||
'last_seen': 12345878000,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
def test_disabled_monthly_active_user(self):
|
||||
self.hs.config.limit_usage_by_mau = False
|
||||
self.hs.config.max_mau_value = 50
|
||||
|
|
|
@ -1,129 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from mock import Mock
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.names import dns, error
|
||||
|
||||
from synapse.http.endpoint import resolve_service
|
||||
|
||||
from tests.utils import MockClock
|
||||
|
||||
from . import unittest
|
||||
|
||||
|
||||
@unittest.DEBUG
|
||||
class DnsTestCase(unittest.TestCase):
|
||||
@defer.inlineCallbacks
|
||||
def test_resolve(self):
|
||||
dns_client_mock = Mock()
|
||||
|
||||
service_name = "test_service.example.com"
|
||||
host_name = "example.com"
|
||||
|
||||
answer_srv = dns.RRHeader(
|
||||
type=dns.SRV, payload=dns.Record_SRV(target=host_name)
|
||||
)
|
||||
|
||||
dns_client_mock.lookupService.return_value = defer.succeed(
|
||||
([answer_srv], None, None)
|
||||
)
|
||||
|
||||
cache = {}
|
||||
|
||||
servers = yield resolve_service(
|
||||
service_name, dns_client=dns_client_mock, cache=cache
|
||||
)
|
||||
|
||||
dns_client_mock.lookupService.assert_called_once_with(service_name)
|
||||
|
||||
self.assertEquals(len(servers), 1)
|
||||
self.assertEquals(servers, cache[service_name])
|
||||
self.assertEquals(servers[0].host, host_name)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_from_cache_expired_and_dns_fail(self):
|
||||
dns_client_mock = Mock()
|
||||
dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
|
||||
|
||||
service_name = "test_service.example.com"
|
||||
|
||||
entry = Mock(spec_set=["expires"])
|
||||
entry.expires = 0
|
||||
|
||||
cache = {service_name: [entry]}
|
||||
|
||||
servers = yield resolve_service(
|
||||
service_name, dns_client=dns_client_mock, cache=cache
|
||||
)
|
||||
|
||||
dns_client_mock.lookupService.assert_called_once_with(service_name)
|
||||
|
||||
self.assertEquals(len(servers), 1)
|
||||
self.assertEquals(servers, cache[service_name])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_from_cache(self):
|
||||
clock = MockClock()
|
||||
|
||||
dns_client_mock = Mock(spec_set=['lookupService'])
|
||||
dns_client_mock.lookupService = Mock(spec_set=[])
|
||||
|
||||
service_name = "test_service.example.com"
|
||||
|
||||
entry = Mock(spec_set=["expires"])
|
||||
entry.expires = 999999999
|
||||
|
||||
cache = {service_name: [entry]}
|
||||
|
||||
servers = yield resolve_service(
|
||||
service_name, dns_client=dns_client_mock, cache=cache, clock=clock
|
||||
)
|
||||
|
||||
self.assertFalse(dns_client_mock.lookupService.called)
|
||||
|
||||
self.assertEquals(len(servers), 1)
|
||||
self.assertEquals(servers, cache[service_name])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_empty_cache(self):
|
||||
dns_client_mock = Mock()
|
||||
|
||||
dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
|
||||
|
||||
service_name = "test_service.example.com"
|
||||
|
||||
cache = {}
|
||||
|
||||
with self.assertRaises(error.DNSServerError):
|
||||
yield resolve_service(service_name, dns_client=dns_client_mock, cache=cache)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_name_error(self):
|
||||
dns_client_mock = Mock()
|
||||
|
||||
dns_client_mock.lookupService.return_value = defer.fail(error.DNSNameError())
|
||||
|
||||
service_name = "test_service.example.com"
|
||||
|
||||
cache = {}
|
||||
|
||||
servers = yield resolve_service(
|
||||
service_name, dns_client=dns_client_mock, cache=cache
|
||||
)
|
||||
|
||||
self.assertEquals(len(servers), 0)
|
||||
self.assertEquals(len(cache), 0)
|
|
@ -19,7 +19,7 @@ from six import StringIO
|
|||
|
||||
from twisted.internet.defer import Deferred
|
||||
from twisted.python.failure import Failure
|
||||
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
|
||||
from twisted.test.proto_helpers import AccumulatingProtocol
|
||||
from twisted.web.resource import Resource
|
||||
from twisted.web.server import NOT_DONE_YET
|
||||
|
||||
|
@ -30,12 +30,18 @@ from synapse.util import Clock
|
|||
from synapse.util.logcontext import make_deferred_yieldable
|
||||
|
||||
from tests import unittest
|
||||
from tests.server import FakeTransport, make_request, render, setup_test_homeserver
|
||||
from tests.server import (
|
||||
FakeTransport,
|
||||
ThreadedMemoryReactorClock,
|
||||
make_request,
|
||||
render,
|
||||
setup_test_homeserver,
|
||||
)
|
||||
|
||||
|
||||
class JsonResourceTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.reactor = MemoryReactorClock()
|
||||
self.reactor = ThreadedMemoryReactorClock()
|
||||
self.hs_clock = Clock(self.reactor)
|
||||
self.homeserver = setup_test_homeserver(
|
||||
self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.reactor
|
||||
|
|
|
@ -96,7 +96,7 @@ class TestCase(unittest.TestCase):
|
|||
|
||||
method = getattr(self, methodName)
|
||||
|
||||
level = getattr(method, "loglevel", getattr(self, "loglevel", logging.ERROR))
|
||||
level = getattr(method, "loglevel", getattr(self, "loglevel", logging.WARNING))
|
||||
|
||||
@around(self)
|
||||
def setUp(orig):
|
||||
|
@ -333,7 +333,15 @@ class HomeserverTestCase(TestCase):
|
|||
"""
|
||||
kwargs = dict(kwargs)
|
||||
kwargs.update(self._hs_args)
|
||||
return setup_test_homeserver(self.addCleanup, *args, **kwargs)
|
||||
hs = setup_test_homeserver(self.addCleanup, *args, **kwargs)
|
||||
stor = hs.get_datastore()
|
||||
|
||||
# Run the database background updates.
|
||||
if hasattr(stor, "do_next_background_update"):
|
||||
while not self.get_success(stor.has_completed_background_updates()):
|
||||
self.get_success(stor.do_next_background_update(1))
|
||||
|
||||
return hs
|
||||
|
||||
def pump(self, by=0.0):
|
||||
"""
|
||||
|
|
104
tests/util/test_async_utils.py
Normal file
104
tests/util/test_async_utils.py
Normal file
|
@ -0,0 +1,104 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2019 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.defer import CancelledError, Deferred
|
||||
from twisted.internet.task import Clock
|
||||
|
||||
from synapse.util import logcontext
|
||||
from synapse.util.async_helpers import timeout_deferred
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
|
||||
from tests.unittest import TestCase
|
||||
|
||||
|
||||
class TimeoutDeferredTest(TestCase):
|
||||
def setUp(self):
|
||||
self.clock = Clock()
|
||||
|
||||
def test_times_out(self):
|
||||
"""Basic test case that checks that the original deferred is cancelled and that
|
||||
the timing-out deferred is errbacked
|
||||
"""
|
||||
cancelled = [False]
|
||||
|
||||
def canceller(_d):
|
||||
cancelled[0] = True
|
||||
|
||||
non_completing_d = Deferred(canceller)
|
||||
timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock)
|
||||
|
||||
self.assertNoResult(timing_out_d)
|
||||
self.assertFalse(cancelled[0], "deferred was cancelled prematurely")
|
||||
|
||||
self.clock.pump((1.0, ))
|
||||
|
||||
self.assertTrue(cancelled[0], "deferred was not cancelled by timeout")
|
||||
self.failureResultOf(timing_out_d, defer.TimeoutError, )
|
||||
|
||||
def test_times_out_when_canceller_throws(self):
|
||||
"""Test that we have successfully worked around
|
||||
https://twistedmatrix.com/trac/ticket/9534"""
|
||||
|
||||
def canceller(_d):
|
||||
raise Exception("can't cancel this deferred")
|
||||
|
||||
non_completing_d = Deferred(canceller)
|
||||
timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock)
|
||||
|
||||
self.assertNoResult(timing_out_d)
|
||||
|
||||
self.clock.pump((1.0, ))
|
||||
|
||||
self.failureResultOf(timing_out_d, defer.TimeoutError, )
|
||||
|
||||
def test_logcontext_is_preserved_on_cancellation(self):
|
||||
blocking_was_cancelled = [False]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def blocking():
|
||||
non_completing_d = Deferred()
|
||||
with logcontext.PreserveLoggingContext():
|
||||
try:
|
||||
yield non_completing_d
|
||||
except CancelledError:
|
||||
blocking_was_cancelled[0] = True
|
||||
raise
|
||||
|
||||
with logcontext.LoggingContext("one") as context_one:
|
||||
# the errbacks should be run in the test logcontext
|
||||
def errback(res, deferred_name):
|
||||
self.assertIs(
|
||||
LoggingContext.current_context(), context_one,
|
||||
"errback %s run in unexpected logcontext %s" % (
|
||||
deferred_name, LoggingContext.current_context(),
|
||||
)
|
||||
)
|
||||
return res
|
||||
|
||||
original_deferred = blocking()
|
||||
original_deferred.addErrback(errback, "orig")
|
||||
timing_out_d = timeout_deferred(original_deferred, 1.0, self.clock)
|
||||
self.assertNoResult(timing_out_d)
|
||||
self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel)
|
||||
timing_out_d.addErrback(errback, "timingout")
|
||||
|
||||
self.clock.pump((1.0, ))
|
||||
|
||||
self.assertTrue(
|
||||
blocking_was_cancelled[0],
|
||||
"non-completing deferred was not cancelled",
|
||||
)
|
||||
self.failureResultOf(timing_out_d, defer.TimeoutError, )
|
||||
self.assertIs(LoggingContext.current_context(), context_one)
|
Loading…
Reference in a new issue