0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-14 12:33:49 +01:00
--------
 
 - Python 3.5 and 3.6 support is now in beta.
 ([\#3576](https://github.com/matrix-org/synapse/issues/3576))
 - Implement `event_format` filter param in `/sync`
 ([\#3790](https://github.com/matrix-org/synapse/issues/3790))
 - Add synapse_admin_mau:registered_reserved_users metric to expose
 number of real reaserved users
 ([\#3846](https://github.com/matrix-org/synapse/issues/3846))
 
 Bugfixes
 --------
 
 - Remove connection ID for replication prometheus metrics, as it creates
 a large number of new series.
 ([\#3788](https://github.com/matrix-org/synapse/issues/3788))
 - guest users should not be part of mau total
 ([\#3800](https://github.com/matrix-org/synapse/issues/3800))
 - Bump dependency on pyopenssl 16.x, to avoid incompatibility with
 recent Twisted.
 ([\#3804](https://github.com/matrix-org/synapse/issues/3804))
 - Fix existing room tags not coming down sync when joining a room
 ([\#3810](https://github.com/matrix-org/synapse/issues/3810))
 - Fix jwt import check
 ([\#3824](https://github.com/matrix-org/synapse/issues/3824))
 - fix VOIP crashes under Python 3 (#3821)
 ([\#3835](https://github.com/matrix-org/synapse/issues/3835))
 - Fix manhole so that it works with latest openssh clients
 ([\#3841](https://github.com/matrix-org/synapse/issues/3841))
 - Fix outbound requests occasionally wedging, which can result in
 federation breaking between servers.
 ([\#3845](https://github.com/matrix-org/synapse/issues/3845))
 - Show heroes if room name/canonical alias has been deleted
 ([\#3851](https://github.com/matrix-org/synapse/issues/3851))
 - Fix handling of redacted events from federation
 ([\#3859](https://github.com/matrix-org/synapse/issues/3859))
 -  ([\#3874](https://github.com/matrix-org/synapse/issues/3874))
 - Mitigate outbound federation randomly becoming wedged
 ([\#3875](https://github.com/matrix-org/synapse/issues/3875))
 
 Internal Changes
 ----------------
 
 - CircleCI tests now run on the potential merge of a PR.
 ([\#3704](https://github.com/matrix-org/synapse/issues/3704))
 - http/ is now ported to Python 3.
 ([\#3771](https://github.com/matrix-org/synapse/issues/3771))
 - Improve human readable error messages for threepid
 registration/account update
 ([\#3789](https://github.com/matrix-org/synapse/issues/3789))
 - Make /sync slightly faster by avoiding needless copies
 ([\#3795](https://github.com/matrix-org/synapse/issues/3795))
 - handlers/ is now ported to Python 3.
 ([\#3803](https://github.com/matrix-org/synapse/issues/3803))
 - Limit the number of PDUs/EDUs per federation transaction
 ([\#3805](https://github.com/matrix-org/synapse/issues/3805))
 - Only start postgres instance for postgres tests on Travis CI
 ([\#3806](https://github.com/matrix-org/synapse/issues/3806))
 - tests/ is now ported to Python 3.
 ([\#3808](https://github.com/matrix-org/synapse/issues/3808))
 - crypto/ is now ported to Python 3.
 ([\#3822](https://github.com/matrix-org/synapse/issues/3822))
 - rest/ is now ported to Python 3.
 ([\#3823](https://github.com/matrix-org/synapse/issues/3823))
 - add some logging for the keyring queue
 ([\#3826](https://github.com/matrix-org/synapse/issues/3826))
 - speed up lazy loading by 2-3x
 ([\#3827](https://github.com/matrix-org/synapse/issues/3827))
 - Improved Dockerfile to remove build requirements after building
 reducing the image size.
 ([\#3834](https://github.com/matrix-org/synapse/issues/3834))
 - Disable lazy loading for incremental syncs for now
 ([\#3840](https://github.com/matrix-org/synapse/issues/3840))
 - federation/ is now ported to Python 3.
 ([\#3847](https://github.com/matrix-org/synapse/issues/3847))
 - Log when we retry outbound requests
 ([\#3853](https://github.com/matrix-org/synapse/issues/3853))
 - Removed some excess logging messages.
 ([\#3855](https://github.com/matrix-org/synapse/issues/3855))
 - Speed up purge history for rooms that have been previously purged
 ([\#3856](https://github.com/matrix-org/synapse/issues/3856))
 - Refactor some HTTP timeout code.
 ([\#3857](https://github.com/matrix-org/synapse/issues/3857))
 - Fix running merged builds on CircleCI
 ([\#3858](https://github.com/matrix-org/synapse/issues/3858))
 - Fix typo in replication stream exception.
 ([\#3860](https://github.com/matrix-org/synapse/issues/3860))
 - Add in flight real time metrics for Measure blocks
 ([\#3871](https://github.com/matrix-org/synapse/issues/3871))
 - Disable buffering and automatic retrying in treq requests to prevent
 timeouts. ([\#3872](https://github.com/matrix-org/synapse/issues/3872))
 - mention jemalloc in the README
 ([\#3877](https://github.com/matrix-org/synapse/issues/3877))
 - Remove unmaintained "nuke-room-from-db.sh" script
 ([\#3888](https://github.com/matrix-org/synapse/issues/3888))
 -----BEGIN PGP SIGNATURE-----
 
 iQEzBAABCAAdFiEEIQBQJ4l+yK4dlKkFIwi0edOSShEFAluo6WIACgkQIwi0edOS
 ShHFIwf/ZURGbih+6q4jy8voc4wUbgtuNGabmFLQeXGhIXIOYltvKOrB8wcCEi/4
 jODztPqewikZoI7NFA4M44GToQKacuIEOYK/M513l3t/mEpBQYDdvLGDccM+6XGc
 RcrWqQEfJYkiNyWNmLx5TJ4bYaNelrpfUyNbbONph0wLO7o6f7ANoGyqqVKm/7pD
 ufUwwBLqzd2U2vuvNCbQLSW0LkognP91vYytHMN5NTkoAR4BSYfDxQCShUSoNSlJ
 098CEYjXZaDcY9ypGA6wlflbNfl3NC2KcTajLOCHdKe2U+FbEMc/9Ou690GbDRJ9
 6L+Bonfwx0tsF04LHsteTdyTIOVp+Q==
 =Dv5r
 -----END PGP SIGNATURE-----

Merge tag 'v0.33.5'

Features
--------

- Python 3.5 and 3.6 support is now in beta.
([\#3576](https://github.com/matrix-org/synapse/issues/3576))
- Implement `event_format` filter param in `/sync`
([\#3790](https://github.com/matrix-org/synapse/issues/3790))
- Add synapse_admin_mau:registered_reserved_users metric to expose
number of real reaserved users
([\#3846](https://github.com/matrix-org/synapse/issues/3846))

Bugfixes
--------

- Remove connection ID for replication prometheus metrics, as it creates
a large number of new series.
([\#3788](https://github.com/matrix-org/synapse/issues/3788))
- guest users should not be part of mau total
([\#3800](https://github.com/matrix-org/synapse/issues/3800))
- Bump dependency on pyopenssl 16.x, to avoid incompatibility with
recent Twisted.
([\#3804](https://github.com/matrix-org/synapse/issues/3804))
- Fix existing room tags not coming down sync when joining a room
([\#3810](https://github.com/matrix-org/synapse/issues/3810))
- Fix jwt import check
([\#3824](https://github.com/matrix-org/synapse/issues/3824))
- fix VOIP crashes under Python 3 (#3821)
([\#3835](https://github.com/matrix-org/synapse/issues/3835))
- Fix manhole so that it works with latest openssh clients
([\#3841](https://github.com/matrix-org/synapse/issues/3841))
- Fix outbound requests occasionally wedging, which can result in
federation breaking between servers.
([\#3845](https://github.com/matrix-org/synapse/issues/3845))
- Show heroes if room name/canonical alias has been deleted
([\#3851](https://github.com/matrix-org/synapse/issues/3851))
- Fix handling of redacted events from federation
([\#3859](https://github.com/matrix-org/synapse/issues/3859))
-  ([\#3874](https://github.com/matrix-org/synapse/issues/3874))
- Mitigate outbound federation randomly becoming wedged
([\#3875](https://github.com/matrix-org/synapse/issues/3875))

Internal Changes
----------------

- CircleCI tests now run on the potential merge of a PR.
([\#3704](https://github.com/matrix-org/synapse/issues/3704))
- http/ is now ported to Python 3.
([\#3771](https://github.com/matrix-org/synapse/issues/3771))
- Improve human readable error messages for threepid
registration/account update
([\#3789](https://github.com/matrix-org/synapse/issues/3789))
- Make /sync slightly faster by avoiding needless copies
([\#3795](https://github.com/matrix-org/synapse/issues/3795))
- handlers/ is now ported to Python 3.
([\#3803](https://github.com/matrix-org/synapse/issues/3803))
- Limit the number of PDUs/EDUs per federation transaction
([\#3805](https://github.com/matrix-org/synapse/issues/3805))
- Only start postgres instance for postgres tests on Travis CI
([\#3806](https://github.com/matrix-org/synapse/issues/3806))
- tests/ is now ported to Python 3.
([\#3808](https://github.com/matrix-org/synapse/issues/3808))
- crypto/ is now ported to Python 3.
([\#3822](https://github.com/matrix-org/synapse/issues/3822))
- rest/ is now ported to Python 3.
([\#3823](https://github.com/matrix-org/synapse/issues/3823))
- add some logging for the keyring queue
([\#3826](https://github.com/matrix-org/synapse/issues/3826))
- speed up lazy loading by 2-3x
([\#3827](https://github.com/matrix-org/synapse/issues/3827))
- Improved Dockerfile to remove build requirements after building
reducing the image size.
([\#3834](https://github.com/matrix-org/synapse/issues/3834))
- Disable lazy loading for incremental syncs for now
([\#3840](https://github.com/matrix-org/synapse/issues/3840))
- federation/ is now ported to Python 3.
([\#3847](https://github.com/matrix-org/synapse/issues/3847))
- Log when we retry outbound requests
([\#3853](https://github.com/matrix-org/synapse/issues/3853))
- Removed some excess logging messages.
([\#3855](https://github.com/matrix-org/synapse/issues/3855))
- Speed up purge history for rooms that have been previously purged
([\#3856](https://github.com/matrix-org/synapse/issues/3856))
- Refactor some HTTP timeout code.
([\#3857](https://github.com/matrix-org/synapse/issues/3857))
- Fix running merged builds on CircleCI
([\#3858](https://github.com/matrix-org/synapse/issues/3858))
- Fix typo in replication stream exception.
([\#3860](https://github.com/matrix-org/synapse/issues/3860))
- Add in flight real time metrics for Measure blocks
([\#3871](https://github.com/matrix-org/synapse/issues/3871))
- Disable buffering and automatic retrying in treq requests to prevent
timeouts. ([\#3872](https://github.com/matrix-org/synapse/issues/3872))
- mention jemalloc in the README
([\#3877](https://github.com/matrix-org/synapse/issues/3877))
- Remove unmaintained "nuke-room-from-db.sh" script
([\#3888](https://github.com/matrix-org/synapse/issues/3888))
This commit is contained in:
Amber Brown 2018-09-24 23:41:35 +10:00
commit 829213523e
87 changed files with 1998 additions and 1091 deletions

View file

@ -9,6 +9,8 @@ jobs:
- store_artifacts: - store_artifacts:
path: ~/project/logs path: ~/project/logs
destination: logs destination: logs
- store_test_results:
path: logs
sytestpy2postgres: sytestpy2postgres:
machine: true machine: true
steps: steps:
@ -18,15 +20,45 @@ jobs:
- store_artifacts: - store_artifacts:
path: ~/project/logs path: ~/project/logs
destination: logs destination: logs
- store_test_results:
path: logs
sytestpy2merged:
machine: true
steps:
- checkout
- run: bash .circleci/merge_base_branch.sh
- run: docker pull matrixdotorg/sytest-synapsepy2
- run: docker run --rm -it -v $(pwd)\:/src -v $(pwd)/logs\:/logs matrixdotorg/sytest-synapsepy2
- store_artifacts:
path: ~/project/logs
destination: logs
- store_test_results:
path: logs
sytestpy2postgresmerged:
machine: true
steps:
- checkout
- run: bash .circleci/merge_base_branch.sh
- run: docker pull matrixdotorg/sytest-synapsepy2
- run: docker run --rm -it -v $(pwd)\:/src -v $(pwd)/logs\:/logs -e POSTGRES=1 matrixdotorg/sytest-synapsepy2
- store_artifacts:
path: ~/project/logs
destination: logs
- store_test_results:
path: logs
sytestpy3: sytestpy3:
machine: true machine: true
steps: steps:
- checkout - checkout
- run: docker pull matrixdotorg/sytest-synapsepy3 - run: docker pull matrixdotorg/sytest-synapsepy3
- run: docker run --rm -it -v $(pwd)\:/src -v $(pwd)/logs\:/logs hawkowl/sytestpy3 - run: docker run --rm -it -v $(pwd)\:/src -v $(pwd)/logs\:/logs matrixdotorg/sytest-synapsepy3
- store_artifacts: - store_artifacts:
path: ~/project/logs path: ~/project/logs
destination: logs destination: logs
- store_test_results:
path: logs
sytestpy3postgres: sytestpy3postgres:
machine: true machine: true
steps: steps:
@ -36,6 +68,32 @@ jobs:
- store_artifacts: - store_artifacts:
path: ~/project/logs path: ~/project/logs
destination: logs destination: logs
- store_test_results:
path: logs
sytestpy3merged:
machine: true
steps:
- checkout
- run: bash .circleci/merge_base_branch.sh
- run: docker pull matrixdotorg/sytest-synapsepy3
- run: docker run --rm -it -v $(pwd)\:/src -v $(pwd)/logs\:/logs matrixdotorg/sytest-synapsepy3
- store_artifacts:
path: ~/project/logs
destination: logs
- store_test_results:
path: logs
sytestpy3postgresmerged:
machine: true
steps:
- checkout
- run: bash .circleci/merge_base_branch.sh
- run: docker pull matrixdotorg/sytest-synapsepy3
- run: docker run --rm -it -v $(pwd)\:/src -v $(pwd)/logs\:/logs -e POSTGRES=1 matrixdotorg/sytest-synapsepy3
- store_artifacts:
path: ~/project/logs
destination: logs
- store_test_results:
path: logs
workflows: workflows:
version: 2 version: 2
@ -43,6 +101,21 @@ workflows:
jobs: jobs:
- sytestpy2 - sytestpy2
- sytestpy2postgres - sytestpy2postgres
# Currently broken while the Python 3 port is incomplete - sytestpy3
# - sytestpy3 - sytestpy3postgres
# - sytestpy3postgres - sytestpy2merged:
filters:
branches:
ignore: /develop|master/
- sytestpy2postgresmerged:
filters:
branches:
ignore: /develop|master/
- sytestpy3merged:
filters:
branches:
ignore: /develop|master/
- sytestpy3postgresmerged:
filters:
branches:
ignore: /develop|master/

31
.circleci/merge_base_branch.sh Executable file
View file

@ -0,0 +1,31 @@
#!/usr/bin/env bash
set -e
# CircleCI doesn't give CIRCLE_PR_NUMBER in the environment for non-forked PRs. Wonderful.
# In this case, we just need to do some ~shell magic~ to strip it out of the PULL_REQUEST URL.
echo 'export CIRCLE_PR_NUMBER="${CIRCLE_PR_NUMBER:-${CIRCLE_PULL_REQUEST##*/}}"' >> $BASH_ENV
source $BASH_ENV
if [[ -z "${CIRCLE_PR_NUMBER}" ]]
then
echo "Can't figure out what the PR number is!"
exit 1
fi
# Get the reference, using the GitHub API
GITBASE=`curl -q https://api.github.com/repos/matrix-org/synapse/pulls/${CIRCLE_PR_NUMBER} | jq -r '.base.ref'`
# Show what we are before
git show -s
# Set up username so it can do a merge
git config --global user.email bot@matrix.org
git config --global user.name "A robot"
# Fetch and merge. If it doesn't work, it will raise due to set -e.
git fetch -u origin $GITBASE
git merge --no-edit origin/$GITBASE
# Show what we are after.
git show -s

1
.gitignore vendored
View file

@ -44,6 +44,7 @@ media_store/
build/ build/
venv/ venv/
venv*/ venv*/
*venv/
localhost-800*/ localhost-800*/
static/client/register/register_config.js static/client/register/register_config.js

View file

@ -8,9 +8,6 @@ before_script:
- git remote set-branches --add origin develop - git remote set-branches --add origin develop
- git fetch origin develop - git fetch origin develop
services:
- postgresql
matrix: matrix:
fast_finish: true fast_finish: true
include: include:
@ -25,6 +22,11 @@ matrix:
- python: 2.7 - python: 2.7
env: TOX_ENV=py27-postgres TRIAL_FLAGS="-j 4" env: TOX_ENV=py27-postgres TRIAL_FLAGS="-j 4"
services:
- postgresql
- python: 3.5
env: TOX_ENV=py35
- python: 3.6 - python: 3.6
env: TOX_ENV=py36 env: TOX_ENV=py36

View file

@ -1,3 +1,67 @@
Synapse 0.33.5 (2018-09-24)
===========================
No significant changes.
Synapse 0.33.5rc1 (2018-09-17)
==============================
Features
--------
- Python 3.5 and 3.6 support is now in beta. ([\#3576](https://github.com/matrix-org/synapse/issues/3576))
- Implement `event_format` filter param in `/sync` ([\#3790](https://github.com/matrix-org/synapse/issues/3790))
- Add synapse_admin_mau:registered_reserved_users metric to expose number of real reaserved users ([\#3846](https://github.com/matrix-org/synapse/issues/3846))
Bugfixes
--------
- Remove connection ID for replication prometheus metrics, as it creates a large number of new series. ([\#3788](https://github.com/matrix-org/synapse/issues/3788))
- guest users should not be part of mau total ([\#3800](https://github.com/matrix-org/synapse/issues/3800))
- Bump dependency on pyopenssl 16.x, to avoid incompatibility with recent Twisted. ([\#3804](https://github.com/matrix-org/synapse/issues/3804))
- Fix existing room tags not coming down sync when joining a room ([\#3810](https://github.com/matrix-org/synapse/issues/3810))
- Fix jwt import check ([\#3824](https://github.com/matrix-org/synapse/issues/3824))
- fix VOIP crashes under Python 3 (#3821) ([\#3835](https://github.com/matrix-org/synapse/issues/3835))
- Fix manhole so that it works with latest openssh clients ([\#3841](https://github.com/matrix-org/synapse/issues/3841))
- Fix outbound requests occasionally wedging, which can result in federation breaking between servers. ([\#3845](https://github.com/matrix-org/synapse/issues/3845))
- Show heroes if room name/canonical alias has been deleted ([\#3851](https://github.com/matrix-org/synapse/issues/3851))
- Fix handling of redacted events from federation ([\#3859](https://github.com/matrix-org/synapse/issues/3859))
- ([\#3874](https://github.com/matrix-org/synapse/issues/3874))
- Mitigate outbound federation randomly becoming wedged ([\#3875](https://github.com/matrix-org/synapse/issues/3875))
Internal Changes
----------------
- CircleCI tests now run on the potential merge of a PR. ([\#3704](https://github.com/matrix-org/synapse/issues/3704))
- http/ is now ported to Python 3. ([\#3771](https://github.com/matrix-org/synapse/issues/3771))
- Improve human readable error messages for threepid registration/account update ([\#3789](https://github.com/matrix-org/synapse/issues/3789))
- Make /sync slightly faster by avoiding needless copies ([\#3795](https://github.com/matrix-org/synapse/issues/3795))
- handlers/ is now ported to Python 3. ([\#3803](https://github.com/matrix-org/synapse/issues/3803))
- Limit the number of PDUs/EDUs per federation transaction ([\#3805](https://github.com/matrix-org/synapse/issues/3805))
- Only start postgres instance for postgres tests on Travis CI ([\#3806](https://github.com/matrix-org/synapse/issues/3806))
- tests/ is now ported to Python 3. ([\#3808](https://github.com/matrix-org/synapse/issues/3808))
- crypto/ is now ported to Python 3. ([\#3822](https://github.com/matrix-org/synapse/issues/3822))
- rest/ is now ported to Python 3. ([\#3823](https://github.com/matrix-org/synapse/issues/3823))
- add some logging for the keyring queue ([\#3826](https://github.com/matrix-org/synapse/issues/3826))
- speed up lazy loading by 2-3x ([\#3827](https://github.com/matrix-org/synapse/issues/3827))
- Improved Dockerfile to remove build requirements after building reducing the image size. ([\#3834](https://github.com/matrix-org/synapse/issues/3834))
- Disable lazy loading for incremental syncs for now ([\#3840](https://github.com/matrix-org/synapse/issues/3840))
- federation/ is now ported to Python 3. ([\#3847](https://github.com/matrix-org/synapse/issues/3847))
- Log when we retry outbound requests ([\#3853](https://github.com/matrix-org/synapse/issues/3853))
- Removed some excess logging messages. ([\#3855](https://github.com/matrix-org/synapse/issues/3855))
- Speed up purge history for rooms that have been previously purged ([\#3856](https://github.com/matrix-org/synapse/issues/3856))
- Refactor some HTTP timeout code. ([\#3857](https://github.com/matrix-org/synapse/issues/3857))
- Fix running merged builds on CircleCI ([\#3858](https://github.com/matrix-org/synapse/issues/3858))
- Fix typo in replication stream exception. ([\#3860](https://github.com/matrix-org/synapse/issues/3860))
- Add in flight real time metrics for Measure blocks ([\#3871](https://github.com/matrix-org/synapse/issues/3871))
- Disable buffering and automatic retrying in treq requests to prevent timeouts. ([\#3872](https://github.com/matrix-org/synapse/issues/3872))
- mention jemalloc in the README ([\#3877](https://github.com/matrix-org/synapse/issues/3877))
- Remove unmaintained "nuke-room-from-db.sh" script ([\#3888](https://github.com/matrix-org/synapse/issues/3888))
Synapse 0.33.4 (2018-09-07) Synapse 0.33.4 (2018-09-07)
=========================== ===========================

View file

@ -742,6 +742,18 @@ so an example nginx configuration might look like::
} }
} }
and an example apache configuration may look like::
<VirtualHost *:443>
SSLEngine on
ServerName matrix.example.com;
<Location /_matrix>
ProxyPass http://127.0.0.1:8008/_matrix nocanon
ProxyPassReverse http://127.0.0.1:8008/_matrix
</Location>
</VirtualHost>
You will also want to set ``bind_addresses: ['127.0.0.1']`` and ``x_forwarded: true`` You will also want to set ``bind_addresses: ['127.0.0.1']`` and ``x_forwarded: true``
for port 8008 in ``homeserver.yaml`` to ensure that client IP addresses are for port 8008 in ``homeserver.yaml`` to ensure that client IP addresses are
recorded correctly. recorded correctly.
@ -951,5 +963,13 @@ variable. The default is 0.5, which can be decreased to reduce RAM usage
in memory constrained enviroments, or increased if performance starts to in memory constrained enviroments, or increased if performance starts to
degrade. degrade.
Using `libjemalloc <http://jemalloc.net/>`_ can also yield a significant
improvement in overall amount, and especially in terms of giving back RAM
to the OS. To use it, the library must simply be put in the LD_PRELOAD
environment variable when launching Synapse. On Debian, this can be done
by installing the ``libjemalloc1`` package and adding this line to
``/etc/default/matrix-synaspse``::
LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libjemalloc.so.1
.. _`key_management`: https://matrix.org/docs/spec/server_server/unstable.html#retrieving-server-keys .. _`key_management`: https://matrix.org/docs/spec/server_server/unstable.html#retrieving-server-keys

View file

@ -1,6 +1,8 @@
FROM docker.io/python:2-alpine3.8 FROM docker.io/python:2-alpine3.8
RUN apk add --no-cache --virtual .nacl_deps \ COPY . /synapse
RUN apk add --no-cache --virtual .build_deps \
build-base \ build-base \
libffi-dev \ libffi-dev \
libjpeg-turbo-dev \ libjpeg-turbo-dev \
@ -8,13 +10,16 @@ RUN apk add --no-cache --virtual .nacl_deps \
libxslt-dev \ libxslt-dev \
linux-headers \ linux-headers \
postgresql-dev \ postgresql-dev \
su-exec \ zlib-dev \
zlib-dev && cd /synapse \
&& apk add --no-cache --virtual .runtime_deps \
COPY . /synapse libffi \
libjpeg-turbo \
# A wheel cache may be provided in ./cache for faster build libressl \
RUN cd /synapse \ libxslt \
libpq \
zlib \
su-exec \
&& pip install --upgrade \ && pip install --upgrade \
lxml \ lxml \
pip \ pip \
@ -26,8 +31,9 @@ RUN cd /synapse \
&& rm -rf \ && rm -rf \
setup.cfg \ setup.cfg \
setup.py \ setup.py \
synapse synapse \
&& apk del .build_deps
VOLUME ["/data"] VOLUME ["/data"]
EXPOSE 8008/tcp 8448/tcp EXPOSE 8008/tcp 8448/tcp

View file

@ -1,57 +0,0 @@
#!/bin/bash
## CAUTION:
## This script will remove (hopefully) all trace of the given room ID from
## your homeserver.db
## Do not run it lightly.
set -e
if [ "$1" == "-h" ] || [ "$1" == "" ]; then
echo "Call with ROOM_ID as first option and then pipe it into the database. So for instance you might run"
echo " nuke-room-from-db.sh <room_id> | sqlite3 homeserver.db"
echo "or"
echo " nuke-room-from-db.sh <room_id> | psql --dbname=synapse"
exit
fi
ROOMID="$1"
cat <<EOF
DELETE FROM event_forward_extremities WHERE room_id = '$ROOMID';
DELETE FROM event_backward_extremities WHERE room_id = '$ROOMID';
DELETE FROM event_edges WHERE room_id = '$ROOMID';
DELETE FROM room_depth WHERE room_id = '$ROOMID';
DELETE FROM state_forward_extremities WHERE room_id = '$ROOMID';
DELETE FROM events WHERE room_id = '$ROOMID';
DELETE FROM event_json WHERE room_id = '$ROOMID';
DELETE FROM state_events WHERE room_id = '$ROOMID';
DELETE FROM current_state_events WHERE room_id = '$ROOMID';
DELETE FROM room_memberships WHERE room_id = '$ROOMID';
DELETE FROM feedback WHERE room_id = '$ROOMID';
DELETE FROM topics WHERE room_id = '$ROOMID';
DELETE FROM room_names WHERE room_id = '$ROOMID';
DELETE FROM rooms WHERE room_id = '$ROOMID';
DELETE FROM room_hosts WHERE room_id = '$ROOMID';
DELETE FROM room_aliases WHERE room_id = '$ROOMID';
DELETE FROM state_groups WHERE room_id = '$ROOMID';
DELETE FROM state_groups_state WHERE room_id = '$ROOMID';
DELETE FROM receipts_graph WHERE room_id = '$ROOMID';
DELETE FROM receipts_linearized WHERE room_id = '$ROOMID';
DELETE FROM event_search WHERE room_id = '$ROOMID';
DELETE FROM guest_access WHERE room_id = '$ROOMID';
DELETE FROM history_visibility WHERE room_id = '$ROOMID';
DELETE FROM room_tags WHERE room_id = '$ROOMID';
DELETE FROM room_tags_revisions WHERE room_id = '$ROOMID';
DELETE FROM room_account_data WHERE room_id = '$ROOMID';
DELETE FROM event_push_actions WHERE room_id = '$ROOMID';
DELETE FROM local_invites WHERE room_id = '$ROOMID';
DELETE FROM pusher_throttle WHERE room_id = '$ROOMID';
DELETE FROM event_reports WHERE room_id = '$ROOMID';
DELETE FROM public_room_list_stream WHERE room_id = '$ROOMID';
DELETE FROM stream_ordering_to_exterm WHERE room_id = '$ROOMID';
DELETE FROM event_auth WHERE room_id = '$ROOMID';
DELETE FROM appservice_room_list WHERE room_id = '$ROOMID';
VACUUM;
EOF

View file

@ -17,13 +17,14 @@ ignore =
[pep8] [pep8]
max-line-length = 90 max-line-length = 90
# W503 requires that binary operators be at the end, not start, of lines. Erik # W503 requires that binary operators be at the end, not start, of lines. Erik
# doesn't like it. E203 is contrary to PEP8. # doesn't like it. E203 is contrary to PEP8. E731 is silly.
ignore = W503,E203 ignore = W503,E203,E731
[flake8] [flake8]
# note that flake8 inherits the "ignore" settings from "pep8" (because it uses # note that flake8 inherits the "ignore" settings from "pep8" (because it uses
# pep8 to do those checks), but not the "max-line-length" setting # pep8 to do those checks), but not the "max-line-length" setting
max-line-length = 90 max-line-length = 90
ignore=W503,E203,E731
[isort] [isort]
line_length = 89 line_length = 89

View file

@ -17,4 +17,14 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.33.4" try:
from twisted.internet import protocol
from twisted.internet.protocol import Factory
from twisted.names.dns import DNSDatagramProtocol
protocol.Factory.noisy = False
Factory.noisy = False
DNSDatagramProtocol.noisy = False
except ImportError:
pass
__version__ = "0.33.5"

View file

@ -251,6 +251,7 @@ class FilterCollection(object):
"include_leave", False "include_leave", False
) )
self.event_fields = filter_json.get("event_fields", []) self.event_fields = filter_json.get("event_fields", [])
self.event_format = filter_json.get("event_format", "client")
def __repr__(self): def __repr__(self):
return "<FilterCollection %s>" % (json.dumps(self._filter_json),) return "<FilterCollection %s>" % (json.dumps(self._filter_json),)

View file

@ -307,6 +307,10 @@ class SynapseHomeServer(HomeServer):
# Gauges to expose monthly active user control metrics # Gauges to expose monthly active user control metrics
current_mau_gauge = Gauge("synapse_admin_mau:current", "Current MAU") current_mau_gauge = Gauge("synapse_admin_mau:current", "Current MAU")
max_mau_gauge = Gauge("synapse_admin_mau:max", "MAU Limit") max_mau_gauge = Gauge("synapse_admin_mau:max", "MAU Limit")
registered_reserved_users_mau_gauge = Gauge(
"synapse_admin_mau:registered_reserved_users",
"Registered users with reserved threepids"
)
def setup(config_options): def setup(config_options):
@ -531,10 +535,14 @@ def run(hs):
@defer.inlineCallbacks @defer.inlineCallbacks
def generate_monthly_active_users(): def generate_monthly_active_users():
count = 0 current_mau_count = 0
reserved_count = 0
store = hs.get_datastore()
if hs.config.limit_usage_by_mau: if hs.config.limit_usage_by_mau:
count = yield hs.get_datastore().get_monthly_active_count() current_mau_count = yield store.get_monthly_active_count()
current_mau_gauge.set(float(count)) reserved_count = yield store.get_registered_reserved_users_count()
current_mau_gauge.set(float(current_mau_count))
registered_reserved_users_mau_gauge.set(float(reserved_count))
max_mau_gauge.set(float(hs.config.max_mau_value)) max_mau_gauge.set(float(hs.config.max_mau_value))
hs.get_datastore().initialise_reserved_users( hs.get_datastore().initialise_reserved_users(

View file

@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
import urllib
from six.moves import urllib
from prometheus_client import Counter from prometheus_client import Counter
@ -98,7 +99,7 @@ class ApplicationServiceApi(SimpleHttpClient):
def query_user(self, service, user_id): def query_user(self, service, user_id):
if service.url is None: if service.url is None:
defer.returnValue(False) defer.returnValue(False)
uri = service.url + ("/users/%s" % urllib.quote(user_id)) uri = service.url + ("/users/%s" % urllib.parse.quote(user_id))
response = None response = None
try: try:
response = yield self.get_json(uri, { response = yield self.get_json(uri, {
@ -119,7 +120,7 @@ class ApplicationServiceApi(SimpleHttpClient):
def query_alias(self, service, alias): def query_alias(self, service, alias):
if service.url is None: if service.url is None:
defer.returnValue(False) defer.returnValue(False)
uri = service.url + ("/rooms/%s" % urllib.quote(alias)) uri = service.url + ("/rooms/%s" % urllib.parse.quote(alias))
response = None response = None
try: try:
response = yield self.get_json(uri, { response = yield self.get_json(uri, {
@ -153,7 +154,7 @@ class ApplicationServiceApi(SimpleHttpClient):
service.url, service.url,
APP_SERVICE_PREFIX, APP_SERVICE_PREFIX,
kind, kind,
urllib.quote(protocol) urllib.parse.quote(protocol)
) )
try: try:
response = yield self.get_json(uri, fields) response = yield self.get_json(uri, fields)
@ -188,7 +189,7 @@ class ApplicationServiceApi(SimpleHttpClient):
uri = "%s%s/thirdparty/protocol/%s" % ( uri = "%s%s/thirdparty/protocol/%s" % (
service.url, service.url,
APP_SERVICE_PREFIX, APP_SERVICE_PREFIX,
urllib.quote(protocol) urllib.parse.quote(protocol)
) )
try: try:
info = yield self.get_json(uri, {}) info = yield self.get_json(uri, {})
@ -228,7 +229,7 @@ class ApplicationServiceApi(SimpleHttpClient):
txn_id = str(txn_id) txn_id = str(txn_id)
uri = service.url + ("/transactions/%s" % uri = service.url + ("/transactions/%s" %
urllib.quote(txn_id)) urllib.parse.quote(txn_id))
try: try:
yield self.put_json( yield self.put_json(
uri=uri, uri=uri,

View file

@ -21,7 +21,7 @@ from .consent_config import ConsentConfig
from .database import DatabaseConfig from .database import DatabaseConfig
from .emailconfig import EmailConfig from .emailconfig import EmailConfig
from .groups import GroupsConfig from .groups import GroupsConfig
from .jwt import JWTConfig from .jwt_config import JWTConfig
from .key import KeyConfig from .key import KeyConfig
from .logger import LoggingConfig from .logger import LoggingConfig
from .metrics import MetricsConfig from .metrics import MetricsConfig

View file

@ -227,7 +227,22 @@ def setup_logging(config, use_worker_options=False):
# #
# However this may not be too much of a problem if we are just writing to a file. # However this may not be too much of a problem if we are just writing to a file.
observer = STDLibLogObserver() observer = STDLibLogObserver()
def _log(event):
if "log_text" in event:
if event["log_text"].startswith("DNSDatagramProtocol starting on "):
return
if event["log_text"].startswith("(UDP Port "):
return
if event["log_text"].startswith("Timing out client"):
return
return observer(event)
globalLogBeginner.beginLoggingTo( globalLogBeginner.beginLoggingTo(
[observer], [_log],
redirectStandardIO=not config.no_redirect_stdio, redirectStandardIO=not config.no_redirect_stdio,
) )

View file

@ -123,6 +123,6 @@ class ClientTLSOptionsFactory(object):
def get_options(self, host): def get_options(self, host):
return ClientTLSOptions( return ClientTLSOptions(
host.decode('utf-8'), host,
CertificateOptions(verify=False).getContext() CertificateOptions(verify=False).getContext()
) )

View file

@ -50,7 +50,7 @@ def fetch_server_key(server_name, tls_client_options_factory, path=KEY_API_V1):
defer.returnValue((server_response, server_certificate)) defer.returnValue((server_response, server_certificate))
except SynapseKeyClientError as e: except SynapseKeyClientError as e:
logger.warn("Error getting key for %r: %s", server_name, e) logger.warn("Error getting key for %r: %s", server_name, e)
if e.status.startswith("4"): if e.status.startswith(b"4"):
# Don't retry for 4xx responses. # Don't retry for 4xx responses.
raise IOError("Cannot get key for %r" % server_name) raise IOError("Cannot get key for %r" % server_name)
except (ConnectError, DomainError) as e: except (ConnectError, DomainError) as e:
@ -82,6 +82,12 @@ class SynapseKeyClientProtocol(HTTPClient):
self._peer = self.transport.getPeer() self._peer = self.transport.getPeer()
logger.debug("Connected to %s", self._peer) logger.debug("Connected to %s", self._peer)
if not isinstance(self.path, bytes):
self.path = self.path.encode('ascii')
if not isinstance(self.host, bytes):
self.host = self.host.encode('ascii')
self.sendCommand(b"GET", self.path) self.sendCommand(b"GET", self.path)
if self.host: if self.host:
self.sendHeader(b"Host", self.host) self.sendHeader(b"Host", self.host)

View file

@ -16,9 +16,10 @@
import hashlib import hashlib
import logging import logging
import urllib
from collections import namedtuple from collections import namedtuple
from six.moves import urllib
from signedjson.key import ( from signedjson.key import (
decode_verify_key_bytes, decode_verify_key_bytes,
encode_verify_key_base64, encode_verify_key_base64,
@ -40,6 +41,7 @@ from synapse.api.errors import Codes, SynapseError
from synapse.crypto.keyclient import fetch_server_key from synapse.crypto.keyclient import fetch_server_key
from synapse.util import logcontext, unwrapFirstError from synapse.util import logcontext, unwrapFirstError
from synapse.util.logcontext import ( from synapse.util.logcontext import (
LoggingContext,
PreserveLoggingContext, PreserveLoggingContext,
preserve_fn, preserve_fn,
run_in_background, run_in_background,
@ -216,23 +218,34 @@ class Keyring(object):
servers have completed. Follows the synapse rules of logcontext servers have completed. Follows the synapse rules of logcontext
preservation. preservation.
""" """
loop_count = 1
while True: while True:
wait_on = [ wait_on = [
self.key_downloads[server_name] (server_name, self.key_downloads[server_name])
for server_name in server_names for server_name in server_names
if server_name in self.key_downloads if server_name in self.key_downloads
] ]
if wait_on: if not wait_on:
with PreserveLoggingContext():
yield defer.DeferredList(wait_on)
else:
break break
logger.info(
"Waiting for existing lookups for %s to complete [loop %i]",
[w[0] for w in wait_on], loop_count,
)
with PreserveLoggingContext():
yield defer.DeferredList((w[1] for w in wait_on))
loop_count += 1
ctx = LoggingContext.current_context()
def rm(r, server_name_): def rm(r, server_name_):
self.key_downloads.pop(server_name_, None) with PreserveLoggingContext(ctx):
logger.debug("Releasing key lookup lock on %s", server_name_)
self.key_downloads.pop(server_name_, None)
return r return r
for server_name, deferred in server_to_deferred.items(): for server_name, deferred in server_to_deferred.items():
logger.debug("Got key lookup lock on %s", server_name)
self.key_downloads[server_name] = deferred self.key_downloads[server_name] = deferred
deferred.addBoth(rm, server_name) deferred.addBoth(rm, server_name)
@ -432,7 +445,7 @@ class Keyring(object):
# an incoming request. # an incoming request.
query_response = yield self.client.post_json( query_response = yield self.client.post_json(
destination=perspective_name, destination=perspective_name,
path=b"/_matrix/key/v2/query", path="/_matrix/key/v2/query",
data={ data={
u"server_keys": { u"server_keys": {
server_name: { server_name: {
@ -513,8 +526,8 @@ class Keyring(object):
(response, tls_certificate) = yield fetch_server_key( (response, tls_certificate) = yield fetch_server_key(
server_name, self.hs.tls_client_options_factory, server_name, self.hs.tls_client_options_factory,
path=(b"/_matrix/key/v2/server/%s" % ( path=("/_matrix/key/v2/server/%s" % (
urllib.quote(requested_key_id), urllib.parse.quote(requested_key_id),
)).encode("ascii"), )).encode("ascii"),
) )

View file

@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import six
from synapse.util.caches import intern_dict from synapse.util.caches import intern_dict
from synapse.util.frozenutils import freeze from synapse.util.frozenutils import freeze
@ -147,6 +149,9 @@ class EventBase(object):
def items(self): def items(self):
return list(self._event_dict.items()) return list(self._event_dict.items())
def keys(self):
return six.iterkeys(self._event_dict)
class FrozenEvent(EventBase): class FrozenEvent(EventBase):
def __init__(self, event_dict, internal_metadata_dict={}, rejected_reason=None): def __init__(self, event_dict, internal_metadata_dict={}, rejected_reason=None):

View file

@ -143,11 +143,31 @@ class FederationBase(object):
def callback(_, pdu): def callback(_, pdu):
with logcontext.PreserveLoggingContext(ctx): with logcontext.PreserveLoggingContext(ctx):
if not check_event_content_hash(pdu): if not check_event_content_hash(pdu):
logger.warn( # let's try to distinguish between failures because the event was
"Event content has been tampered, redacting %s: %s", # redacted (which are somewhat expected) vs actual ball-tampering
pdu.event_id, pdu.get_pdu_json() # incidents.
) #
return prune_event(pdu) # This is just a heuristic, so we just assume that if the keys are
# about the same between the redacted and received events, then the
# received event was probably a redacted copy (but we then use our
# *actual* redacted copy to be on the safe side.)
redacted_event = prune_event(pdu)
if (
set(redacted_event.keys()) == set(pdu.keys()) and
set(six.iterkeys(redacted_event.content))
== set(six.iterkeys(pdu.content))
):
logger.info(
"Event %s seems to have been redacted; using our redacted "
"copy",
pdu.event_id,
)
else:
logger.warning(
"Event %s content has been tampered, redacting",
pdu.event_id, pdu.get_pdu_json(),
)
return redacted_event
if self.spam_checker.check_event_for_spam(pdu): if self.spam_checker.check_event_for_spam(pdu):
logger.warn( logger.warn(
@ -162,8 +182,8 @@ class FederationBase(object):
failure.trap(SynapseError) failure.trap(SynapseError)
with logcontext.PreserveLoggingContext(ctx): with logcontext.PreserveLoggingContext(ctx):
logger.warn( logger.warn(
"Signature check failed for %s", "Signature check failed for %s: %s",
pdu.event_id, pdu.event_id, failure.getErrorMessage(),
) )
return failure return failure

View file

@ -271,10 +271,10 @@ class FederationClient(FederationBase):
event_id, destination, e, event_id, destination, e,
) )
except NotRetryingDestination as e: except NotRetryingDestination as e:
logger.info(e.message) logger.info(str(e))
continue continue
except FederationDeniedError as e: except FederationDeniedError as e:
logger.info(e.message) logger.info(str(e))
continue continue
except Exception as e: except Exception as e:
pdu_attempts[destination] = now pdu_attempts[destination] = now
@ -510,7 +510,7 @@ class FederationClient(FederationBase):
else: else:
logger.warn( logger.warn(
"Failed to %s via %s: %i %s", "Failed to %s via %s: %i %s",
description, destination, e.code, e.message, description, destination, e.code, e.args[0],
) )
except Exception: except Exception:
logger.warn( logger.warn(
@ -875,7 +875,7 @@ class FederationClient(FederationBase):
except Exception as e: except Exception as e:
logger.exception( logger.exception(
"Failed to send_third_party_invite via %s: %s", "Failed to send_third_party_invite via %s: %s",
destination, e.message destination, str(e)
) )
raise RuntimeError("Failed to send to any server.") raise RuntimeError("Failed to send to any server.")

View file

@ -838,9 +838,9 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
) )
return self._send_edu( return self._send_edu(
edu_type=edu_type, edu_type=edu_type,
origin=origin, origin=origin,
content=content, content=content,
) )
def on_query(self, query_type, args): def on_query(self, query_type, args):
@ -851,6 +851,6 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
return handler(args) return handler(args)
return self._get_query_client( return self._get_query_client(
query_type=query_type, query_type=query_type,
args=args, args=args,
) )

View file

@ -463,7 +463,19 @@ class TransactionQueue(object):
# pending_transactions flag. # pending_transactions flag.
pending_pdus = self.pending_pdus_by_dest.pop(destination, []) pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
# We can only include at most 50 PDUs per transactions
pending_pdus, leftover_pdus = pending_pdus[:50], pending_pdus[50:]
if leftover_pdus:
self.pending_pdus_by_dest[destination] = leftover_pdus
pending_edus = self.pending_edus_by_dest.pop(destination, []) pending_edus = self.pending_edus_by_dest.pop(destination, [])
# We can only include at most 100 EDUs per transactions
pending_edus, leftover_edus = pending_edus[:100], pending_edus[100:]
if leftover_edus:
self.pending_edus_by_dest[destination] = leftover_edus
pending_presence = self.pending_presence_by_dest.pop(destination, {}) pending_presence = self.pending_presence_by_dest.pop(destination, {})
pending_edus.extend( pending_edus.extend(

View file

@ -15,7 +15,8 @@
# limitations under the License. # limitations under the License.
import logging import logging
import urllib
from six.moves import urllib
from twisted.internet import defer from twisted.internet import defer
@ -951,4 +952,4 @@ def _create_path(prefix, path, *args):
Returns: Returns:
str str
""" """
return prefix + path % tuple(urllib.quote(arg, "") for arg in args) return prefix + path % tuple(urllib.parse.quote(arg, "") for arg in args)

View file

@ -90,8 +90,8 @@ class Authenticator(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def authenticate_request(self, request, content): def authenticate_request(self, request, content):
json_request = { json_request = {
"method": request.method, "method": request.method.decode('ascii'),
"uri": request.uri, "uri": request.uri.decode('ascii'),
"destination": self.server_name, "destination": self.server_name,
"signatures": {}, "signatures": {},
} }
@ -252,7 +252,7 @@ class BaseFederationServlet(object):
by the callback method. None if the request has already been handled. by the callback method. None if the request has already been handled.
""" """
content = None content = None
if request.method in ["PUT", "POST"]: if request.method in [b"PUT", b"POST"]:
# TODO: Handle other method types? other content types? # TODO: Handle other method types? other content types?
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
@ -386,7 +386,7 @@ class FederationStateServlet(BaseFederationServlet):
return self.handler.on_context_state_request( return self.handler.on_context_state_request(
origin, origin,
context, context,
query.get("event_id", [None])[0], parse_string_from_args(query, "event_id", None),
) )
@ -397,7 +397,7 @@ class FederationStateIdsServlet(BaseFederationServlet):
return self.handler.on_state_ids_request( return self.handler.on_state_ids_request(
origin, origin,
room_id, room_id,
query.get("event_id", [None])[0], parse_string_from_args(query, "event_id", None),
) )
@ -405,14 +405,12 @@ class FederationBackfillServlet(BaseFederationServlet):
PATH = "/backfill/(?P<context>[^/]*)/" PATH = "/backfill/(?P<context>[^/]*)/"
def on_GET(self, origin, content, query, context): def on_GET(self, origin, content, query, context):
versions = query["v"] versions = [x.decode('ascii') for x in query[b"v"]]
limits = query["limit"] limit = parse_integer_from_args(query, "limit", None)
if not limits: if not limit:
return defer.succeed((400, {"error": "Did not include limit param"})) return defer.succeed((400, {"error": "Did not include limit param"}))
limit = int(limits[-1])
return self.handler.on_backfill_request(origin, context, versions, limit) return self.handler.on_backfill_request(origin, context, versions, limit)
@ -423,7 +421,7 @@ class FederationQueryServlet(BaseFederationServlet):
def on_GET(self, origin, content, query, query_type): def on_GET(self, origin, content, query, query_type):
return self.handler.on_query_request( return self.handler.on_query_request(
query_type, query_type,
{k: v[0].decode("utf-8") for k, v in query.items()} {k.decode('utf8'): v[0].decode("utf-8") for k, v in query.items()}
) )
@ -630,14 +628,14 @@ class OpenIdUserInfo(BaseFederationServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, origin, content, query): def on_GET(self, origin, content, query):
token = query.get("access_token", [None])[0] token = query.get(b"access_token", [None])[0]
if token is None: if token is None:
defer.returnValue((401, { defer.returnValue((401, {
"errcode": "M_MISSING_TOKEN", "error": "Access Token required" "errcode": "M_MISSING_TOKEN", "error": "Access Token required"
})) }))
return return
user_id = yield self.handler.on_openid_userinfo(token) user_id = yield self.handler.on_openid_userinfo(token.decode('ascii'))
if user_id is None: if user_id is None:
defer.returnValue((401, { defer.returnValue((401, {

View file

@ -895,22 +895,24 @@ class AuthHandler(BaseHandler):
Args: Args:
password (unicode): Password to hash. password (unicode): Password to hash.
stored_hash (unicode): Expected hash value. stored_hash (bytes): Expected hash value.
Returns: Returns:
Deferred(bool): Whether self.hash(password) == stored_hash. Deferred(bool): Whether self.hash(password) == stored_hash.
""" """
def _do_validate_hash(): def _do_validate_hash():
# Normalise the Unicode in the password # Normalise the Unicode in the password
pw = unicodedata.normalize("NFKC", password) pw = unicodedata.normalize("NFKC", password)
return bcrypt.checkpw( return bcrypt.checkpw(
pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"), pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"),
stored_hash.encode('utf8') stored_hash
) )
if stored_hash: if stored_hash:
if not isinstance(stored_hash, bytes):
stored_hash = stored_hash.encode('ascii')
return make_deferred_yieldable( return make_deferred_yieldable(
threads.deferToThreadPool( threads.deferToThreadPool(
self.hs.get_reactor(), self.hs.get_reactor(),

View file

@ -330,7 +330,8 @@ class E2eKeysHandler(object):
(algorithm, key_id, ex_json, key) (algorithm, key_id, ex_json, key)
) )
else: else:
new_keys.append((algorithm, key_id, encode_canonical_json(key))) new_keys.append((
algorithm, key_id, encode_canonical_json(key).decode('ascii')))
yield self.store.add_e2e_one_time_keys( yield self.store.add_e2e_one_time_keys(
user_id, device_id, time_now, new_keys user_id, device_id, time_now, new_keys
@ -358,7 +359,7 @@ def _exception_to_failure(e):
# Note that some Exceptions (notably twisted's ResponseFailed etc) don't # Note that some Exceptions (notably twisted's ResponseFailed etc) don't
# give a string for e.message, which json then fails to serialize. # give a string for e.message, which json then fails to serialize.
return { return {
"status": 503, "message": str(e.message), "status": 503, "message": str(e),
} }

View file

@ -594,7 +594,7 @@ class FederationHandler(BaseHandler):
required_auth = set( required_auth = set(
a_id a_id
for event in events + state_events.values() + auth_events.values() for event in events + list(state_events.values()) + list(auth_events.values())
for a_id, _ in event.auth_events for a_id, _ in event.auth_events
) )
auth_events.update({ auth_events.update({
@ -802,7 +802,7 @@ class FederationHandler(BaseHandler):
) )
continue continue
except NotRetryingDestination as e: except NotRetryingDestination as e:
logger.info(e.message) logger.info(str(e))
continue continue
except FederationDeniedError as e: except FederationDeniedError as e:
logger.info(e) logger.info(e)
@ -1358,7 +1358,7 @@ class FederationHandler(BaseHandler):
) )
if state_groups: if state_groups:
_, state = state_groups.items().pop() _, state = list(state_groups.items()).pop()
results = state results = state
if event.is_state(): if event.is_state():

View file

@ -269,14 +269,7 @@ class PaginationHandler(object):
if state_ids: if state_ids:
state = yield self.store.get_events(list(state_ids.values())) state = yield self.store.get_events(list(state_ids.values()))
state = state.values()
if state:
state = yield filter_events_for_client(
self.store,
user_id,
state.values(),
is_peeking=(member_event_id is None),
)
time_now = self.clock.time_msec() time_now = self.clock.time_msec()

View file

@ -162,7 +162,7 @@ class RoomListHandler(BaseHandler):
# Filter out rooms that we don't want to return # Filter out rooms that we don't want to return
rooms_to_scan = [ rooms_to_scan = [
r for r in sorted_rooms r for r in sorted_rooms
if r not in newly_unpublished and rooms_to_num_joined[room_id] > 0 if r not in newly_unpublished and rooms_to_num_joined[r] > 0
] ]
total_room_count = len(rooms_to_scan) total_room_count = len(rooms_to_scan)

View file

@ -54,7 +54,7 @@ class SearchHandler(BaseHandler):
batch_token = None batch_token = None
if batch: if batch:
try: try:
b = decode_base64(batch) b = decode_base64(batch).decode('ascii')
batch_group, batch_group_key, batch_token = b.split("\n") batch_group, batch_group_key, batch_token = b.split("\n")
assert batch_group is not None assert batch_group is not None
@ -258,18 +258,18 @@ class SearchHandler(BaseHandler):
# it returns more from the same group (if applicable) rather # it returns more from the same group (if applicable) rather
# than reverting to searching all results again. # than reverting to searching all results again.
if batch_group and batch_group_key: if batch_group and batch_group_key:
global_next_batch = encode_base64("%s\n%s\n%s" % ( global_next_batch = encode_base64(("%s\n%s\n%s" % (
batch_group, batch_group_key, pagination_token batch_group, batch_group_key, pagination_token
)) )).encode('ascii'))
else: else:
global_next_batch = encode_base64("%s\n%s\n%s" % ( global_next_batch = encode_base64(("%s\n%s\n%s" % (
"all", "", pagination_token "all", "", pagination_token
)) )).encode('ascii'))
for room_id, group in room_groups.items(): for room_id, group in room_groups.items():
group["next_batch"] = encode_base64("%s\n%s\n%s" % ( group["next_batch"] = encode_base64(("%s\n%s\n%s" % (
"room_id", room_id, pagination_token "room_id", room_id, pagination_token
)) )).encode('ascii'))
allowed_events.extend(room_events) allowed_events.extend(room_events)

View file

@ -24,6 +24,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.push.clientformat import format_push_rules_for_user from synapse.push.clientformat import format_push_rules_for_user
from synapse.storage.roommember import MemberSummary
from synapse.types import RoomStreamToken from synapse.types import RoomStreamToken
from synapse.util.async_helpers import concurrently_execute from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
@ -525,6 +526,8 @@ class SyncHandler(object):
A deferred dict describing the room summary A deferred dict describing the room summary
""" """
# FIXME: we could/should get this from room_stats when matthew/stats lands
# FIXME: this promulgates https://github.com/matrix-org/synapse/issues/3305 # FIXME: this promulgates https://github.com/matrix-org/synapse/issues/3305
last_events, _ = yield self.store.get_recent_event_ids_for_room( last_events, _ = yield self.store.get_recent_event_ids_for_room(
room_id, end_token=now_token.room_key, limit=1, room_id, end_token=now_token.room_key, limit=1,
@ -537,44 +540,67 @@ class SyncHandler(object):
last_event = last_events[-1] last_event = last_events[-1]
state_ids = yield self.store.get_state_ids_for_event( state_ids = yield self.store.get_state_ids_for_event(
last_event.event_id, [ last_event.event_id, [
(EventTypes.Member, None),
(EventTypes.Name, ''), (EventTypes.Name, ''),
(EventTypes.CanonicalAlias, ''), (EventTypes.CanonicalAlias, ''),
] ]
) )
member_ids = { # this is heavily cached, thus: fast.
state_key: event_id details = yield self.store.get_room_summary(room_id)
for (t, state_key), event_id in state_ids.iteritems()
if t == EventTypes.Member
}
name_id = state_ids.get((EventTypes.Name, '')) name_id = state_ids.get((EventTypes.Name, ''))
canonical_alias_id = state_ids.get((EventTypes.CanonicalAlias, '')) canonical_alias_id = state_ids.get((EventTypes.CanonicalAlias, ''))
summary = {} summary = {}
empty_ms = MemberSummary([], 0)
# FIXME: it feels very heavy to load up every single membership event
# just to calculate the counts.
member_events = yield self.store.get_events(member_ids.values())
joined_user_ids = []
invited_user_ids = []
for ev in member_events.values():
if ev.content.get("membership") == Membership.JOIN:
joined_user_ids.append(ev.state_key)
elif ev.content.get("membership") == Membership.INVITE:
invited_user_ids.append(ev.state_key)
# TODO: only send these when they change. # TODO: only send these when they change.
summary["m.joined_member_count"] = len(joined_user_ids) summary["m.joined_member_count"] = (
summary["m.invited_member_count"] = len(invited_user_ids) details.get(Membership.JOIN, empty_ms).count
)
summary["m.invited_member_count"] = (
details.get(Membership.INVITE, empty_ms).count
)
if name_id or canonical_alias_id: # if the room has a name or canonical_alias set, we can skip
defer.returnValue(summary) # calculating heroes. we assume that if the event has contents, it'll
# be a valid name or canonical_alias - i.e. we're checking that they
# haven't been "deleted" by blatting {} over the top.
if name_id:
name = yield self.store.get_event(name_id, allow_none=False)
if name and name.content:
defer.returnValue(summary)
# FIXME: order by stream ordering, not alphabetic if canonical_alias_id:
canonical_alias = yield self.store.get_event(
canonical_alias_id, allow_none=False,
)
if canonical_alias and canonical_alias.content:
defer.returnValue(summary)
joined_user_ids = [
r[0] for r in details.get(Membership.JOIN, empty_ms).members
]
invited_user_ids = [
r[0] for r in details.get(Membership.INVITE, empty_ms).members
]
gone_user_ids = (
[r[0] for r in details.get(Membership.LEAVE, empty_ms).members] +
[r[0] for r in details.get(Membership.BAN, empty_ms).members]
)
# FIXME: only build up a member_ids list for our heroes
member_ids = {}
for membership in (
Membership.JOIN,
Membership.INVITE,
Membership.LEAVE,
Membership.BAN
):
for user_id, event_id in details.get(membership, empty_ms).members:
member_ids[user_id] = event_id
# FIXME: order by stream ordering rather than as returned by SQL
me = sync_config.user.to_string() me = sync_config.user.to_string()
if (joined_user_ids or invited_user_ids): if (joined_user_ids or invited_user_ids):
summary['m.heroes'] = sorted( summary['m.heroes'] = sorted(
@ -586,7 +612,11 @@ class SyncHandler(object):
)[0:5] )[0:5]
else: else:
summary['m.heroes'] = sorted( summary['m.heroes'] = sorted(
[user_id for user_id in member_ids.keys() if user_id != me] [
user_id
for user_id in gone_user_ids
if user_id != me
]
)[0:5] )[0:5]
if not sync_config.filter_collection.lazy_load_members(): if not sync_config.filter_collection.lazy_load_members():
@ -719,6 +749,26 @@ class SyncHandler(object):
lazy_load_members=lazy_load_members, lazy_load_members=lazy_load_members,
) )
elif batch.limited: elif batch.limited:
state_at_timeline_start = yield self.store.get_state_ids_for_event(
batch.events[0].event_id, types=types,
filtered_types=filtered_types,
)
# for now, we disable LL for gappy syncs - see
# https://github.com/vector-im/riot-web/issues/7211#issuecomment-419976346
# N.B. this slows down incr syncs as we are now processing way
# more state in the server than if we were LLing.
#
# We still have to filter timeline_start to LL entries (above) in order
# for _calculate_state's LL logic to work, as we have to include LL
# members for timeline senders in case they weren't loaded in the initial
# sync. We do this by (counterintuitively) by filtering timeline_start
# members to just be ones which were timeline senders, which then ensures
# all of the rest get included in the state block (if we need to know
# about them).
types = None
filtered_types = None
state_at_previous_sync = yield self.get_state_at( state_at_previous_sync = yield self.get_state_at(
room_id, stream_position=since_token, types=types, room_id, stream_position=since_token, types=types,
filtered_types=filtered_types, filtered_types=filtered_types,
@ -729,24 +779,21 @@ class SyncHandler(object):
filtered_types=filtered_types, filtered_types=filtered_types,
) )
state_at_timeline_start = yield self.store.get_state_ids_for_event(
batch.events[0].event_id, types=types,
filtered_types=filtered_types,
)
state_ids = _calculate_state( state_ids = _calculate_state(
timeline_contains=timeline_state, timeline_contains=timeline_state,
timeline_start=state_at_timeline_start, timeline_start=state_at_timeline_start,
previous=state_at_previous_sync, previous=state_at_previous_sync,
current=current_state_ids, current=current_state_ids,
# we have to include LL members in case LL initial sync missed them
lazy_load_members=lazy_load_members, lazy_load_members=lazy_load_members,
) )
else: else:
state_ids = {} state_ids = {}
if lazy_load_members: if lazy_load_members:
if types: if types:
# We're returning an incremental sync, with no "gap" since # We're returning an incremental sync, with no
# the previous sync, so normally there would be no state to return # "gap" since the previous sync, so normally there would be
# no state to return.
# But we're lazy-loading, so the client might need some more # But we're lazy-loading, so the client might need some more
# member events to understand the events in this timeline. # member events to understand the events in this timeline.
# So we fish out all the member events corresponding to the # So we fish out all the member events corresponding to the
@ -774,7 +821,7 @@ class SyncHandler(object):
logger.debug("filtering state from %r...", state_ids) logger.debug("filtering state from %r...", state_ids)
state_ids = { state_ids = {
t: event_id t: event_id
for t, event_id in state_ids.iteritems() for t, event_id in iteritems(state_ids)
if cache.get(t[1]) != event_id if cache.get(t[1]) != event_id
} }
logger.debug("...to %r", state_ids) logger.debug("...to %r", state_ids)
@ -1575,6 +1622,19 @@ class SyncHandler(object):
newly_joined_room=newly_joined, newly_joined_room=newly_joined,
) )
# When we join the room (or the client requests full_state), we should
# send down any existing tags. Usually the user won't have tags in a
# newly joined room, unless either a) they've joined before or b) the
# tag was added by synapse e.g. for server notice rooms.
if full_state:
user_id = sync_result_builder.sync_config.user.to_string()
tags = yield self.store.get_tags_for_room(user_id, room_id)
# If there aren't any tags, don't send the empty tags list down
# sync
if not tags:
tags = None
account_data_events = [] account_data_events = []
if tags is not None: if tags is not None:
account_data_events.append({ account_data_events.append({
@ -1603,10 +1663,24 @@ class SyncHandler(object):
) )
summary = {} summary = {}
# we include a summary in room responses when we're lazy loading
# members (as the client otherwise doesn't have enough info to form
# the name itself).
if ( if (
sync_config.filter_collection.lazy_load_members() and sync_config.filter_collection.lazy_load_members() and
( (
# we recalulate the summary:
# if there are membership changes in the timeline, or
# if membership has changed during a gappy sync, or
# if this is an initial sync.
any(ev.type == EventTypes.Member for ev in batch.events) or any(ev.type == EventTypes.Member for ev in batch.events) or
(
# XXX: this may include false positives in the form of LL
# members which have snuck into state
batch.limited and
any(t == EventTypes.Member for (t, k) in state)
) or
since_token is None since_token is None
) )
): ):
@ -1636,6 +1710,16 @@ class SyncHandler(object):
unread_notifications["highlight_count"] = notifs["highlight_count"] unread_notifications["highlight_count"] = notifs["highlight_count"]
sync_result_builder.joined.append(room_sync) sync_result_builder.joined.append(room_sync)
if batch.limited and since_token:
user_id = sync_result_builder.sync_config.user.to_string()
logger.info(
"Incremental gappy sync of %s for user %s with %d state events" % (
room_id,
user_id,
len(state),
)
)
elif room_builder.rtype == "archived": elif room_builder.rtype == "archived":
room_sync = ArchivedSyncResult( room_sync = ArchivedSyncResult(
room_id=room_id, room_id=room_id,
@ -1729,17 +1813,17 @@ def _calculate_state(
event_id_to_key = { event_id_to_key = {
e: key e: key
for key, e in itertools.chain( for key, e in itertools.chain(
timeline_contains.items(), iteritems(timeline_contains),
previous.items(), iteritems(previous),
timeline_start.items(), iteritems(timeline_start),
current.items(), iteritems(current),
) )
} }
c_ids = set(e for e in current.values()) c_ids = set(e for e in itervalues(current))
ts_ids = set(e for e in timeline_start.values()) ts_ids = set(e for e in itervalues(timeline_start))
p_ids = set(e for e in previous.values()) p_ids = set(e for e in itervalues(previous))
tc_ids = set(e for e in timeline_contains.values()) tc_ids = set(e for e in itervalues(timeline_contains))
# If we are lazyloading room members, we explicitly add the membership events # If we are lazyloading room members, we explicitly add the membership events
# for the senders in the timeline into the state block returned by /sync, # for the senders in the timeline into the state block returned by /sync,
@ -1753,7 +1837,7 @@ def _calculate_state(
if lazy_load_members: if lazy_load_members:
p_ids.difference_update( p_ids.difference_update(
e for t, e in timeline_start.iteritems() e for t, e in iteritems(timeline_start)
if t[0] == EventTypes.Member if t[0] == EventTypes.Member
) )

View file

@ -38,12 +38,12 @@ def cancelled_to_request_timed_out_error(value, timeout):
return value return value
ACCESS_TOKEN_RE = re.compile(br'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$') ACCESS_TOKEN_RE = re.compile(r'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$')
def redact_uri(uri): def redact_uri(uri):
"""Strips access tokens from the uri replaces with <redacted>""" """Strips access tokens from the uri replaces with <redacted>"""
return ACCESS_TOKEN_RE.sub( return ACCESS_TOKEN_RE.sub(
br'\1<redacted>\3', r'\1<redacted>\3',
uri uri
) )

View file

@ -13,24 +13,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
import urllib
from six import StringIO from six import text_type
from six.moves import urllib
import treq
from canonicaljson import encode_canonical_json, json from canonicaljson import encode_canonical_json, json
from prometheus_client import Counter from prometheus_client import Counter
from OpenSSL import SSL from OpenSSL import SSL
from OpenSSL.SSL import VERIFY_NONE from OpenSSL.SSL import VERIFY_NONE
from twisted.internet import defer, protocol, reactor, ssl, task from twisted.internet import defer, protocol, reactor, ssl
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.web._newclient import ResponseDone from twisted.web._newclient import ResponseDone
from twisted.web.client import ( from twisted.web.client import (
Agent, Agent,
BrowserLikeRedirectAgent, BrowserLikeRedirectAgent,
ContentDecoderAgent, ContentDecoderAgent,
FileBodyProducer as TwistedFileBodyProducer,
GzipDecoder, GzipDecoder,
HTTPConnectionPool, HTTPConnectionPool,
PartialDownloadError, PartialDownloadError,
@ -83,8 +84,10 @@ class SimpleHttpClient(object):
if hs.config.user_agent_suffix: if hs.config.user_agent_suffix:
self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix,) self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix,)
self.user_agent = self.user_agent.encode('ascii')
@defer.inlineCallbacks @defer.inlineCallbacks
def request(self, method, uri, *args, **kwargs): def request(self, method, uri, data=b'', headers=None):
# A small wrapper around self.agent.request() so we can easily attach # A small wrapper around self.agent.request() so we can easily attach
# counters to it # counters to it
outgoing_requests_counter.labels(method).inc() outgoing_requests_counter.labels(method).inc()
@ -93,8 +96,8 @@ class SimpleHttpClient(object):
logger.info("Sending request %s %s", method, redact_uri(uri)) logger.info("Sending request %s %s", method, redact_uri(uri))
try: try:
request_deferred = self.agent.request( request_deferred = treq.request(
method, uri, *args, **kwargs method, uri, agent=self.agent, data=data, headers=headers
) )
add_timeout_to_deferred( add_timeout_to_deferred(
request_deferred, 60, self.hs.get_reactor(), request_deferred, 60, self.hs.get_reactor(),
@ -112,7 +115,7 @@ class SimpleHttpClient(object):
incoming_responses_counter.labels(method, "ERR").inc() incoming_responses_counter.labels(method, "ERR").inc()
logger.info( logger.info(
"Error sending request to %s %s: %s %s", "Error sending request to %s %s: %s %s",
method, redact_uri(uri), type(e).__name__, e.message method, redact_uri(uri), type(e).__name__, e.args[0]
) )
raise raise
@ -137,7 +140,8 @@ class SimpleHttpClient(object):
# TODO: Do we ever want to log message contents? # TODO: Do we ever want to log message contents?
logger.debug("post_urlencoded_get_json args: %s", args) logger.debug("post_urlencoded_get_json args: %s", args)
query_bytes = urllib.urlencode(encode_urlencode_args(args), True) query_bytes = urllib.parse.urlencode(
encode_urlencode_args(args), True).encode("utf8")
actual_headers = { actual_headers = {
b"Content-Type": [b"application/x-www-form-urlencoded"], b"Content-Type": [b"application/x-www-form-urlencoded"],
@ -148,15 +152,14 @@ class SimpleHttpClient(object):
response = yield self.request( response = yield self.request(
"POST", "POST",
uri.encode("ascii"), uri,
headers=Headers(actual_headers), headers=Headers(actual_headers),
bodyProducer=FileBodyProducer(StringIO(query_bytes)) data=query_bytes
) )
body = yield make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300: if 200 <= response.code < 300:
defer.returnValue(json.loads(body)) body = yield make_deferred_yieldable(treq.json_content(response))
defer.returnValue(body)
else: else:
raise HttpResponseException(response.code, response.phrase, body) raise HttpResponseException(response.code, response.phrase, body)
@ -191,9 +194,9 @@ class SimpleHttpClient(object):
response = yield self.request( response = yield self.request(
"POST", "POST",
uri.encode("ascii"), uri,
headers=Headers(actual_headers), headers=Headers(actual_headers),
bodyProducer=FileBodyProducer(StringIO(json_str)) data=json_str
) )
body = yield make_deferred_yieldable(readBody(response)) body = yield make_deferred_yieldable(readBody(response))
@ -248,7 +251,7 @@ class SimpleHttpClient(object):
ValueError: if the response was not JSON ValueError: if the response was not JSON
""" """
if len(args): if len(args):
query_bytes = urllib.urlencode(args, True) query_bytes = urllib.parse.urlencode(args, True)
uri = "%s?%s" % (uri, query_bytes) uri = "%s?%s" % (uri, query_bytes)
json_str = encode_canonical_json(json_body) json_str = encode_canonical_json(json_body)
@ -262,9 +265,9 @@ class SimpleHttpClient(object):
response = yield self.request( response = yield self.request(
"PUT", "PUT",
uri.encode("ascii"), uri,
headers=Headers(actual_headers), headers=Headers(actual_headers),
bodyProducer=FileBodyProducer(StringIO(json_str)) data=json_str
) )
body = yield make_deferred_yieldable(readBody(response)) body = yield make_deferred_yieldable(readBody(response))
@ -293,7 +296,7 @@ class SimpleHttpClient(object):
HttpResponseException on a non-2xx HTTP response. HttpResponseException on a non-2xx HTTP response.
""" """
if len(args): if len(args):
query_bytes = urllib.urlencode(args, True) query_bytes = urllib.parse.urlencode(args, True)
uri = "%s?%s" % (uri, query_bytes) uri = "%s?%s" % (uri, query_bytes)
actual_headers = { actual_headers = {
@ -304,7 +307,7 @@ class SimpleHttpClient(object):
response = yield self.request( response = yield self.request(
"GET", "GET",
uri.encode("ascii"), uri,
headers=Headers(actual_headers), headers=Headers(actual_headers),
) )
@ -339,13 +342,14 @@ class SimpleHttpClient(object):
response = yield self.request( response = yield self.request(
"GET", "GET",
url.encode("ascii"), url,
headers=Headers(actual_headers), headers=Headers(actual_headers),
) )
resp_headers = dict(response.headers.getAllRawHeaders()) resp_headers = dict(response.headers.getAllRawHeaders())
if 'Content-Length' in resp_headers and resp_headers['Content-Length'] > max_size: if (b'Content-Length' in resp_headers and
int(resp_headers[b'Content-Length']) > max_size):
logger.warn("Requested URL is too large > %r bytes" % (self.max_size,)) logger.warn("Requested URL is too large > %r bytes" % (self.max_size,))
raise SynapseError( raise SynapseError(
502, 502,
@ -378,7 +382,12 @@ class SimpleHttpClient(object):
) )
defer.returnValue( defer.returnValue(
(length, resp_headers, response.request.absoluteURI, response.code), (
length,
resp_headers,
response.request.absoluteURI.decode('ascii'),
response.code,
),
) )
@ -434,12 +443,12 @@ class CaptchaServerHttpClient(SimpleHttpClient):
@defer.inlineCallbacks @defer.inlineCallbacks
def post_urlencoded_get_raw(self, url, args={}): def post_urlencoded_get_raw(self, url, args={}):
query_bytes = urllib.urlencode(encode_urlencode_args(args), True) query_bytes = urllib.parse.urlencode(encode_urlencode_args(args), True)
response = yield self.request( response = yield self.request(
"POST", "POST",
url.encode("ascii"), url,
bodyProducer=FileBodyProducer(StringIO(query_bytes)), data=query_bytes,
headers=Headers({ headers=Headers({
b"Content-Type": [b"application/x-www-form-urlencoded"], b"Content-Type": [b"application/x-www-form-urlencoded"],
b"User-Agent": [self.user_agent], b"User-Agent": [self.user_agent],
@ -463,9 +472,9 @@ class SpiderEndpointFactory(object):
def endpointForURI(self, uri): def endpointForURI(self, uri):
logger.info("Getting endpoint for %s", uri.toBytes()) logger.info("Getting endpoint for %s", uri.toBytes())
if uri.scheme == "http": if uri.scheme == b"http":
endpoint_factory = HostnameEndpoint endpoint_factory = HostnameEndpoint
elif uri.scheme == "https": elif uri.scheme == b"https":
tlsCreator = self.policyForHTTPS.creatorForNetloc(uri.host, uri.port) tlsCreator = self.policyForHTTPS.creatorForNetloc(uri.host, uri.port)
def endpoint_factory(reactor, host, port, **kw): def endpoint_factory(reactor, host, port, **kw):
@ -510,7 +519,7 @@ def encode_urlencode_args(args):
def encode_urlencode_arg(arg): def encode_urlencode_arg(arg):
if isinstance(arg, unicode): if isinstance(arg, text_type):
return arg.encode('utf-8') return arg.encode('utf-8')
elif isinstance(arg, list): elif isinstance(arg, list):
return [encode_urlencode_arg(i) for i in arg] return [encode_urlencode_arg(i) for i in arg]
@ -542,26 +551,3 @@ class InsecureInterceptableContextFactory(ssl.ContextFactory):
def creatorForNetloc(self, hostname, port): def creatorForNetloc(self, hostname, port):
return self return self
class FileBodyProducer(TwistedFileBodyProducer):
"""Workaround for https://twistedmatrix.com/trac/ticket/8473
We override the pauseProducing and resumeProducing methods in twisted's
FileBodyProducer so that they do not raise exceptions if the task has
already completed.
"""
def pauseProducing(self):
try:
super(FileBodyProducer, self).pauseProducing()
except task.TaskDone:
# task has already completed
pass
def resumeProducing(self):
try:
super(FileBodyProducer, self).resumeProducing()
except task.NotPaused:
# task was not paused (probably because it had already completed)
pass

View file

@ -17,19 +17,19 @@ import cgi
import logging import logging
import random import random
import sys import sys
import urllib
from six import string_types from six import PY3, string_types
from six.moves.urllib import parse as urlparse from six.moves import urllib
from canonicaljson import encode_canonical_json, json import treq
from canonicaljson import encode_canonical_json
from prometheus_client import Counter from prometheus_client import Counter
from signedjson.sign import sign_json from signedjson.sign import sign_json
from twisted.internet import defer, protocol, reactor from twisted.internet import defer, protocol
from twisted.internet.error import DNSLookupError from twisted.internet.error import DNSLookupError
from twisted.web._newclient import ResponseDone from twisted.web._newclient import ResponseDone
from twisted.web.client import Agent, HTTPConnectionPool, readBody from twisted.web.client import Agent, HTTPConnectionPool
from twisted.web.http_headers import Headers from twisted.web.http_headers import Headers
import synapse.metrics import synapse.metrics
@ -40,11 +40,11 @@ from synapse.api.errors import (
HttpResponseException, HttpResponseException,
SynapseError, SynapseError,
) )
from synapse.http import cancelled_to_request_timed_out_error
from synapse.http.endpoint import matrix_federation_endpoint from synapse.http.endpoint import matrix_federation_endpoint
from synapse.util import logcontext from synapse.util import logcontext
from synapse.util.async_helpers import add_timeout_to_deferred from synapse.util.async_helpers import timeout_no_seriously
from synapse.util.logcontext import make_deferred_yieldable from synapse.util.logcontext import make_deferred_yieldable
from synapse.util.metrics import Measure
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
outbound_logger = logging.getLogger("synapse.http.outbound") outbound_logger = logging.getLogger("synapse.http.outbound")
@ -58,16 +58,22 @@ incoming_responses_counter = Counter("synapse_http_matrixfederationclient_respon
MAX_LONG_RETRIES = 10 MAX_LONG_RETRIES = 10
MAX_SHORT_RETRIES = 3 MAX_SHORT_RETRIES = 3
if PY3:
MAXINT = sys.maxsize
else:
MAXINT = sys.maxint
class MatrixFederationEndpointFactory(object): class MatrixFederationEndpointFactory(object):
def __init__(self, hs): def __init__(self, hs):
self.reactor = hs.get_reactor()
self.tls_client_options_factory = hs.tls_client_options_factory self.tls_client_options_factory = hs.tls_client_options_factory
def endpointForURI(self, uri): def endpointForURI(self, uri):
destination = uri.netloc destination = uri.netloc.decode('ascii')
return matrix_federation_endpoint( return matrix_federation_endpoint(
reactor, destination, timeout=10, self.reactor, destination, timeout=10,
tls_client_options_factory=self.tls_client_options_factory tls_client_options_factory=self.tls_client_options_factory
) )
@ -85,7 +91,9 @@ class MatrixFederationHttpClient(object):
self.hs = hs self.hs = hs
self.signing_key = hs.config.signing_key[0] self.signing_key = hs.config.signing_key[0]
self.server_name = hs.hostname self.server_name = hs.hostname
reactor = hs.get_reactor()
pool = HTTPConnectionPool(reactor) pool = HTTPConnectionPool(reactor)
pool.retryAutomatically = False
pool.maxPersistentPerHost = 5 pool.maxPersistentPerHost = 5
pool.cachedConnectionTimeout = 2 * 60 pool.cachedConnectionTimeout = 2 * 60
self.agent = Agent.usingEndpointFactory( self.agent = Agent.usingEndpointFactory(
@ -93,26 +101,33 @@ class MatrixFederationHttpClient(object):
) )
self.clock = hs.get_clock() self.clock = hs.get_clock()
self._store = hs.get_datastore() self._store = hs.get_datastore()
self.version_string = hs.version_string self.version_string = hs.version_string.encode('ascii')
self._next_id = 1 self._next_id = 1
self.default_timeout = 60
def _create_url(self, destination, path_bytes, param_bytes, query_bytes): def _create_url(self, destination, path_bytes, param_bytes, query_bytes):
return urlparse.urlunparse( return urllib.parse.urlunparse(
("matrix", destination, path_bytes, param_bytes, query_bytes, "") (b"matrix", destination, path_bytes, param_bytes, query_bytes, b"")
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def _request(self, destination, method, path, def _request(self, destination, method, path,
body_callback, headers_dict={}, param_bytes=b"", json=None, json_callback=None,
query_bytes=b"", retry_on_dns_fail=True, param_bytes=b"",
query=None, retry_on_dns_fail=True,
timeout=None, long_retries=False, timeout=None, long_retries=False,
ignore_backoff=False, ignore_backoff=False,
backoff_on_404=False): backoff_on_404=False):
""" Creates and sends a request to the given server """
Creates and sends a request to the given server.
Args: Args:
destination (str): The remote server to send the HTTP request to. destination (str): The remote server to send the HTTP request to.
method (str): HTTP method method (str): HTTP method
path (str): The HTTP path path (str): The HTTP path
json (dict or None): JSON to send in the body.
json_callback (func or None): A callback to generate the JSON.
query (dict or None): Query arguments.
ignore_backoff (bool): true to ignore the historical backoff data ignore_backoff (bool): true to ignore the historical backoff data
and try the request anyway. and try the request anyway.
backoff_on_404 (bool): Back off if we get a 404 backoff_on_404 (bool): Back off if we get a 404
@ -132,6 +147,11 @@ class MatrixFederationHttpClient(object):
(May also fail with plenty of other Exceptions for things like DNS (May also fail with plenty of other Exceptions for things like DNS
failures, connection failures, SSL failures.) failures, connection failures, SSL failures.)
""" """
if timeout:
_sec_timeout = timeout / 1000
else:
_sec_timeout = self.default_timeout
if ( if (
self.hs.config.federation_domain_whitelist is not None and self.hs.config.federation_domain_whitelist is not None and
destination not in self.hs.config.federation_domain_whitelist destination not in self.hs.config.federation_domain_whitelist
@ -146,23 +166,25 @@ class MatrixFederationHttpClient(object):
ignore_backoff=ignore_backoff, ignore_backoff=ignore_backoff,
) )
destination = destination.encode("ascii") headers_dict = {}
path_bytes = path.encode("ascii") path_bytes = path.encode("ascii")
with limiter: if query:
headers_dict[b"User-Agent"] = [self.version_string] query_bytes = encode_query_args(query)
headers_dict[b"Host"] = [destination] else:
query_bytes = b""
url_bytes = self._create_url( headers_dict = {
destination, path_bytes, param_bytes, query_bytes "User-Agent": [self.version_string],
) "Host": [destination],
}
with limiter:
url = self._create_url(
destination.encode("ascii"), path_bytes, param_bytes, query_bytes
).decode('ascii')
txn_id = "%s-O-%s" % (method, self._next_id) txn_id = "%s-O-%s" % (method, self._next_id)
self._next_id = (self._next_id + 1) % (sys.maxint - 1) self._next_id = (self._next_id + 1) % (MAXINT - 1)
outbound_logger.info(
"{%s} [%s] Sending request: %s %s",
txn_id, destination, method, url_bytes
)
# XXX: Would be much nicer to retry only at the transaction-layer # XXX: Would be much nicer to retry only at the transaction-layer
# (once we have reliable transactions in place) # (once we have reliable transactions in place)
@ -171,80 +193,110 @@ class MatrixFederationHttpClient(object):
else: else:
retries_left = MAX_SHORT_RETRIES retries_left = MAX_SHORT_RETRIES
http_url_bytes = urlparse.urlunparse( http_url = urllib.parse.urlunparse(
("", "", path_bytes, param_bytes, query_bytes, "") (b"", b"", path_bytes, param_bytes, query_bytes, b"")
) ).decode('ascii')
log_result = None log_result = None
try: while True:
while True: try:
producer = None if json_callback:
if body_callback: json = json_callback()
producer = body_callback(method, http_url_bytes, headers_dict)
try: if json:
request_deferred = self.agent.request( data = encode_canonical_json(json)
method, headers_dict["Content-Type"] = ["application/json"]
url_bytes, self.sign_request(
Headers(headers_dict), destination, method, http_url, headers_dict, json
producer
)
add_timeout_to_deferred(
request_deferred,
timeout / 1000. if timeout else 60,
self.hs.get_reactor(),
cancelled_to_request_timed_out_error,
) )
else:
data = None
self.sign_request(destination, method, http_url, headers_dict)
outbound_logger.info(
"{%s} [%s] Sending request: %s %s",
txn_id, destination, method, url
)
request_deferred = treq.request(
method,
url,
headers=Headers(headers_dict),
data=data,
agent=self.agent,
reactor=self.hs.get_reactor(),
unbuffered=True
)
request_deferred.addTimeout(_sec_timeout, self.hs.get_reactor())
# Sometimes the timeout above doesn't work, so lets hack yet
# another layer of timeouts in in the vain hope that at some
# point the world made sense and this really really really
# should work.
request_deferred = timeout_no_seriously(
request_deferred,
timeout=_sec_timeout * 2,
reactor=self.hs.get_reactor(),
)
with Measure(self.clock, "outbound_request"):
response = yield make_deferred_yieldable( response = yield make_deferred_yieldable(
request_deferred, request_deferred,
) )
log_result = "%d %s" % (response.code, response.phrase,) log_result = "%d %s" % (response.code, response.phrase,)
break break
except Exception as e: except Exception as e:
if not retry_on_dns_fail and isinstance(e, DNSLookupError): if not retry_on_dns_fail and isinstance(e, DNSLookupError):
logger.warn(
"DNS Lookup failed to %s with %s",
destination,
e
)
log_result = "DNS Lookup failed to %s with %s" % (
destination, e
)
raise
logger.warn( logger.warn(
"{%s} Sending request failed to %s: %s %s: %s", "DNS Lookup failed to %s with %s",
txn_id,
destination, destination,
method, e
url_bytes, )
_flatten_response_never_received(e), log_result = "DNS Lookup failed to %s with %s" % (
destination, e
)
raise
logger.warn(
"{%s} Sending request failed to %s: %s %s: %s",
txn_id,
destination,
method,
url,
_flatten_response_never_received(e),
)
log_result = _flatten_response_never_received(e)
if retries_left and not timeout:
if long_retries:
delay = 4 ** (MAX_LONG_RETRIES + 1 - retries_left)
delay = min(delay, 60)
delay *= random.uniform(0.8, 1.4)
else:
delay = 0.5 * 2 ** (MAX_SHORT_RETRIES - retries_left)
delay = min(delay, 2)
delay *= random.uniform(0.8, 1.4)
logger.debug(
"{%s} Waiting %s before sending to %s...",
txn_id,
delay,
destination
) )
log_result = _flatten_response_never_received(e) yield self.clock.sleep(delay)
retries_left -= 1
if retries_left and not timeout: else:
if long_retries: raise
delay = 4 ** (MAX_LONG_RETRIES + 1 - retries_left) finally:
delay = min(delay, 60) outbound_logger.info(
delay *= random.uniform(0.8, 1.4) "{%s} [%s] Result: %s",
else: txn_id,
delay = 0.5 * 2 ** (MAX_SHORT_RETRIES - retries_left) destination,
delay = min(delay, 2) log_result,
delay *= random.uniform(0.8, 1.4) )
yield self.clock.sleep(delay)
retries_left -= 1
else:
raise
finally:
outbound_logger.info(
"{%s} [%s] Result: %s",
txn_id,
destination,
log_result,
)
if 200 <= response.code < 300: if 200 <= response.code < 300:
pass pass
@ -252,7 +304,9 @@ class MatrixFederationHttpClient(object):
# :'( # :'(
# Update transactions table? # Update transactions table?
with logcontext.PreserveLoggingContext(): with logcontext.PreserveLoggingContext():
body = yield readBody(response) d = treq.content(response)
d.addTimeout(_sec_timeout, self.hs.get_reactor())
body = yield make_deferred_yieldable(d)
raise HttpResponseException( raise HttpResponseException(
response.code, response.phrase, body response.code, response.phrase, body
) )
@ -297,11 +351,11 @@ class MatrixFederationHttpClient(object):
auth_headers = [] auth_headers = []
for key, sig in request["signatures"][self.server_name].items(): for key, sig in request["signatures"][self.server_name].items():
auth_headers.append(bytes( auth_headers.append((
"X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % ( "X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % (
self.server_name, key, sig, self.server_name, key, sig,
) )).encode('ascii')
)) )
headers_dict[b"Authorization"] = auth_headers headers_dict[b"Authorization"] = auth_headers
@ -347,24 +401,14 @@ class MatrixFederationHttpClient(object):
""" """
if not json_data_callback: if not json_data_callback:
def json_data_callback(): json_data_callback = lambda: data
return data
def body_callback(method, url_bytes, headers_dict):
json_data = json_data_callback()
self.sign_request(
destination, method, url_bytes, headers_dict, json_data
)
producer = _JsonProducer(json_data)
return producer
response = yield self._request( response = yield self._request(
destination, destination,
"PUT", "PUT",
path, path,
body_callback=body_callback, json_callback=json_data_callback,
headers_dict={"Content-Type": ["application/json"]}, query=args,
query_bytes=encode_query_args(args),
long_retries=long_retries, long_retries=long_retries,
timeout=timeout, timeout=timeout,
ignore_backoff=ignore_backoff, ignore_backoff=ignore_backoff,
@ -376,8 +420,10 @@ class MatrixFederationHttpClient(object):
check_content_type_is_json(response.headers) check_content_type_is_json(response.headers)
with logcontext.PreserveLoggingContext(): with logcontext.PreserveLoggingContext():
body = yield readBody(response) d = treq.json_content(response)
defer.returnValue(json.loads(body)) d.addTimeout(self.default_timeout, self.hs.get_reactor())
body = yield make_deferred_yieldable(d)
defer.returnValue(body)
@defer.inlineCallbacks @defer.inlineCallbacks
def post_json(self, destination, path, data={}, long_retries=False, def post_json(self, destination, path, data={}, long_retries=False,
@ -410,20 +456,12 @@ class MatrixFederationHttpClient(object):
Fails with ``FederationDeniedError`` if this destination Fails with ``FederationDeniedError`` if this destination
is not on our federation whitelist is not on our federation whitelist
""" """
def body_callback(method, url_bytes, headers_dict):
self.sign_request(
destination, method, url_bytes, headers_dict, data
)
return _JsonProducer(data)
response = yield self._request( response = yield self._request(
destination, destination,
"POST", "POST",
path, path,
query_bytes=encode_query_args(args), query=args,
body_callback=body_callback, json=data,
headers_dict={"Content-Type": ["application/json"]},
long_retries=long_retries, long_retries=long_retries,
timeout=timeout, timeout=timeout,
ignore_backoff=ignore_backoff, ignore_backoff=ignore_backoff,
@ -434,9 +472,16 @@ class MatrixFederationHttpClient(object):
check_content_type_is_json(response.headers) check_content_type_is_json(response.headers)
with logcontext.PreserveLoggingContext(): with logcontext.PreserveLoggingContext():
body = yield readBody(response) d = treq.json_content(response)
if timeout:
_sec_timeout = timeout / 1000
else:
_sec_timeout = self.default_timeout
defer.returnValue(json.loads(body)) d.addTimeout(_sec_timeout, self.hs.get_reactor())
body = yield make_deferred_yieldable(d)
defer.returnValue(body)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_json(self, destination, path, args=None, retry_on_dns_fail=True, def get_json(self, destination, path, args=None, retry_on_dns_fail=True,
@ -471,16 +516,11 @@ class MatrixFederationHttpClient(object):
logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail) logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail)
def body_callback(method, url_bytes, headers_dict):
self.sign_request(destination, method, url_bytes, headers_dict)
return None
response = yield self._request( response = yield self._request(
destination, destination,
"GET", "GET",
path, path,
query_bytes=encode_query_args(args), query=args,
body_callback=body_callback,
retry_on_dns_fail=retry_on_dns_fail, retry_on_dns_fail=retry_on_dns_fail,
timeout=timeout, timeout=timeout,
ignore_backoff=ignore_backoff, ignore_backoff=ignore_backoff,
@ -491,9 +531,11 @@ class MatrixFederationHttpClient(object):
check_content_type_is_json(response.headers) check_content_type_is_json(response.headers)
with logcontext.PreserveLoggingContext(): with logcontext.PreserveLoggingContext():
body = yield readBody(response) d = treq.json_content(response)
d.addTimeout(self.default_timeout, self.hs.get_reactor())
body = yield make_deferred_yieldable(d)
defer.returnValue(json.loads(body)) defer.returnValue(body)
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_json(self, destination, path, long_retries=False, def delete_json(self, destination, path, long_retries=False,
@ -523,13 +565,11 @@ class MatrixFederationHttpClient(object):
Fails with ``FederationDeniedError`` if this destination Fails with ``FederationDeniedError`` if this destination
is not on our federation whitelist is not on our federation whitelist
""" """
response = yield self._request( response = yield self._request(
destination, destination,
"DELETE", "DELETE",
path, path,
query_bytes=encode_query_args(args), query=args,
headers_dict={"Content-Type": ["application/json"]},
long_retries=long_retries, long_retries=long_retries,
timeout=timeout, timeout=timeout,
ignore_backoff=ignore_backoff, ignore_backoff=ignore_backoff,
@ -540,9 +580,11 @@ class MatrixFederationHttpClient(object):
check_content_type_is_json(response.headers) check_content_type_is_json(response.headers)
with logcontext.PreserveLoggingContext(): with logcontext.PreserveLoggingContext():
body = yield readBody(response) d = treq.json_content(response)
d.addTimeout(self.default_timeout, self.hs.get_reactor())
body = yield make_deferred_yieldable(d)
defer.returnValue(json.loads(body)) defer.returnValue(body)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_file(self, destination, path, output_stream, args={}, def get_file(self, destination, path, output_stream, args={},
@ -569,26 +611,11 @@ class MatrixFederationHttpClient(object):
Fails with ``FederationDeniedError`` if this destination Fails with ``FederationDeniedError`` if this destination
is not on our federation whitelist is not on our federation whitelist
""" """
encoded_args = {}
for k, vs in args.items():
if isinstance(vs, string_types):
vs = [vs]
encoded_args[k] = [v.encode("UTF-8") for v in vs]
query_bytes = urllib.urlencode(encoded_args, True)
logger.debug("Query bytes: %s Retry DNS: %s", query_bytes, retry_on_dns_fail)
def body_callback(method, url_bytes, headers_dict):
self.sign_request(destination, method, url_bytes, headers_dict)
return None
response = yield self._request( response = yield self._request(
destination, destination,
"GET", "GET",
path, path,
query_bytes=query_bytes, query=args,
body_callback=body_callback,
retry_on_dns_fail=retry_on_dns_fail, retry_on_dns_fail=retry_on_dns_fail,
ignore_backoff=ignore_backoff, ignore_backoff=ignore_backoff,
) )
@ -597,9 +624,9 @@ class MatrixFederationHttpClient(object):
try: try:
with logcontext.PreserveLoggingContext(): with logcontext.PreserveLoggingContext():
length = yield _readBodyToFile( d = _readBodyToFile(response, output_stream, max_size)
response, output_stream, max_size d.addTimeout(self.default_timeout, self.hs.get_reactor())
) length = yield make_deferred_yieldable(d)
except Exception: except Exception:
logger.exception("Failed to download body") logger.exception("Failed to download body")
raise raise
@ -639,30 +666,6 @@ def _readBodyToFile(response, stream, max_size):
return d return d
class _JsonProducer(object):
""" Used by the twisted http client to create the HTTP body from json
"""
def __init__(self, jsn):
self.reset(jsn)
def reset(self, jsn):
self.body = encode_canonical_json(jsn)
self.length = len(self.body)
def startProducing(self, consumer):
consumer.write(self.body)
return defer.succeed(None)
def pauseProducing(self):
pass
def stopProducing(self):
pass
def resumeProducing(self):
pass
def _flatten_response_never_received(e): def _flatten_response_never_received(e):
if hasattr(e, "reasons"): if hasattr(e, "reasons"):
reasons = ", ".join( reasons = ", ".join(
@ -693,7 +696,7 @@ def check_content_type_is_json(headers):
"No Content-Type header" "No Content-Type header"
) )
c_type = c_type[0] # only the first header c_type = c_type[0].decode('ascii') # only the first header
val, options = cgi.parse_header(c_type) val, options = cgi.parse_header(c_type)
if val != "application/json": if val != "application/json":
raise RuntimeError( raise RuntimeError(
@ -711,6 +714,6 @@ def encode_query_args(args):
vs = [vs] vs = [vs]
encoded_args[k] = [v.encode("UTF-8") for v in vs] encoded_args[k] = [v.encode("UTF-8") for v in vs]
query_bytes = urllib.urlencode(encoded_args, True) query_bytes = urllib.parse.urlencode(encoded_args, True)
return query_bytes return query_bytes.encode('utf8')

View file

@ -85,7 +85,10 @@ class SynapseRequest(Request):
return "%s-%i" % (self.method, self.request_seq) return "%s-%i" % (self.method, self.request_seq)
def get_redacted_uri(self): def get_redacted_uri(self):
return redact_uri(self.uri) uri = self.uri
if isinstance(uri, bytes):
uri = self.uri.decode('ascii')
return redact_uri(uri)
def get_user_agent(self): def get_user_agent(self):
return self.requestHeaders.getRawHeaders(b"User-Agent", [None])[-1] return self.requestHeaders.getRawHeaders(b"User-Agent", [None])[-1]
@ -204,14 +207,14 @@ class SynapseRequest(Request):
self.start_time = time.time() self.start_time = time.time()
self.request_metrics = RequestMetrics() self.request_metrics = RequestMetrics()
self.request_metrics.start( self.request_metrics.start(
self.start_time, name=servlet_name, method=self.method, self.start_time, name=servlet_name, method=self.method.decode('ascii'),
) )
self.site.access_logger.info( self.site.access_logger.info(
"%s - %s - Received request: %s %s", "%s - %s - Received request: %s %s",
self.getClientIP(), self.getClientIP(),
self.site.site_tag, self.site.site_tag,
self.method, self.method.decode('ascii'),
self.get_redacted_uri() self.get_redacted_uri()
) )

View file

@ -18,8 +18,11 @@ import gc
import logging import logging
import os import os
import platform import platform
import threading
import time import time
import six
import attr import attr
from prometheus_client import Counter, Gauge, Histogram from prometheus_client import Counter, Gauge, Histogram
from prometheus_client.core import REGISTRY, GaugeMetricFamily from prometheus_client.core import REGISTRY, GaugeMetricFamily
@ -68,7 +71,7 @@ class LaterGauge(object):
return return
if isinstance(calls, dict): if isinstance(calls, dict):
for k, v in calls.items(): for k, v in six.iteritems(calls):
g.add_metric(k, v) g.add_metric(k, v)
else: else:
g.add_metric([], calls) g.add_metric([], calls)
@ -87,6 +90,109 @@ class LaterGauge(object):
all_gauges[self.name] = self all_gauges[self.name] = self
class InFlightGauge(object):
"""Tracks number of things (e.g. requests, Measure blocks, etc) in flight
at any given time.
Each InFlightGauge will create a metric called `<name>_total` that counts
the number of in flight blocks, as well as a metrics for each item in the
given `sub_metrics` as `<name>_<sub_metric>` which will get updated by the
callbacks.
Args:
name (str)
desc (str)
labels (list[str])
sub_metrics (list[str]): A list of sub metrics that the callbacks
will update.
"""
def __init__(self, name, desc, labels, sub_metrics):
self.name = name
self.desc = desc
self.labels = labels
self.sub_metrics = sub_metrics
# Create a class which have the sub_metrics values as attributes, which
# default to 0 on initialization. Used to pass to registered callbacks.
self._metrics_class = attr.make_class(
"_MetricsEntry",
attrs={x: attr.ib(0) for x in sub_metrics},
slots=True,
)
# Counts number of in flight blocks for a given set of label values
self._registrations = {}
# Protects access to _registrations
self._lock = threading.Lock()
self._register_with_collector()
def register(self, key, callback):
"""Registers that we've entered a new block with labels `key`.
`callback` gets called each time the metrics are collected. The same
value must also be given to `unregister`.
`callback` gets called with an object that has an attribute per
sub_metric, which should be updated with the necessary values. Note that
the metrics object is shared between all callbacks registered with the
same key.
Note that `callback` may be called on a separate thread.
"""
with self._lock:
self._registrations.setdefault(key, set()).add(callback)
def unregister(self, key, callback):
"""Registers that we've exited a block with labels `key`.
"""
with self._lock:
self._registrations.setdefault(key, set()).discard(callback)
def collect(self):
"""Called by prometheus client when it reads metrics.
Note: may be called by a separate thread.
"""
in_flight = GaugeMetricFamily(self.name + "_total", self.desc, labels=self.labels)
metrics_by_key = {}
# We copy so that we don't mutate the list while iterating
with self._lock:
keys = list(self._registrations)
for key in keys:
with self._lock:
callbacks = set(self._registrations[key])
in_flight.add_metric(key, len(callbacks))
metrics = self._metrics_class()
metrics_by_key[key] = metrics
for callback in callbacks:
callback(metrics)
yield in_flight
for name in self.sub_metrics:
gauge = GaugeMetricFamily("_".join([self.name, name]), "", labels=self.labels)
for key, metrics in six.iteritems(metrics_by_key):
gauge.add_metric(key, getattr(metrics, name))
yield gauge
def _register_with_collector(self):
if self.name in all_gauges.keys():
logger.warning("%s already registered, reregistering" % (self.name,))
REGISTRY.unregister(all_gauges.pop(self.name))
REGISTRY.register(self)
all_gauges[self.name] = self
# #
# Detailed CPU metrics # Detailed CPU metrics
# #

View file

@ -15,6 +15,8 @@
# limitations under the License. # limitations under the License.
import logging import logging
import six
from prometheus_client import Counter from prometheus_client import Counter
from twisted.internet import defer from twisted.internet import defer
@ -26,6 +28,9 @@ from synapse.util.metrics import Measure
from . import push_rule_evaluator, push_tools from . import push_rule_evaluator, push_tools
if six.PY3:
long = int
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
http_push_processed_counter = Counter("synapse_http_httppusher_http_pushes_processed", "") http_push_processed_counter = Counter("synapse_http_httppusher_http_pushes_processed", "")
@ -96,7 +101,7 @@ class HttpPusher(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_new_notifications(self, min_stream_ordering, max_stream_ordering): def on_new_notifications(self, min_stream_ordering, max_stream_ordering):
self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering) self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering or 0)
yield self._process() yield self._process()
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -17,10 +17,11 @@ import email.mime.multipart
import email.utils import email.utils
import logging import logging
import time import time
import urllib
from email.mime.multipart import MIMEMultipart from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText from email.mime.text import MIMEText
from six.moves import urllib
import bleach import bleach
import jinja2 import jinja2
@ -474,7 +475,7 @@ class Mailer(object):
# XXX: make r0 once API is stable # XXX: make r0 once API is stable
return "%s_matrix/client/unstable/pushers/remove?%s" % ( return "%s_matrix/client/unstable/pushers/remove?%s" % (
self.hs.config.public_baseurl, self.hs.config.public_baseurl,
urllib.urlencode(params), urllib.parse.urlencode(params),
) )
@ -561,7 +562,7 @@ def _create_mxc_to_http_filter(config):
return "%s_matrix/media/v1/thumbnail/%s?%s%s" % ( return "%s_matrix/media/v1/thumbnail/%s?%s%s" % (
config.public_baseurl, config.public_baseurl,
serverAndMediaId, serverAndMediaId,
urllib.urlencode(params), urllib.parse.urlencode(params),
fragment or "", fragment or "",
) )

View file

@ -40,9 +40,10 @@ REQUIREMENTS = {
"pynacl>=1.2.1": ["nacl>=1.2.1", "nacl.bindings"], "pynacl>=1.2.1": ["nacl>=1.2.1", "nacl.bindings"],
"service_identity>=1.0.0": ["service_identity>=1.0.0"], "service_identity>=1.0.0": ["service_identity>=1.0.0"],
"Twisted>=17.1.0": ["twisted>=17.1.0"], "Twisted>=17.1.0": ["twisted>=17.1.0"],
"treq>=15.1": ["treq>=15.1"],
# We use crypto.get_elliptic_curve which is only supported in >=0.15 # Twisted has required pyopenssl 16.0 since about Twisted 16.6.
"pyopenssl>=0.15": ["OpenSSL>=0.15"], "pyopenssl>=16.0.0": ["OpenSSL>=16.0.0"],
"pyyaml": ["yaml"], "pyyaml": ["yaml"],
"pyasn1": ["pyasn1"], "pyasn1": ["pyasn1"],

View file

@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import six
from synapse.storage import DataStore from synapse.storage import DataStore
from synapse.storage.end_to_end_keys import EndToEndKeyStore from synapse.storage.end_to_end_keys import EndToEndKeyStore
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
@ -21,6 +23,13 @@ from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker from ._slaved_id_tracker import SlavedIdTracker
def __func__(inp):
if six.PY3:
return inp
else:
return inp.__func__
class SlavedDeviceStore(BaseSlavedStore): class SlavedDeviceStore(BaseSlavedStore):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(SlavedDeviceStore, self).__init__(db_conn, hs) super(SlavedDeviceStore, self).__init__(db_conn, hs)
@ -38,14 +47,14 @@ class SlavedDeviceStore(BaseSlavedStore):
"DeviceListFederationStreamChangeCache", device_list_max, "DeviceListFederationStreamChangeCache", device_list_max,
) )
get_device_stream_token = DataStore.get_device_stream_token.__func__ get_device_stream_token = __func__(DataStore.get_device_stream_token)
get_user_whose_devices_changed = DataStore.get_user_whose_devices_changed.__func__ get_user_whose_devices_changed = __func__(DataStore.get_user_whose_devices_changed)
get_devices_by_remote = DataStore.get_devices_by_remote.__func__ get_devices_by_remote = __func__(DataStore.get_devices_by_remote)
_get_devices_by_remote_txn = DataStore._get_devices_by_remote_txn.__func__ _get_devices_by_remote_txn = __func__(DataStore._get_devices_by_remote_txn)
_get_e2e_device_keys_txn = DataStore._get_e2e_device_keys_txn.__func__ _get_e2e_device_keys_txn = __func__(DataStore._get_e2e_device_keys_txn)
mark_as_sent_devices_by_remote = DataStore.mark_as_sent_devices_by_remote.__func__ mark_as_sent_devices_by_remote = __func__(DataStore.mark_as_sent_devices_by_remote)
_mark_as_sent_devices_by_remote_txn = ( _mark_as_sent_devices_by_remote_txn = (
DataStore._mark_as_sent_devices_by_remote_txn.__func__ __func__(DataStore._mark_as_sent_devices_by_remote_txn)
) )
count_e2e_one_time_keys = EndToEndKeyStore.__dict__["count_e2e_one_time_keys"] count_e2e_one_time_keys = EndToEndKeyStore.__dict__["count_e2e_one_time_keys"]

View file

@ -590,9 +590,9 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
pending_commands = LaterGauge( pending_commands = LaterGauge(
"synapse_replication_tcp_protocol_pending_commands", "synapse_replication_tcp_protocol_pending_commands",
"", "",
["name", "conn_id"], ["name"],
lambda: { lambda: {
(p.name, p.conn_id): len(p.pending_commands) for p in connected_connections (p.name,): len(p.pending_commands) for p in connected_connections
}, },
) )
@ -607,9 +607,9 @@ def transport_buffer_size(protocol):
transport_send_buffer = LaterGauge( transport_send_buffer = LaterGauge(
"synapse_replication_tcp_protocol_transport_send_buffer", "synapse_replication_tcp_protocol_transport_send_buffer",
"", "",
["name", "conn_id"], ["name"],
lambda: { lambda: {
(p.name, p.conn_id): transport_buffer_size(p) for p in connected_connections (p.name,): transport_buffer_size(p) for p in connected_connections
}, },
) )
@ -632,9 +632,9 @@ def transport_kernel_read_buffer_size(protocol, read=True):
tcp_transport_kernel_send_buffer = LaterGauge( tcp_transport_kernel_send_buffer = LaterGauge(
"synapse_replication_tcp_protocol_transport_kernel_send_buffer", "synapse_replication_tcp_protocol_transport_kernel_send_buffer",
"", "",
["name", "conn_id"], ["name"],
lambda: { lambda: {
(p.name, p.conn_id): transport_kernel_read_buffer_size(p, False) (p.name,): transport_kernel_read_buffer_size(p, False)
for p in connected_connections for p in connected_connections
}, },
) )
@ -643,9 +643,9 @@ tcp_transport_kernel_send_buffer = LaterGauge(
tcp_transport_kernel_read_buffer = LaterGauge( tcp_transport_kernel_read_buffer = LaterGauge(
"synapse_replication_tcp_protocol_transport_kernel_read_buffer", "synapse_replication_tcp_protocol_transport_kernel_read_buffer",
"", "",
["name", "conn_id"], ["name"],
lambda: { lambda: {
(p.name, p.conn_id): transport_kernel_read_buffer_size(p, True) (p.name,): transport_kernel_read_buffer_size(p, True)
for p in connected_connections for p in connected_connections
}, },
) )
@ -654,9 +654,9 @@ tcp_transport_kernel_read_buffer = LaterGauge(
tcp_inbound_commands = LaterGauge( tcp_inbound_commands = LaterGauge(
"synapse_replication_tcp_protocol_inbound_commands", "synapse_replication_tcp_protocol_inbound_commands",
"", "",
["command", "name", "conn_id"], ["command", "name"],
lambda: { lambda: {
(k[0], p.name, p.conn_id): count (k[0], p.name,): count
for p in connected_connections for p in connected_connections
for k, count in iteritems(p.inbound_commands_counter) for k, count in iteritems(p.inbound_commands_counter)
}, },
@ -665,9 +665,9 @@ tcp_inbound_commands = LaterGauge(
tcp_outbound_commands = LaterGauge( tcp_outbound_commands = LaterGauge(
"synapse_replication_tcp_protocol_outbound_commands", "synapse_replication_tcp_protocol_outbound_commands",
"", "",
["command", "name", "conn_id"], ["command", "name"],
lambda: { lambda: {
(k[0], p.name, p.conn_id): count (k[0], p.name,): count
for p in connected_connections for p in connected_connections
for k, count in iteritems(p.outbound_commands_counter) for k, count in iteritems(p.outbound_commands_counter)
}, },

View file

@ -196,7 +196,7 @@ class Stream(object):
) )
if len(rows) >= MAX_EVENTS_BEHIND: if len(rows) >= MAX_EVENTS_BEHIND:
raise Exception("stream %s has fallen behined" % (self.NAME)) raise Exception("stream %s has fallen behind" % (self.NAME))
else: else:
rows = yield self.update_function( rows = yield self.update_function(
from_token, current_token, from_token, current_token,

View file

@ -101,7 +101,7 @@ class UserRegisterServlet(ClientV1RestServlet):
nonce = self.hs.get_secrets().token_hex(64) nonce = self.hs.get_secrets().token_hex(64)
self.nonces[nonce] = int(self.reactor.seconds()) self.nonces[nonce] = int(self.reactor.seconds())
return (200, {"nonce": nonce.encode('ascii')}) return (200, {"nonce": nonce})
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
@ -164,7 +164,7 @@ class UserRegisterServlet(ClientV1RestServlet):
key=self.hs.config.registration_shared_secret.encode(), key=self.hs.config.registration_shared_secret.encode(),
digestmod=hashlib.sha1, digestmod=hashlib.sha1,
) )
want_mac.update(nonce) want_mac.update(nonce.encode('utf8'))
want_mac.update(b"\x00") want_mac.update(b"\x00")
want_mac.update(username) want_mac.update(username)
want_mac.update(b"\x00") want_mac.update(b"\x00")
@ -173,7 +173,10 @@ class UserRegisterServlet(ClientV1RestServlet):
want_mac.update(b"admin" if admin else b"notadmin") want_mac.update(b"admin" if admin else b"notadmin")
want_mac = want_mac.hexdigest() want_mac = want_mac.hexdigest()
if not hmac.compare_digest(want_mac, got_mac.encode('ascii')): if not hmac.compare_digest(
want_mac.encode('ascii'),
got_mac.encode('ascii')
):
raise SynapseError(403, "HMAC incorrect") raise SynapseError(403, "HMAC incorrect")
# Reuse the parts of RegisterRestServlet to reduce code duplication # Reuse the parts of RegisterRestServlet to reduce code duplication

View file

@ -45,20 +45,20 @@ class EventStreamRestServlet(ClientV1RestServlet):
is_guest = requester.is_guest is_guest = requester.is_guest
room_id = None room_id = None
if is_guest: if is_guest:
if "room_id" not in request.args: if b"room_id" not in request.args:
raise SynapseError(400, "Guest users must specify room_id param") raise SynapseError(400, "Guest users must specify room_id param")
if "room_id" in request.args: if b"room_id" in request.args:
room_id = request.args["room_id"][0] room_id = request.args[b"room_id"][0].decode('ascii')
pagin_config = PaginationConfig.from_request(request) pagin_config = PaginationConfig.from_request(request)
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
if "timeout" in request.args: if b"timeout" in request.args:
try: try:
timeout = int(request.args["timeout"][0]) timeout = int(request.args[b"timeout"][0])
except ValueError: except ValueError:
raise SynapseError(400, "timeout must be in milliseconds.") raise SynapseError(400, "timeout must be in milliseconds.")
as_client_event = "raw" not in request.args as_client_event = b"raw" not in request.args
chunk = yield self.event_stream_handler.get_stream( chunk = yield self.event_stream_handler.get_stream(
requester.user.to_string(), requester.user.to_string(),

View file

@ -32,7 +32,7 @@ class InitialSyncRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
as_client_event = "raw" not in request.args as_client_event = b"raw" not in request.args
pagination_config = PaginationConfig.from_request(request) pagination_config = PaginationConfig.from_request(request)
include_archived = parse_boolean(request, "archived", default=False) include_archived = parse_boolean(request, "archived", default=False)
content = yield self.initial_sync_handler.snapshot_all_rooms( content = yield self.initial_sync_handler.snapshot_all_rooms(

View file

@ -14,10 +14,9 @@
# limitations under the License. # limitations under the License.
import logging import logging
import urllib
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from six.moves.urllib import parse as urlparse from six.moves import urllib
from canonicaljson import json from canonicaljson import json
from saml2 import BINDING_HTTP_POST, config from saml2 import BINDING_HTTP_POST, config
@ -134,7 +133,7 @@ class LoginRestServlet(ClientV1RestServlet):
LoginRestServlet.SAML2_TYPE): LoginRestServlet.SAML2_TYPE):
relay_state = "" relay_state = ""
if "relay_state" in login_submission: if "relay_state" in login_submission:
relay_state = "&RelayState=" + urllib.quote( relay_state = "&RelayState=" + urllib.parse.quote(
login_submission["relay_state"]) login_submission["relay_state"])
result = { result = {
"uri": "%s%s" % (self.idp_redirect_url, relay_state) "uri": "%s%s" % (self.idp_redirect_url, relay_state)
@ -366,7 +365,7 @@ class SAML2RestServlet(ClientV1RestServlet):
(user_id, token) = yield handler.register_saml2(username) (user_id, token) = yield handler.register_saml2(username)
# Forward to the RelayState callback along with ava # Forward to the RelayState callback along with ava
if 'RelayState' in request.args: if 'RelayState' in request.args:
request.redirect(urllib.unquote( request.redirect(urllib.parse.unquote(
request.args['RelayState'][0]) + request.args['RelayState'][0]) +
'?status=authenticated&access_token=' + '?status=authenticated&access_token=' +
token + '&user_id=' + user_id + '&ava=' + token + '&user_id=' + user_id + '&ava=' +
@ -377,7 +376,7 @@ class SAML2RestServlet(ClientV1RestServlet):
"user_id": user_id, "token": token, "user_id": user_id, "token": token,
"ava": saml2_auth.ava})) "ava": saml2_auth.ava}))
elif 'RelayState' in request.args: elif 'RelayState' in request.args:
request.redirect(urllib.unquote( request.redirect(urllib.parse.unquote(
request.args['RelayState'][0]) + request.args['RelayState'][0]) +
'?status=not_authenticated') '?status=not_authenticated')
finish_request(request) finish_request(request)
@ -390,21 +389,22 @@ class CasRedirectServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(CasRedirectServlet, self).__init__(hs) super(CasRedirectServlet, self).__init__(hs)
self.cas_server_url = hs.config.cas_server_url self.cas_server_url = hs.config.cas_server_url.encode('ascii')
self.cas_service_url = hs.config.cas_service_url self.cas_service_url = hs.config.cas_service_url.encode('ascii')
def on_GET(self, request): def on_GET(self, request):
args = request.args args = request.args
if "redirectUrl" not in args: if b"redirectUrl" not in args:
return (400, "Redirect URL not specified for CAS auth") return (400, "Redirect URL not specified for CAS auth")
client_redirect_url_param = urllib.urlencode({ client_redirect_url_param = urllib.parse.urlencode({
"redirectUrl": args["redirectUrl"][0] b"redirectUrl": args[b"redirectUrl"][0]
}) }).encode('ascii')
hs_redirect_url = self.cas_service_url + "/_matrix/client/api/v1/login/cas/ticket" hs_redirect_url = (self.cas_service_url +
service_param = urllib.urlencode({ b"/_matrix/client/api/v1/login/cas/ticket")
"service": "%s?%s" % (hs_redirect_url, client_redirect_url_param) service_param = urllib.parse.urlencode({
}) b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)
request.redirect("%s/login?%s" % (self.cas_server_url, service_param)) }).encode('ascii')
request.redirect(b"%s/login?%s" % (self.cas_server_url, service_param))
finish_request(request) finish_request(request)
@ -422,11 +422,11 @@ class CasTicketServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
client_redirect_url = request.args["redirectUrl"][0] client_redirect_url = request.args[b"redirectUrl"][0]
http_client = self.hs.get_simple_http_client() http_client = self.hs.get_simple_http_client()
uri = self.cas_server_url + "/proxyValidate" uri = self.cas_server_url + "/proxyValidate"
args = { args = {
"ticket": request.args["ticket"], "ticket": request.args[b"ticket"][0].decode('ascii'),
"service": self.cas_service_url "service": self.cas_service_url
} }
try: try:
@ -471,11 +471,11 @@ class CasTicketServlet(ClientV1RestServlet):
finish_request(request) finish_request(request)
def add_login_token_to_redirect_url(self, url, token): def add_login_token_to_redirect_url(self, url, token):
url_parts = list(urlparse.urlparse(url)) url_parts = list(urllib.parse.urlparse(url))
query = dict(urlparse.parse_qsl(url_parts[4])) query = dict(urllib.parse.parse_qsl(url_parts[4]))
query.update({"loginToken": token}) query.update({"loginToken": token})
url_parts[4] = urllib.urlencode(query) url_parts[4] = urllib.parse.urlencode(query).encode('ascii')
return urlparse.urlunparse(url_parts) return urllib.parse.urlunparse(url_parts)
def parse_cas_response(self, cas_response_body): def parse_cas_response(self, cas_response_body):
user = None user = None

View file

@ -46,7 +46,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
try: try:
priority_class = _priority_class_from_spec(spec) priority_class = _priority_class_from_spec(spec)
except InvalidRuleException as e: except InvalidRuleException as e:
raise SynapseError(400, e.message) raise SynapseError(400, str(e))
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
@ -73,7 +73,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
content, content,
) )
except InvalidRuleException as e: except InvalidRuleException as e:
raise SynapseError(400, e.message) raise SynapseError(400, str(e))
before = parse_string(request, "before") before = parse_string(request, "before")
if before: if before:
@ -95,9 +95,9 @@ class PushRuleRestServlet(ClientV1RestServlet):
) )
self.notify_user(user_id) self.notify_user(user_id)
except InconsistentRuleException as e: except InconsistentRuleException as e:
raise SynapseError(400, e.message) raise SynapseError(400, str(e))
except RuleNotFoundException as e: except RuleNotFoundException as e:
raise SynapseError(400, e.message) raise SynapseError(400, str(e))
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -142,10 +142,10 @@ class PushRuleRestServlet(ClientV1RestServlet):
PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
) )
if path[0] == '': if path[0] == b'':
defer.returnValue((200, rules)) defer.returnValue((200, rules))
elif path[0] == 'global': elif path[0] == b'global':
path = path[1:] path = [x.decode('ascii') for x in path[1:]]
result = _filter_ruleset_with_path(rules['global'], path) result = _filter_ruleset_with_path(rules['global'], path)
defer.returnValue((200, result)) defer.returnValue((200, result))
else: else:
@ -192,10 +192,10 @@ class PushRuleRestServlet(ClientV1RestServlet):
def _rule_spec_from_path(path): def _rule_spec_from_path(path):
if len(path) < 2: if len(path) < 2:
raise UnrecognizedRequestError() raise UnrecognizedRequestError()
if path[0] != 'pushrules': if path[0] != b'pushrules':
raise UnrecognizedRequestError() raise UnrecognizedRequestError()
scope = path[1] scope = path[1].decode('ascii')
path = path[2:] path = path[2:]
if scope != 'global': if scope != 'global':
raise UnrecognizedRequestError() raise UnrecognizedRequestError()
@ -203,13 +203,13 @@ def _rule_spec_from_path(path):
if len(path) == 0: if len(path) == 0:
raise UnrecognizedRequestError() raise UnrecognizedRequestError()
template = path[0] template = path[0].decode('ascii')
path = path[1:] path = path[1:]
if len(path) == 0 or len(path[0]) == 0: if len(path) == 0 or len(path[0]) == 0:
raise UnrecognizedRequestError() raise UnrecognizedRequestError()
rule_id = path[0] rule_id = path[0].decode('ascii')
spec = { spec = {
'scope': scope, 'scope': scope,
@ -220,7 +220,7 @@ def _rule_spec_from_path(path):
path = path[1:] path = path[1:]
if len(path) > 0 and len(path[0]) > 0: if len(path) > 0 and len(path[0]) > 0:
spec['attr'] = path[0] spec['attr'] = path[0].decode('ascii')
return spec return spec

View file

@ -59,7 +59,7 @@ class PushersRestServlet(ClientV1RestServlet):
] ]
for p in pushers: for p in pushers:
for k, v in p.items(): for k, v in list(p.items()):
if k not in allowed_keys: if k not in allowed_keys:
del p[k] del p[k]
@ -126,7 +126,7 @@ class PushersSetRestServlet(ClientV1RestServlet):
profile_tag=content.get('profile_tag', ""), profile_tag=content.get('profile_tag', ""),
) )
except PusherConfigException as pce: except PusherConfigException as pce:
raise SynapseError(400, "Config Error: " + pce.message, raise SynapseError(400, "Config Error: " + str(pce),
errcode=Codes.MISSING_PARAM) errcode=Codes.MISSING_PARAM)
self.notifier.on_new_replication_data() self.notifier.on_new_replication_data()

View file

@ -207,7 +207,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
"sender": requester.user.to_string(), "sender": requester.user.to_string(),
} }
if 'ts' in request.args and requester.app_service: if b'ts' in request.args and requester.app_service:
event_dict['origin_server_ts'] = parse_integer(request, "ts", 0) event_dict['origin_server_ts'] = parse_integer(request, "ts", 0)
event = yield self.event_creation_hander.create_and_send_nonmember_event( event = yield self.event_creation_hander.create_and_send_nonmember_event(
@ -255,7 +255,9 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
if RoomID.is_valid(room_identifier): if RoomID.is_valid(room_identifier):
room_id = room_identifier room_id = room_identifier
try: try:
remote_room_hosts = request.args["server_name"] remote_room_hosts = [
x.decode('ascii') for x in request.args[b"server_name"]
]
except Exception: except Exception:
remote_room_hosts = None remote_room_hosts = None
elif RoomAlias.is_valid(room_identifier): elif RoomAlias.is_valid(room_identifier):
@ -461,10 +463,10 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
pagination_config = PaginationConfig.from_request( pagination_config = PaginationConfig.from_request(
request, default_limit=10, request, default_limit=10,
) )
as_client_event = "raw" not in request.args as_client_event = b"raw" not in request.args
filter_bytes = parse_string(request, "filter") filter_bytes = parse_string(request, b"filter", encoding=None)
if filter_bytes: if filter_bytes:
filter_json = urlparse.unquote(filter_bytes).decode("UTF-8") filter_json = urlparse.unquote(filter_bytes.decode("UTF-8"))
event_filter = Filter(json.loads(filter_json)) event_filter = Filter(json.loads(filter_json))
else: else:
event_filter = None event_filter = None
@ -560,7 +562,7 @@ class RoomEventContextServlet(ClientV1RestServlet):
# picking the API shape for symmetry with /messages # picking the API shape for symmetry with /messages
filter_bytes = parse_string(request, "filter") filter_bytes = parse_string(request, "filter")
if filter_bytes: if filter_bytes:
filter_json = urlparse.unquote(filter_bytes).decode("UTF-8") filter_json = urlparse.unquote(filter_bytes)
event_filter = Filter(json.loads(filter_json)) event_filter = Filter(json.loads(filter_json))
else: else:
event_filter = None event_filter = None

View file

@ -42,7 +42,11 @@ class VoipRestServlet(ClientV1RestServlet):
expiry = (self.hs.get_clock().time_msec() + userLifetime) / 1000 expiry = (self.hs.get_clock().time_msec() + userLifetime) / 1000
username = "%d:%s" % (expiry, requester.user.to_string()) username = "%d:%s" % (expiry, requester.user.to_string())
mac = hmac.new(turnSecret, msg=username, digestmod=hashlib.sha1) mac = hmac.new(
turnSecret.encode(),
msg=username.encode(),
digestmod=hashlib.sha1
)
# We need to use standard padded base64 encoding here # We need to use standard padded base64 encoding here
# encode_base64 because we need to add the standard padding to get the # encode_base64 because we need to add the standard padding to get the
# same result as the TURN server. # same result as the TURN server.

View file

@ -53,7 +53,9 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
if not check_3pid_allowed(self.hs, "email", body['email']): if not check_3pid_allowed(self.hs, "email", body['email']):
raise SynapseError( raise SynapseError(
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED, 403,
"Your email domain is not authorized on this server",
Codes.THREEPID_DENIED,
) )
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
@ -89,7 +91,9 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet):
if not check_3pid_allowed(self.hs, "msisdn", msisdn): if not check_3pid_allowed(self.hs, "msisdn", msisdn):
raise SynapseError( raise SynapseError(
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED, 403,
"Account phone numbers are not authorized on this server",
Codes.THREEPID_DENIED,
) )
existingUid = yield self.datastore.get_user_id_by_threepid( existingUid = yield self.datastore.get_user_id_by_threepid(
@ -241,7 +245,9 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
if not check_3pid_allowed(self.hs, "email", body['email']): if not check_3pid_allowed(self.hs, "email", body['email']):
raise SynapseError( raise SynapseError(
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED, 403,
"Your email domain is not authorized on this server",
Codes.THREEPID_DENIED,
) )
existingUid = yield self.datastore.get_user_id_by_threepid( existingUid = yield self.datastore.get_user_id_by_threepid(
@ -276,7 +282,9 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
if not check_3pid_allowed(self.hs, "msisdn", msisdn): if not check_3pid_allowed(self.hs, "msisdn", msisdn):
raise SynapseError( raise SynapseError(
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED, 403,
"Account phone numbers are not authorized on this server",
Codes.THREEPID_DENIED,
) )
existingUid = yield self.datastore.get_user_id_by_threepid( existingUid = yield self.datastore.get_user_id_by_threepid(

View file

@ -75,7 +75,9 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
if not check_3pid_allowed(self.hs, "email", body['email']): if not check_3pid_allowed(self.hs, "email", body['email']):
raise SynapseError( raise SynapseError(
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED, 403,
"Your email domain is not authorized to register on this server",
Codes.THREEPID_DENIED,
) )
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
@ -115,7 +117,9 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
if not check_3pid_allowed(self.hs, "msisdn", msisdn): if not check_3pid_allowed(self.hs, "msisdn", msisdn):
raise SynapseError( raise SynapseError(
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED, 403,
"Phone numbers are not authorized to register on this server",
Codes.THREEPID_DENIED,
) )
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
@ -373,7 +377,9 @@ class RegisterRestServlet(RestServlet):
if not check_3pid_allowed(self.hs, medium, address): if not check_3pid_allowed(self.hs, medium, address):
raise SynapseError( raise SynapseError(
403, "Third party identifier is not allowed", 403,
"Third party identifiers (email/phone numbers)" +
" are not authorized on this server",
Codes.THREEPID_DENIED, Codes.THREEPID_DENIED,
) )

View file

@ -25,6 +25,7 @@ from synapse.api.errors import SynapseError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection
from synapse.events.utils import ( from synapse.events.utils import (
format_event_for_client_v2_without_room_id, format_event_for_client_v2_without_room_id,
format_event_raw,
serialize_event, serialize_event,
) )
from synapse.handlers.presence import format_user_presence_state from synapse.handlers.presence import format_user_presence_state
@ -88,7 +89,7 @@ class SyncRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
if "from" in request.args: if b"from" in request.args:
# /events used to use 'from', but /sync uses 'since'. # /events used to use 'from', but /sync uses 'since'.
# Lets be helpful and whine if we see a 'from'. # Lets be helpful and whine if we see a 'from'.
raise SynapseError( raise SynapseError(
@ -175,17 +176,28 @@ class SyncRestServlet(RestServlet):
@staticmethod @staticmethod
def encode_response(time_now, sync_result, access_token_id, filter): def encode_response(time_now, sync_result, access_token_id, filter):
if filter.event_format == 'client':
event_formatter = format_event_for_client_v2_without_room_id
elif filter.event_format == 'federation':
event_formatter = format_event_raw
else:
raise Exception("Unknown event format %s" % (filter.event_format, ))
joined = SyncRestServlet.encode_joined( joined = SyncRestServlet.encode_joined(
sync_result.joined, time_now, access_token_id, filter.event_fields sync_result.joined, time_now, access_token_id,
filter.event_fields,
event_formatter,
) )
invited = SyncRestServlet.encode_invited( invited = SyncRestServlet.encode_invited(
sync_result.invited, time_now, access_token_id, sync_result.invited, time_now, access_token_id,
event_formatter,
) )
archived = SyncRestServlet.encode_archived( archived = SyncRestServlet.encode_archived(
sync_result.archived, time_now, access_token_id, sync_result.archived, time_now, access_token_id,
filter.event_fields, filter.event_fields,
event_formatter,
) )
return { return {
@ -228,7 +240,7 @@ class SyncRestServlet(RestServlet):
} }
@staticmethod @staticmethod
def encode_joined(rooms, time_now, token_id, event_fields): def encode_joined(rooms, time_now, token_id, event_fields, event_formatter):
""" """
Encode the joined rooms in a sync result Encode the joined rooms in a sync result
@ -240,7 +252,9 @@ class SyncRestServlet(RestServlet):
token_id(int): ID of the user's auth token - used for namespacing token_id(int): ID of the user's auth token - used for namespacing
of transaction IDs of transaction IDs
event_fields(list<str>): List of event fields to include. If empty, event_fields(list<str>): List of event fields to include. If empty,
all fields will be returned. all fields will be returned.
event_formatter (func[dict]): function to convert from federation format
to client format
Returns: Returns:
dict[str, dict[str, object]]: the joined rooms list, in our dict[str, dict[str, object]]: the joined rooms list, in our
response format response format
@ -248,13 +262,14 @@ class SyncRestServlet(RestServlet):
joined = {} joined = {}
for room in rooms: for room in rooms:
joined[room.room_id] = SyncRestServlet.encode_room( joined[room.room_id] = SyncRestServlet.encode_room(
room, time_now, token_id, only_fields=event_fields room, time_now, token_id, joined=True, only_fields=event_fields,
event_formatter=event_formatter,
) )
return joined return joined
@staticmethod @staticmethod
def encode_invited(rooms, time_now, token_id): def encode_invited(rooms, time_now, token_id, event_formatter):
""" """
Encode the invited rooms in a sync result Encode the invited rooms in a sync result
@ -264,7 +279,9 @@ class SyncRestServlet(RestServlet):
time_now(int): current time - used as a baseline for age time_now(int): current time - used as a baseline for age
calculations calculations
token_id(int): ID of the user's auth token - used for namespacing token_id(int): ID of the user's auth token - used for namespacing
of transaction IDs of transaction IDs
event_formatter (func[dict]): function to convert from federation format
to client format
Returns: Returns:
dict[str, dict[str, object]]: the invited rooms list, in our dict[str, dict[str, object]]: the invited rooms list, in our
@ -274,7 +291,7 @@ class SyncRestServlet(RestServlet):
for room in rooms: for room in rooms:
invite = serialize_event( invite = serialize_event(
room.invite, time_now, token_id=token_id, room.invite, time_now, token_id=token_id,
event_format=format_event_for_client_v2_without_room_id, event_format=event_formatter,
is_invite=True, is_invite=True,
) )
unsigned = dict(invite.get("unsigned", {})) unsigned = dict(invite.get("unsigned", {}))
@ -288,7 +305,7 @@ class SyncRestServlet(RestServlet):
return invited return invited
@staticmethod @staticmethod
def encode_archived(rooms, time_now, token_id, event_fields): def encode_archived(rooms, time_now, token_id, event_fields, event_formatter):
""" """
Encode the archived rooms in a sync result Encode the archived rooms in a sync result
@ -300,7 +317,9 @@ class SyncRestServlet(RestServlet):
token_id(int): ID of the user's auth token - used for namespacing token_id(int): ID of the user's auth token - used for namespacing
of transaction IDs of transaction IDs
event_fields(list<str>): List of event fields to include. If empty, event_fields(list<str>): List of event fields to include. If empty,
all fields will be returned. all fields will be returned.
event_formatter (func[dict]): function to convert from federation format
to client format
Returns: Returns:
dict[str, dict[str, object]]: The invited rooms list, in our dict[str, dict[str, object]]: The invited rooms list, in our
response format response format
@ -308,13 +327,18 @@ class SyncRestServlet(RestServlet):
joined = {} joined = {}
for room in rooms: for room in rooms:
joined[room.room_id] = SyncRestServlet.encode_room( joined[room.room_id] = SyncRestServlet.encode_room(
room, time_now, token_id, joined=False, only_fields=event_fields room, time_now, token_id, joined=False,
only_fields=event_fields,
event_formatter=event_formatter,
) )
return joined return joined
@staticmethod @staticmethod
def encode_room(room, time_now, token_id, joined=True, only_fields=None): def encode_room(
room, time_now, token_id, joined,
only_fields, event_formatter,
):
""" """
Args: Args:
room (JoinedSyncResult|ArchivedSyncResult): sync result for a room (JoinedSyncResult|ArchivedSyncResult): sync result for a
@ -326,14 +350,15 @@ class SyncRestServlet(RestServlet):
joined (bool): True if the user is joined to this room - will mean joined (bool): True if the user is joined to this room - will mean
we handle ephemeral events we handle ephemeral events
only_fields(list<str>): Optional. The list of event fields to include. only_fields(list<str>): Optional. The list of event fields to include.
event_formatter (func[dict]): function to convert from federation format
to client format
Returns: Returns:
dict[str, object]: the room, encoded in our response format dict[str, object]: the room, encoded in our response format
""" """
def serialize(event): def serialize(event):
# TODO(mjark): Respect formatting requirements in the filter.
return serialize_event( return serialize_event(
event, time_now, token_id=token_id, event, time_now, token_id=token_id,
event_format=format_event_for_client_v2_without_room_id, event_format=event_formatter,
only_event_fields=only_fields, only_event_fields=only_fields,
) )

View file

@ -79,7 +79,7 @@ class ThirdPartyUserServlet(RestServlet):
yield self.auth.get_user_by_req(request, allow_guest=True) yield self.auth.get_user_by_req(request, allow_guest=True)
fields = request.args fields = request.args
fields.pop("access_token", None) fields.pop(b"access_token", None)
results = yield self.appservice_handler.query_3pe( results = yield self.appservice_handler.query_3pe(
ThirdPartyEntityKind.USER, protocol, fields ThirdPartyEntityKind.USER, protocol, fields
@ -102,7 +102,7 @@ class ThirdPartyLocationServlet(RestServlet):
yield self.auth.get_user_by_req(request, allow_guest=True) yield self.auth.get_user_by_req(request, allow_guest=True)
fields = request.args fields = request.args
fields.pop("access_token", None) fields.pop(b"access_token", None)
results = yield self.appservice_handler.query_3pe( results = yield self.appservice_handler.query_3pe(
ThirdPartyEntityKind.LOCATION, protocol, fields ThirdPartyEntityKind.LOCATION, protocol, fields

View file

@ -88,5 +88,5 @@ class LocalKey(Resource):
) )
def getChild(self, name, request): def getChild(self, name, request):
if name == '': if name == b'':
return self return self

View file

@ -22,5 +22,5 @@ from .remote_key_resource import RemoteKey
class KeyApiV2Resource(Resource): class KeyApiV2Resource(Resource):
def __init__(self, hs): def __init__(self, hs):
Resource.__init__(self) Resource.__init__(self)
self.putChild("server", LocalKey(hs)) self.putChild(b"server", LocalKey(hs))
self.putChild("query", RemoteKey(hs)) self.putChild(b"query", RemoteKey(hs))

View file

@ -103,7 +103,7 @@ class RemoteKey(Resource):
def async_render_GET(self, request): def async_render_GET(self, request):
if len(request.postpath) == 1: if len(request.postpath) == 1:
server, = request.postpath server, = request.postpath
query = {server: {}} query = {server.decode('ascii'): {}}
elif len(request.postpath) == 2: elif len(request.postpath) == 2:
server, key_id = request.postpath server, key_id = request.postpath
minimum_valid_until_ts = parse_integer( minimum_valid_until_ts = parse_integer(
@ -112,11 +112,12 @@ class RemoteKey(Resource):
arguments = {} arguments = {}
if minimum_valid_until_ts is not None: if minimum_valid_until_ts is not None:
arguments["minimum_valid_until_ts"] = minimum_valid_until_ts arguments["minimum_valid_until_ts"] = minimum_valid_until_ts
query = {server: {key_id: arguments}} query = {server.decode('ascii'): {key_id.decode('ascii'): arguments}}
else: else:
raise SynapseError( raise SynapseError(
404, "Not found %r" % request.postpath, Codes.NOT_FOUND 404, "Not found %r" % request.postpath, Codes.NOT_FOUND
) )
yield self.query_keys(request, query, query_remote_on_cache_miss=True) yield self.query_keys(request, query, query_remote_on_cache_miss=True)
def render_POST(self, request): def render_POST(self, request):
@ -135,6 +136,7 @@ class RemoteKey(Resource):
@defer.inlineCallbacks @defer.inlineCallbacks
def query_keys(self, request, query, query_remote_on_cache_miss=False): def query_keys(self, request, query, query_remote_on_cache_miss=False):
logger.info("Handling query for keys %r", query) logger.info("Handling query for keys %r", query)
store_queries = [] store_queries = []
for server_name, key_ids in query.items(): for server_name, key_ids in query.items():
if ( if (

View file

@ -56,7 +56,7 @@ class ContentRepoResource(resource.Resource):
# servers. # servers.
# TODO: A little crude here, we could do this better. # TODO: A little crude here, we could do this better.
filename = request.path.split('/')[-1] filename = request.path.decode('ascii').split('/')[-1]
# be paranoid # be paranoid
filename = re.sub("[^0-9A-z.-_]", "", filename) filename = re.sub("[^0-9A-z.-_]", "", filename)
@ -78,7 +78,7 @@ class ContentRepoResource(resource.Resource):
# select private. don't bother setting Expires as all our matrix # select private. don't bother setting Expires as all our matrix
# clients are smart enough to be happy with Cache-Control (right?) # clients are smart enough to be happy with Cache-Control (right?)
request.setHeader( request.setHeader(
"Cache-Control", "public,max-age=86400,s-maxage=86400" b"Cache-Control", b"public,max-age=86400,s-maxage=86400"
) )
d = FileSender().beginFileTransfer(f, request) d = FileSender().beginFileTransfer(f, request)

View file

@ -15,9 +15,8 @@
import logging import logging
import os import os
import urllib
from six.moves.urllib import parse as urlparse from six.moves import urllib
from twisted.internet import defer from twisted.internet import defer
from twisted.protocols.basic import FileSender from twisted.protocols.basic import FileSender
@ -35,10 +34,15 @@ def parse_media_id(request):
# This allows users to append e.g. /test.png to the URL. Useful for # This allows users to append e.g. /test.png to the URL. Useful for
# clients that parse the URL to see content type. # clients that parse the URL to see content type.
server_name, media_id = request.postpath[:2] server_name, media_id = request.postpath[:2]
if isinstance(server_name, bytes):
server_name = server_name.decode('utf-8')
media_id = media_id.decode('utf8')
file_name = None file_name = None
if len(request.postpath) > 2: if len(request.postpath) > 2:
try: try:
file_name = urlparse.unquote(request.postpath[-1]).decode("utf-8") file_name = urllib.parse.unquote(request.postpath[-1].decode("utf-8"))
except UnicodeDecodeError: except UnicodeDecodeError:
pass pass
return server_name, media_id, file_name return server_name, media_id, file_name
@ -93,22 +97,18 @@ def add_file_headers(request, media_type, file_size, upload_name):
file_size (int): Size in bytes of the media, if known. file_size (int): Size in bytes of the media, if known.
upload_name (str): The name of the requested file, if any. upload_name (str): The name of the requested file, if any.
""" """
def _quote(x):
return urllib.parse.quote(x.encode("utf-8"))
request.setHeader(b"Content-Type", media_type.encode("UTF-8")) request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
if upload_name: if upload_name:
if is_ascii(upload_name): if is_ascii(upload_name):
request.setHeader( disposition = ("inline; filename=%s" % (_quote(upload_name),)).encode("ascii")
b"Content-Disposition",
b"inline; filename=%s" % (
urllib.quote(upload_name.encode("utf-8")),
),
)
else: else:
request.setHeader( disposition = (
b"Content-Disposition", "inline; filename*=utf-8''%s" % (_quote(upload_name),)).encode("ascii")
b"inline; filename*=utf-8''%s" % (
urllib.quote(upload_name.encode("utf-8")), request.setHeader(b"Content-Disposition", disposition)
),
)
# cache for at least a day. # cache for at least a day.
# XXX: we might want to turn this off for data we don't want to # XXX: we might want to turn this off for data we don't want to

View file

@ -47,12 +47,12 @@ class DownloadResource(Resource):
def _async_render_GET(self, request): def _async_render_GET(self, request):
set_cors_headers(request) set_cors_headers(request)
request.setHeader( request.setHeader(
"Content-Security-Policy", b"Content-Security-Policy",
"default-src 'none';" b"default-src 'none';"
" script-src 'none';" b" script-src 'none';"
" plugin-types application/pdf;" b" plugin-types application/pdf;"
" style-src 'unsafe-inline';" b" style-src 'unsafe-inline';"
" object-src 'self';" b" object-src 'self';"
) )
server_name, media_id, name = parse_media_id(request) server_name, media_id, name = parse_media_id(request)
if server_name == self.server_name: if server_name == self.server_name:

View file

@ -20,7 +20,7 @@ import logging
import os import os
import shutil import shutil
from six import iteritems from six import PY3, iteritems
from six.moves.urllib import parse as urlparse from six.moves.urllib import parse as urlparse
import twisted.internet.error import twisted.internet.error
@ -397,13 +397,13 @@ class MediaRepository(object):
yield finish() yield finish()
media_type = headers["Content-Type"][0] media_type = headers[b"Content-Type"][0].decode('ascii')
time_now_ms = self.clock.time_msec() time_now_ms = self.clock.time_msec()
content_disposition = headers.get("Content-Disposition", None) content_disposition = headers.get(b"Content-Disposition", None)
if content_disposition: if content_disposition:
_, params = cgi.parse_header(content_disposition[0],) _, params = cgi.parse_header(content_disposition[0].decode('ascii'),)
upload_name = None upload_name = None
# First check if there is a valid UTF-8 filename # First check if there is a valid UTF-8 filename
@ -419,9 +419,13 @@ class MediaRepository(object):
upload_name = upload_name_ascii upload_name = upload_name_ascii
if upload_name: if upload_name:
upload_name = urlparse.unquote(upload_name) if PY3:
upload_name = urlparse.unquote(upload_name)
else:
upload_name = urlparse.unquote(upload_name.encode('ascii'))
try: try:
upload_name = upload_name.decode("utf-8") if isinstance(upload_name, bytes):
upload_name = upload_name.decode("utf-8")
except UnicodeDecodeError: except UnicodeDecodeError:
upload_name = None upload_name = None
else: else:
@ -755,14 +759,15 @@ class MediaRepositoryResource(Resource):
Resource.__init__(self) Resource.__init__(self)
media_repo = hs.get_media_repository() media_repo = hs.get_media_repository()
self.putChild("upload", UploadResource(hs, media_repo))
self.putChild("download", DownloadResource(hs, media_repo)) self.putChild(b"upload", UploadResource(hs, media_repo))
self.putChild("thumbnail", ThumbnailResource( self.putChild(b"download", DownloadResource(hs, media_repo))
self.putChild(b"thumbnail", ThumbnailResource(
hs, media_repo, media_repo.media_storage, hs, media_repo, media_repo.media_storage,
)) ))
self.putChild("identicon", IdenticonResource()) self.putChild(b"identicon", IdenticonResource())
if hs.config.url_preview_enabled: if hs.config.url_preview_enabled:
self.putChild("preview_url", PreviewUrlResource( self.putChild(b"preview_url", PreviewUrlResource(
hs, media_repo, media_repo.media_storage, hs, media_repo, media_repo.media_storage,
)) ))
self.putChild("config", MediaConfigResource(hs)) self.putChild(b"config", MediaConfigResource(hs))

View file

@ -261,7 +261,7 @@ class PreviewUrlResource(Resource):
logger.debug("Calculated OG for %s as %s" % (url, og)) logger.debug("Calculated OG for %s as %s" % (url, og))
jsonog = json.dumps(og) jsonog = json.dumps(og).encode('utf8')
# store OG in history-aware DB cache # store OG in history-aware DB cache
yield self.store.store_url_cache( yield self.store.store_url_cache(
@ -301,20 +301,20 @@ class PreviewUrlResource(Resource):
logger.warn("Error downloading %s: %r", url, e) logger.warn("Error downloading %s: %r", url, e)
raise SynapseError( raise SynapseError(
500, "Failed to download content: %s" % ( 500, "Failed to download content: %s" % (
traceback.format_exception_only(sys.exc_type, e), traceback.format_exception_only(sys.exc_info()[0], e),
), ),
Codes.UNKNOWN, Codes.UNKNOWN,
) )
yield finish() yield finish()
try: try:
if "Content-Type" in headers: if b"Content-Type" in headers:
media_type = headers["Content-Type"][0] media_type = headers[b"Content-Type"][0].decode('ascii')
else: else:
media_type = "application/octet-stream" media_type = "application/octet-stream"
time_now_ms = self.clock.time_msec() time_now_ms = self.clock.time_msec()
content_disposition = headers.get("Content-Disposition", None) content_disposition = headers.get(b"Content-Disposition", None)
if content_disposition: if content_disposition:
_, params = cgi.parse_header(content_disposition[0],) _, params = cgi.parse_header(content_disposition[0],)
download_name = None download_name = None

View file

@ -929,6 +929,10 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
txn, self.get_users_in_room, (room_id,) txn, self.get_users_in_room, (room_id,)
) )
self._invalidate_cache_and_stream(
txn, self.get_room_summary, (room_id,)
)
self._invalidate_cache_and_stream( self._invalidate_cache_and_stream(
txn, self.get_current_state_ids, (room_id,) txn, self.get_current_state_ids, (room_id,)
) )
@ -1886,20 +1890,6 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
")" ")"
) )
# create an index on should_delete because later we'll be looking for
# the should_delete / shouldn't_delete subsets
txn.execute(
"CREATE INDEX events_to_purge_should_delete"
" ON events_to_purge(should_delete)",
)
# We do joins against events_to_purge for e.g. calculating state
# groups to purge, etc., so lets make an index.
txn.execute(
"CREATE INDEX events_to_purge_id"
" ON events_to_purge(event_id)",
)
# First ensure that we're not about to delete all the forward extremeties # First ensure that we're not about to delete all the forward extremeties
txn.execute( txn.execute(
"SELECT e.event_id, e.depth FROM events as e " "SELECT e.event_id, e.depth FROM events as e "
@ -1926,19 +1916,45 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
should_delete_params = () should_delete_params = ()
if not delete_local_events: if not delete_local_events:
should_delete_expr += " AND event_id NOT LIKE ?" should_delete_expr += " AND event_id NOT LIKE ?"
should_delete_params += ("%:" + self.hs.hostname, )
# We include the parameter twice since we use the expression twice
should_delete_params += (
"%:" + self.hs.hostname,
"%:" + self.hs.hostname,
)
should_delete_params += (room_id, token.topological) should_delete_params += (room_id, token.topological)
# Note that we insert events that are outliers and aren't going to be
# deleted, as nothing will happen to them.
txn.execute( txn.execute(
"INSERT INTO events_to_purge" "INSERT INTO events_to_purge"
" SELECT event_id, %s" " SELECT event_id, %s"
" FROM events AS e LEFT JOIN state_events USING (event_id)" " FROM events AS e LEFT JOIN state_events USING (event_id)"
" WHERE e.room_id = ? AND topological_ordering < ?" % ( " WHERE (NOT outlier OR (%s)) AND e.room_id = ? AND topological_ordering < ?"
% (
should_delete_expr,
should_delete_expr, should_delete_expr,
), ),
should_delete_params, should_delete_params,
) )
# We create the indices *after* insertion as that's a lot faster.
# create an index on should_delete because later we'll be looking for
# the should_delete / shouldn't_delete subsets
txn.execute(
"CREATE INDEX events_to_purge_should_delete"
" ON events_to_purge(should_delete)",
)
# We do joins against events_to_purge for e.g. calculating state
# groups to purge, etc., so lets make an index.
txn.execute(
"CREATE INDEX events_to_purge_id"
" ON events_to_purge(event_id)",
)
txn.execute( txn.execute(
"SELECT event_id, should_delete FROM events_to_purge" "SELECT event_id, should_delete FROM events_to_purge"
) )

View file

@ -134,6 +134,7 @@ class KeyStore(SQLBaseStore):
""" """
key_id = "%s:%s" % (verify_key.alg, verify_key.version) key_id = "%s:%s" % (verify_key.alg, verify_key.version)
# XXX fix this to not need a lock (#3819)
def _txn(txn): def _txn(txn):
self._simple_upsert_txn( self._simple_upsert_txn(
txn, txn,

View file

@ -146,6 +146,23 @@ class MonthlyActiveUsersStore(SQLBaseStore):
return count return count
return self.runInteraction("count_users", _count_users) return self.runInteraction("count_users", _count_users)
@defer.inlineCallbacks
def get_registered_reserved_users_count(self):
"""Of the reserved threepids defined in config, how many are associated
with registered users?
Returns:
Defered[int]: Number of real reserved users
"""
count = 0
for tp in self.hs.config.mau_limits_reserved_threepids:
user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
tp["medium"], tp["address"]
)
if user_id:
count = count + 1
defer.returnValue(count)
@defer.inlineCallbacks @defer.inlineCallbacks
def upsert_monthly_active_user(self, user_id): def upsert_monthly_active_user(self, user_id):
""" """
@ -199,10 +216,14 @@ class MonthlyActiveUsersStore(SQLBaseStore):
Args: Args:
user_id(str): the user_id to query user_id(str): the user_id to query
""" """
if self.hs.config.limit_usage_by_mau: if self.hs.config.limit_usage_by_mau:
# Trial users and guests should not be included as part of MAU group
is_guest = yield self.is_guest(user_id)
if is_guest:
return
is_trial = yield self.is_trial_user(user_id) is_trial = yield self.is_trial_user(user_id)
if is_trial: if is_trial:
# we don't track trial users in the MAU table.
return return
last_seen_timestamp = yield self.user_last_seen_monthly_active(user_id) last_seen_timestamp = yield self.user_last_seen_monthly_active(user_id)

View file

@ -51,6 +51,12 @@ ProfileInfo = namedtuple(
"ProfileInfo", ("avatar_url", "display_name") "ProfileInfo", ("avatar_url", "display_name")
) )
# "members" points to a truncated list of (user_id, event_id) tuples for users of
# a given membership type, suitable for use in calculating heroes for a room.
# "count" points to the total numberr of users of a given membership type.
MemberSummary = namedtuple(
"MemberSummary", ("members", "count")
)
_MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update" _MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
@ -82,6 +88,65 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return [to_ascii(r[0]) for r in txn] return [to_ascii(r[0]) for r in txn]
return self.runInteraction("get_users_in_room", f) return self.runInteraction("get_users_in_room", f)
@cached(max_entries=100000)
def get_room_summary(self, room_id):
""" Get the details of a room roughly suitable for use by the room
summary extension to /sync. Useful when lazy loading room members.
Args:
room_id (str): The room ID to query
Returns:
Deferred[dict[str, MemberSummary]:
dict of membership states, pointing to a MemberSummary named tuple.
"""
def _get_room_summary_txn(txn):
# first get counts.
# We do this all in one transaction to keep the cache small.
# FIXME: get rid of this when we have room_stats
sql = """
SELECT count(*), m.membership FROM room_memberships as m
INNER JOIN current_state_events as c
ON m.event_id = c.event_id
AND m.room_id = c.room_id
AND m.user_id = c.state_key
WHERE c.type = 'm.room.member' AND c.room_id = ?
GROUP BY m.membership
"""
txn.execute(sql, (room_id,))
res = {}
for count, membership in txn:
summary = res.setdefault(to_ascii(membership), MemberSummary([], count))
# we order by membership and then fairly arbitrarily by event_id so
# heroes are consistent
sql = """
SELECT m.user_id, m.membership, m.event_id
FROM room_memberships as m
INNER JOIN current_state_events as c
ON m.event_id = c.event_id
AND m.room_id = c.room_id
AND m.user_id = c.state_key
WHERE c.type = 'm.room.member' AND c.room_id = ?
ORDER BY
CASE m.membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC,
m.event_id ASC
LIMIT ?
"""
# 6 is 5 (number of heroes) plus 1, in case one of them is the calling user.
txn.execute(sql, (room_id, Membership.JOIN, Membership.INVITE, 6))
for user_id, membership, event_id in txn:
summary = res[to_ascii(membership)]
# we will always have a summary for this membership type at this
# point given the summary currently contains the counts.
members = summary.members
members.append((to_ascii(user_id), to_ascii(event_id)))
return res
return self.runInteraction("get_room_summary", _get_room_summary_txn)
@cached() @cached()
def get_invited_rooms_for_user(self, user_id): def get_invited_rooms_for_user(self, user_id):
""" Get all the rooms the user is invited to """ Get all the rooms the user is invited to

View file

@ -438,3 +438,55 @@ def _cancelled_to_timed_out_error(value, timeout):
value.trap(CancelledError) value.trap(CancelledError)
raise DeferredTimeoutError(timeout, "Deferred") raise DeferredTimeoutError(timeout, "Deferred")
return value return value
def timeout_no_seriously(deferred, timeout, reactor):
"""The in build twisted deferred addTimeout (and the method above)
completely fail to time things out under some unknown circumstances.
Lets try a different way of timing things out and maybe that will make
things work?!
TODO: Kill this with fire.
"""
new_d = defer.Deferred()
timed_out = [False]
def time_it_out():
timed_out[0] = True
if not new_d.called:
new_d.errback(DeferredTimeoutError(timeout, "Deferred"))
deferred.cancel()
delayed_call = reactor.callLater(timeout, time_it_out)
def convert_cancelled(value):
if timed_out[0]:
return _cancelled_to_timed_out_error(value, timeout)
return value
deferred.addBoth(convert_cancelled)
def cancel_timeout(result):
# stop the pending call to cancel the deferred if it's been fired
if delayed_call.active():
delayed_call.cancel()
return result
deferred.addBoth(cancel_timeout)
def success_cb(val):
if not new_d.called:
new_d.callback(val)
def failure_cb(val):
if not new_d.called:
new_d.errback(val)
deferred.addCallbacks(success_cb, failure_cb)
return new_d

View file

@ -19,22 +19,40 @@ from twisted.conch.ssh.keys import Key
from twisted.cred import checkers, portal from twisted.cred import checkers, portal
PUBLIC_KEY = ( PUBLIC_KEY = (
"ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAGEArzJx8OYOnJmzf4tfBEvLi8DVPrJ3/c9k2I/Az" "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDHhGATaW4KhE23+7nrH4jFx3yLq9OjaEs5"
"64fxjHf9imyRJbixtQhlH9lfNjUIx+4LmrJH5QNRsFporcHDKOTwTTYLh5KmRpslkYHRivcJS" "XALqeK+7385NlLja3DE/DO9mGhnd9+bAy39EKT3sTV6+WXQ4yD0TvEEyUEMtjWkSEm6U32+C"
"kbh/C+BR3utDS555mV" "DaS3TW/vPBUMeJQwq+Ydcif1UlnpXrDDTamD0AU9VaEvHq+3HAkipqn0TGpKON6aqk4vauDx"
"oXSsV5TXBVrxP/y7HpMOpU4GUWsaaacBTKKNnUaQB4UflvydaPJUuwdaCUJGTMjbhWrjVfK+"
"jslseSPxU6XvrkZMyCr4znxvuDxjMk1RGIdO7v+rbBMLEgqtSMNqJbYeVCnj2CFgc3fcTcld"
"X2uOJDrJb/WRlHulthCh"
) )
PRIVATE_KEY = """-----BEGIN RSA PRIVATE KEY----- PRIVATE_KEY = """-----BEGIN RSA PRIVATE KEY-----
MIIByAIBAAJhAK8ycfDmDpyZs3+LXwRLy4vA1T6yd/3PZNiPwM+uH8Yx3/YpskSW MIIEpQIBAAKCAQEAx4RgE2luCoRNt/u56x+Ixcd8i6vTo2hLOVwC6nivu9/OTZS4
4sbUIZR/ZXzY1CMfuC5qyR+UDUbBaaK3Bwyjk8E02C4eSpkabJZGB0Yr3CUpG4fw 2twxPwzvZhoZ3ffmwMt/RCk97E1evll0OMg9E7xBMlBDLY1pEhJulN9vgg2kt01v
vgUd7rQ0ueeZlQIBIwJgbh+1VZfr7WftK5lu7MHtqE1S1vPWZQYE3+VUn8yJADyb 7zwVDHiUMKvmHXIn9VJZ6V6ww02pg9AFPVWhLx6vtxwJIqap9ExqSjjemqpOL2rg
Z4fsZaCrzW9lkIqXkE3GIY+ojdhZhkO1gbG0118sIgphwSWKRxK0mvh6ERxKqIt1 8aF0rFeU1wVa8T/8ux6TDqVOBlFrGmmnAUyijZ1GkAeFH5b8nWjyVLsHWglCRkzI
xJEJO74EykXZV4oNJ8sjAjEA3J9r2ZghVhGN6V8DnQrTk24Td0E8hU8AcP0FVP+8 24Vq41Xyvo7JbHkj8VOl765GTMgq+M58b7g8YzJNURiHTu7/q2wTCxIKrUjDaiW2
PQm/g/aXf2QQkQT+omdHVEJrAjEAy0pL0EBH6EVS98evDCBtQw22OZT52qXlAwZ2 HlQp49ghYHN33E3JXV9rjiQ6yW/1kZR7pbYQoQIDAQABAoIBAQC8KJ0q8Wzzwh5B
gyTriKFVoqjeEjt3SZKKqXHSApP/AjBLpF99zcJJZRq2abgYlf9lv1chkrWqDHUu esa1dQHZ8+4DEsL/Amae66VcVwD0X3cCN1W2IZ7X5W0Ij2kBqr8V51RYhcR+S+Ek
DZttmYJeEfiFBBavVYIF1dOlZT0G8jMCMBc7sOSZodFnAiryP+Qg9otSBjJ3bQML BtzSiBUBvbKGrqcMGKaUgomDIMzai99hd0gvCCyZnEW1OQhFkNkaRNXCfqiZJ27M
pSTqy7c3a2AScC/YyOwkDaICHnnD3XyjMwIxALRzl0tQEKMXs6hH8ToUdlLROCrP fqvSUiU2eOwh9fCvmxoA6Of8o3FbzcJ+1GMcobWRllDtLmj6lgVbDzuA+0jC5daB
EhQ0wahUTCk1gKA4uPD6TMTChavbh4K63OvbKg== 9Tj1pBzu3wn3ufxiS+gBnJ+7NcXH3E73lqCcPa2ufbZ1haxfiGCnRIhFXuQDgxFX
vKdEfDgtvas6r1ahGbc+b/q8E8fZT7cABuIU4yfOORK+MhpyWbvoyyzuVGKj3PKt
KSPJu5CZAoGBAOkoJfAVyYteqKcmGTanGqQnAY43CaYf6GdSPX/jg+JmKZg0zqMC
jWZUtPb93i+jnOInbrnuHOiHAxI8wmhEPed28H2lC/LU8PzlqFkZXKFZ4vLOhhRB
/HeHCFIDosPFlohWi3b+GAjD7sXgnIuGmnXWe2ea/TS3yersifDEoKKjAoGBANsQ
gJX2cJv1c3jhdgcs8vAt5zIOKcCLTOr/QPmVf/kxjNgndswcKHwsxE/voTO9q+TF
v/6yCSTxAdjuKz1oIYWgi/dZo82bBKWxNRpgrGviU3/zwxiHlyIXUhzQu78q3VS/
7S1XVbc7qMV++XkYKHPVD+nVG/gGzFxumX7MLXfrAoGBAJit9cn2OnjNj9uFE1W6
r7N254ndeLAUjPe73xH0RtTm2a4WRopwjW/JYIetTuYbWgyujc+robqTTuuOZjAp
H/CG7o0Ym251CypQqaFO/l2aowclPp/dZhpPjp9GSjuxFBZLtiBB3DNBOwbRQzIK
/vLTdRQvZkgzYkI4i0vjNt3JAoGBANP8HSKBLymMlShlrSx2b8TB9tc2Y2riohVJ
2ttqs0M2kt/dGJWdrgOz4mikL+983Olt/0P9juHDoxEEMK2kpcPEv40lnmBpYU7h
s8yJvnBLvJe2EJYdJ8AipyAhUX1FgpbvfxmASP8eaUxsegeXvBWTGWojAoS6N2o+
0KSl+l3vAoGAFqm0gO9f/Q1Se60YQd4l2PZeMnJFv0slpgHHUwegmd6wJhOD7zJ1
CkZcXwiv7Nog7AI9qKJEUXLjoqL+vJskBzSOqU3tcd670YQMi1aXSXJqYE202K7o
EddTrx3TNpr1D5m/f+6mnXWrc8u9y1+GNx9yz889xMjIBTBI9KqaaOs=
-----END RSA PRIVATE KEY-----""" -----END RSA PRIVATE KEY-----"""

View file

@ -20,6 +20,7 @@ from prometheus_client import Counter
from twisted.internet import defer from twisted.internet import defer
from synapse.metrics import InFlightGauge
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -45,6 +46,13 @@ block_db_txn_duration = Counter(
block_db_sched_duration = Counter( block_db_sched_duration = Counter(
"synapse_util_metrics_block_db_sched_duration_seconds", "", ["block_name"]) "synapse_util_metrics_block_db_sched_duration_seconds", "", ["block_name"])
# Tracks the number of blocks currently active
in_flight = InFlightGauge(
"synapse_util_metrics_block_in_flight", "",
labels=["block_name"],
sub_metrics=["real_time_max", "real_time_sum"],
)
def measure_func(name): def measure_func(name):
def wrapper(func): def wrapper(func):
@ -82,10 +90,14 @@ class Measure(object):
self.start_usage = self.start_context.get_resource_usage() self.start_usage = self.start_context.get_resource_usage()
in_flight.register((self.name,), self._update_in_flight)
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
if isinstance(exc_type, Exception) or not self.start_context: if isinstance(exc_type, Exception) or not self.start_context:
return return
in_flight.unregister((self.name,), self._update_in_flight)
duration = self.clock.time() - self.start duration = self.clock.time() - self.start
block_counter.labels(self.name).inc() block_counter.labels(self.name).inc()
@ -120,3 +132,13 @@ class Measure(object):
if self.created_context: if self.created_context:
self.start_context.__exit__(exc_type, exc_val, exc_tb) self.start_context.__exit__(exc_type, exc_val, exc_tb)
def _update_in_flight(self, metrics):
"""Gets called when processing in flight metrics
"""
duration = self.clock.time() - self.start
metrics.real_time_max = max(metrics.real_time_max, duration)
metrics.real_time_sum += duration
# TODO: Add other in flight metrics.

View file

@ -471,6 +471,7 @@ class AuthTestCase(unittest.TestCase):
def test_reserved_threepid(self): def test_reserved_threepid(self):
self.hs.config.limit_usage_by_mau = True self.hs.config.limit_usage_by_mau = True
self.hs.config.max_mau_value = 1 self.hs.config.max_mau_value = 1
self.store.get_monthly_active_count = lambda: defer.succeed(2)
threepid = {'medium': 'email', 'address': 'reserved@server.com'} threepid = {'medium': 'email', 'address': 'reserved@server.com'}
unknown_threepid = {'medium': 'email', 'address': 'unreserved@server.com'} unknown_threepid = {'medium': 'email', 'address': 'unreserved@server.com'}
self.hs.config.mau_limits_reserved_threepids = [threepid] self.hs.config.mau_limits_reserved_threepids = [threepid]

View file

@ -47,7 +47,7 @@ class FrontendProxyTests(HomeserverTestCase):
self.assertEqual(len(self.reactor.tcpServers), 1) self.assertEqual(len(self.reactor.tcpServers), 1)
site = self.reactor.tcpServers[0][1] site = self.reactor.tcpServers[0][1]
self.resource = ( self.resource = (
site.resource.children["_matrix"].children["client"].children["r0"] site.resource.children[b"_matrix"].children[b"client"].children[b"r0"]
) )
request, channel = self.make_request("PUT", "presence/a/status") request, channel = self.make_request("PUT", "presence/a/status")
@ -77,7 +77,7 @@ class FrontendProxyTests(HomeserverTestCase):
self.assertEqual(len(self.reactor.tcpServers), 1) self.assertEqual(len(self.reactor.tcpServers), 1)
site = self.reactor.tcpServers[0][1] site = self.reactor.tcpServers[0][1]
self.resource = ( self.resource = (
site.resource.children["_matrix"].children["client"].children["r0"] site.resource.children[b"_matrix"].children[b"client"].children[b"r0"]
) )
request, channel = self.make_request("PUT", "presence/a/status") request, channel = self.make_request("PUT", "presence/a/status")

View file

@ -43,9 +43,7 @@ def _expect_edu_transaction(edu_type, content, origin="test"):
def _make_edu_transaction_json(edu_type, content): def _make_edu_transaction_json(edu_type, content):
return json.dumps(_expect_edu_transaction(edu_type, content)).encode( return json.dumps(_expect_edu_transaction(edu_type, content)).encode('utf8')
'utf8'
)
class TypingNotificationsTestCase(unittest.TestCase): class TypingNotificationsTestCase(unittest.TestCase):

View file

@ -0,0 +1,157 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from mock import Mock
from twisted.internet.defer import TimeoutError
from twisted.internet.error import ConnectingCancelledError, DNSLookupError
from twisted.web.client import ResponseNeverReceived
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from tests.unittest import HomeserverTestCase
class FederationClientTests(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(reactor=reactor, clock=clock)
hs.tls_client_options_factory = None
return hs
def prepare(self, reactor, clock, homeserver):
self.cl = MatrixFederationHttpClient(self.hs)
self.reactor.lookups["testserv"] = "1.2.3.4"
def test_dns_error(self):
"""
If the DNS raising returns an error, it will bubble up.
"""
d = self.cl._request("testserv2:8008", "GET", "foo/bar", timeout=10000)
self.pump()
f = self.failureResultOf(d)
self.assertIsInstance(f.value, DNSLookupError)
def test_client_never_connect(self):
"""
If the HTTP request is not connected and is timed out, it'll give a
ConnectingCancelledError.
"""
d = self.cl._request("testserv:8008", "GET", "foo/bar", timeout=10000)
self.pump()
# Nothing happened yet
self.assertFalse(d.called)
# Make sure treq is trying to connect
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1)
self.assertEqual(clients[0][0], '1.2.3.4')
self.assertEqual(clients[0][1], 8008)
# Deferred is still without a result
self.assertFalse(d.called)
# Push by enough to time it out
self.reactor.advance(10.5)
f = self.failureResultOf(d)
self.assertIsInstance(f.value, ConnectingCancelledError)
def test_client_connect_no_response(self):
"""
If the HTTP request is connected, but gets no response before being
timed out, it'll give a ResponseNeverReceived.
"""
d = self.cl._request("testserv:8008", "GET", "foo/bar", timeout=10000)
self.pump()
# Nothing happened yet
self.assertFalse(d.called)
# Make sure treq is trying to connect
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1)
self.assertEqual(clients[0][0], '1.2.3.4')
self.assertEqual(clients[0][1], 8008)
conn = Mock()
client = clients[0][2].buildProtocol(None)
client.makeConnection(conn)
# Deferred is still without a result
self.assertFalse(d.called)
# Push by enough to time it out
self.reactor.advance(10.5)
f = self.failureResultOf(d)
self.assertIsInstance(f.value, ResponseNeverReceived)
def test_client_gets_headers(self):
"""
Once the client gets the headers, _request returns successfully.
"""
d = self.cl._request("testserv:8008", "GET", "foo/bar", timeout=10000)
self.pump()
conn = Mock()
clients = self.reactor.tcpClients
client = clients[0][2].buildProtocol(None)
client.makeConnection(conn)
# Deferred does not have a result
self.assertFalse(d.called)
# Send it the HTTP response
client.dataReceived(b"HTTP/1.1 200 OK\r\nServer: Fake\r\n\r\n")
# We should get a successful response
r = self.successResultOf(d)
self.assertEqual(r.code, 200)
def test_client_headers_no_body(self):
"""
If the HTTP request is connected, but gets no response before being
timed out, it'll give a ResponseNeverReceived.
"""
d = self.cl.post_json("testserv:8008", "foo/bar", timeout=10000)
self.pump()
conn = Mock()
clients = self.reactor.tcpClients
client = clients[0][2].buildProtocol(None)
client.makeConnection(conn)
# Deferred does not have a result
self.assertFalse(d.called)
# Send it the HTTP response
client.dataReceived(
(b"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n"
b"Server: Fake\r\n\r\n")
)
# Push by enough to time it out
self.reactor.advance(10.5)
f = self.failureResultOf(d)
self.assertIsInstance(f.value, TimeoutError)

View file

@ -22,39 +22,24 @@ from six.moves.urllib import parse as urlparse
from twisted.internet import defer from twisted.internet import defer
import synapse.rest.client.v1.room
from synapse.api.constants import Membership from synapse.api.constants import Membership
from synapse.http.server import JsonResource from synapse.rest.client.v1 import room
from synapse.types import UserID
from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.server import (
ThreadedMemoryReactorClock,
make_request,
render,
setup_test_homeserver,
)
from .utils import RestHelper
PATH_PREFIX = b"/_matrix/client/api/v1" PATH_PREFIX = b"/_matrix/client/api/v1"
class RoomBase(unittest.TestCase): class RoomBase(unittest.HomeserverTestCase):
rmcreator_id = None rmcreator_id = None
def setUp(self): servlets = [room.register_servlets, room.register_deprecated_servlets]
self.clock = ThreadedMemoryReactorClock() def make_homeserver(self, reactor, clock):
self.hs_clock = Clock(self.clock)
self.hs = setup_test_homeserver( self.hs = self.setup_test_homeserver(
self.addCleanup,
"red", "red",
http_client=None, http_client=None,
clock=self.hs_clock,
reactor=self.clock,
federation_client=Mock(), federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=["send_message"]), ratelimiter=NonCallableMock(spec_set=["send_message"]),
) )
@ -63,42 +48,21 @@ class RoomBase(unittest.TestCase):
self.hs.get_federation_handler = Mock(return_value=Mock()) self.hs.get_federation_handler = Mock(return_value=Mock())
def get_user_by_access_token(token=None, allow_guest=False):
return {
"user": UserID.from_string(self.helper.auth_user_id),
"token_id": 1,
"is_guest": False,
}
def get_user_by_req(request, allow_guest=False, rights="access"):
return synapse.types.create_requester(
UserID.from_string(self.helper.auth_user_id), 1, False, None
)
self.hs.get_auth().get_user_by_req = get_user_by_req
self.hs.get_auth().get_user_by_access_token = get_user_by_access_token
self.hs.get_auth().get_access_token_from_request = Mock(return_value=b"1234")
def _insert_client_ip(*args, **kwargs): def _insert_client_ip(*args, **kwargs):
return defer.succeed(None) return defer.succeed(None)
self.hs.get_datastore().insert_client_ip = _insert_client_ip self.hs.get_datastore().insert_client_ip = _insert_client_ip
self.resource = JsonResource(self.hs) return self.hs
synapse.rest.client.v1.room.register_servlets(self.hs, self.resource)
synapse.rest.client.v1.room.register_deprecated_servlets(self.hs, self.resource)
self.helper = RestHelper(self.hs, self.resource, self.user_id)
class RoomPermissionsTestCase(RoomBase): class RoomPermissionsTestCase(RoomBase):
""" Tests room permissions. """ """ Tests room permissions. """
user_id = b"@sid1:red" user_id = "@sid1:red"
rmcreator_id = b"@notme:red" rmcreator_id = "@notme:red"
def setUp(self): def prepare(self, reactor, clock, hs):
super(RoomPermissionsTestCase, self).setUp()
self.helper.auth_user_id = self.rmcreator_id self.helper.auth_user_id = self.rmcreator_id
# create some rooms under the name rmcreator_id # create some rooms under the name rmcreator_id
@ -114,22 +78,20 @@ class RoomPermissionsTestCase(RoomBase):
self.created_rmid_msg_path = ( self.created_rmid_msg_path = (
"rooms/%s/send/m.room.message/a1" % (self.created_rmid) "rooms/%s/send/m.room.message/a1" % (self.created_rmid)
).encode('ascii') ).encode('ascii')
request, channel = make_request( request, channel = self.make_request(
b"PUT", "PUT", self.created_rmid_msg_path, b'{"msgtype":"m.text","body":"test msg"}'
self.created_rmid_msg_path,
b'{"msgtype":"m.text","body":"test msg"}',
) )
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result) self.assertEquals(200, channel.code, channel.result)
# set topic for public room # set topic for public room
request, channel = make_request( request, channel = self.make_request(
b"PUT", "PUT",
("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode('ascii'), ("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode('ascii'),
b'{"topic":"Public Room Topic"}', b'{"topic":"Public Room Topic"}',
) )
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result) self.assertEquals(200, channel.code, channel.result)
# auth as user_id now # auth as user_id now
self.helper.auth_user_id = self.user_id self.helper.auth_user_id = self.user_id
@ -140,128 +102,128 @@ class RoomPermissionsTestCase(RoomBase):
seq = iter(range(100)) seq = iter(range(100))
def send_msg_path(): def send_msg_path():
return b"/rooms/%s/send/m.room.message/mid%s" % ( return "/rooms/%s/send/m.room.message/mid%s" % (
self.created_rmid, self.created_rmid,
str(next(seq)).encode('ascii'), str(next(seq)),
) )
# send message in uncreated room, expect 403 # send message in uncreated room, expect 403
request, channel = make_request( request, channel = self.make_request(
b"PUT", "PUT",
b"/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,), "/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,),
msg_content, msg_content,
) )
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(403, channel.code, msg=channel.result["body"])
# send message in created room not joined (no state), expect 403 # send message in created room not joined (no state), expect 403
request, channel = make_request(b"PUT", send_msg_path(), msg_content) request, channel = self.make_request("PUT", send_msg_path(), msg_content)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(403, channel.code, msg=channel.result["body"])
# send message in created room and invited, expect 403 # send message in created room and invited, expect 403
self.helper.invite( self.helper.invite(
room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id
) )
request, channel = make_request(b"PUT", send_msg_path(), msg_content) request, channel = self.make_request("PUT", send_msg_path(), msg_content)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(403, channel.code, msg=channel.result["body"])
# send message in created room and joined, expect 200 # send message in created room and joined, expect 200
self.helper.join(room=self.created_rmid, user=self.user_id) self.helper.join(room=self.created_rmid, user=self.user_id)
request, channel = make_request(b"PUT", send_msg_path(), msg_content) request, channel = self.make_request("PUT", send_msg_path(), msg_content)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(200, channel.code, msg=channel.result["body"])
# send message in created room and left, expect 403 # send message in created room and left, expect 403
self.helper.leave(room=self.created_rmid, user=self.user_id) self.helper.leave(room=self.created_rmid, user=self.user_id)
request, channel = make_request(b"PUT", send_msg_path(), msg_content) request, channel = self.make_request("PUT", send_msg_path(), msg_content)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(403, channel.code, msg=channel.result["body"])
def test_topic_perms(self): def test_topic_perms(self):
topic_content = b'{"topic":"My Topic Name"}' topic_content = b'{"topic":"My Topic Name"}'
topic_path = b"/rooms/%s/state/m.room.topic" % self.created_rmid topic_path = "/rooms/%s/state/m.room.topic" % self.created_rmid
# set/get topic in uncreated room, expect 403 # set/get topic in uncreated room, expect 403
request, channel = make_request( request, channel = self.make_request(
b"PUT", b"/rooms/%s/state/m.room.topic" % self.uncreated_rmid, topic_content "PUT", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid, topic_content
) )
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(403, channel.code, msg=channel.result["body"])
request, channel = make_request( request, channel = self.make_request(
b"GET", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid "GET", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid
) )
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(403, channel.code, msg=channel.result["body"])
# set/get topic in created PRIVATE room not joined, expect 403 # set/get topic in created PRIVATE room not joined, expect 403
request, channel = make_request(b"PUT", topic_path, topic_content) request, channel = self.make_request("PUT", topic_path, topic_content)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(403, channel.code, msg=channel.result["body"])
request, channel = make_request(b"GET", topic_path) request, channel = self.make_request("GET", topic_path)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(403, channel.code, msg=channel.result["body"])
# set topic in created PRIVATE room and invited, expect 403 # set topic in created PRIVATE room and invited, expect 403
self.helper.invite( self.helper.invite(
room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id
) )
request, channel = make_request(b"PUT", topic_path, topic_content) request, channel = self.make_request("PUT", topic_path, topic_content)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(403, channel.code, msg=channel.result["body"])
# get topic in created PRIVATE room and invited, expect 403 # get topic in created PRIVATE room and invited, expect 403
request, channel = make_request(b"GET", topic_path) request, channel = self.make_request("GET", topic_path)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(403, channel.code, msg=channel.result["body"])
# set/get topic in created PRIVATE room and joined, expect 200 # set/get topic in created PRIVATE room and joined, expect 200
self.helper.join(room=self.created_rmid, user=self.user_id) self.helper.join(room=self.created_rmid, user=self.user_id)
# Only room ops can set topic by default # Only room ops can set topic by default
self.helper.auth_user_id = self.rmcreator_id self.helper.auth_user_id = self.rmcreator_id
request, channel = make_request(b"PUT", topic_path, topic_content) request, channel = self.make_request("PUT", topic_path, topic_content)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(200, channel.code, msg=channel.result["body"])
self.helper.auth_user_id = self.user_id self.helper.auth_user_id = self.user_id
request, channel = make_request(b"GET", topic_path) request, channel = self.make_request("GET", topic_path)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(200, channel.code, msg=channel.result["body"])
self.assert_dict(json.loads(topic_content), channel.json_body) self.assert_dict(json.loads(topic_content.decode('utf8')), channel.json_body)
# set/get topic in created PRIVATE room and left, expect 403 # set/get topic in created PRIVATE room and left, expect 403
self.helper.leave(room=self.created_rmid, user=self.user_id) self.helper.leave(room=self.created_rmid, user=self.user_id)
request, channel = make_request(b"PUT", topic_path, topic_content) request, channel = self.make_request("PUT", topic_path, topic_content)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(403, channel.code, msg=channel.result["body"])
request, channel = make_request(b"GET", topic_path) request, channel = self.make_request("GET", topic_path)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(200, channel.code, msg=channel.result["body"])
# get topic in PUBLIC room, not joined, expect 403 # get topic in PUBLIC room, not joined, expect 403
request, channel = make_request( request, channel = self.make_request(
b"GET", b"/rooms/%s/state/m.room.topic" % self.created_public_rmid "GET", "/rooms/%s/state/m.room.topic" % self.created_public_rmid
) )
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(403, channel.code, msg=channel.result["body"])
# set topic in PUBLIC room, not joined, expect 403 # set topic in PUBLIC room, not joined, expect 403
request, channel = make_request( request, channel = self.make_request(
b"PUT", "PUT",
b"/rooms/%s/state/m.room.topic" % self.created_public_rmid, "/rooms/%s/state/m.room.topic" % self.created_public_rmid,
topic_content, topic_content,
) )
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(403, channel.code, msg=channel.result["body"])
def _test_get_membership(self, room=None, members=[], expect_code=None): def _test_get_membership(self, room=None, members=[], expect_code=None):
for member in members: for member in members:
path = b"/rooms/%s/state/m.room.member/%s" % (room, member) path = "/rooms/%s/state/m.room.member/%s" % (room, member)
request, channel = make_request(b"GET", path) request, channel = self.make_request("GET", path)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(expect_code, int(channel.result["code"])) self.assertEquals(expect_code, channel.code)
def test_membership_basic_room_perms(self): def test_membership_basic_room_perms(self):
# === room does not exist === # === room does not exist ===
@ -428,217 +390,211 @@ class RoomPermissionsTestCase(RoomBase):
class RoomsMemberListTestCase(RoomBase): class RoomsMemberListTestCase(RoomBase):
""" Tests /rooms/$room_id/members/list REST events.""" """ Tests /rooms/$room_id/members/list REST events."""
user_id = b"@sid1:red" user_id = "@sid1:red"
def test_get_member_list(self): def test_get_member_list(self):
room_id = self.helper.create_room_as(self.user_id) room_id = self.helper.create_room_as(self.user_id)
request, channel = make_request(b"GET", b"/rooms/%s/members" % room_id) request, channel = self.make_request("GET", "/rooms/%s/members" % room_id)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(200, channel.code, msg=channel.result["body"])
def test_get_member_list_no_room(self): def test_get_member_list_no_room(self):
request, channel = make_request(b"GET", b"/rooms/roomdoesnotexist/members") request, channel = self.make_request("GET", "/rooms/roomdoesnotexist/members")
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(403, channel.code, msg=channel.result["body"])
def test_get_member_list_no_permission(self): def test_get_member_list_no_permission(self):
room_id = self.helper.create_room_as(b"@some_other_guy:red") room_id = self.helper.create_room_as("@some_other_guy:red")
request, channel = make_request(b"GET", b"/rooms/%s/members" % room_id) request, channel = self.make_request("GET", "/rooms/%s/members" % room_id)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(403, channel.code, msg=channel.result["body"])
def test_get_member_list_mixed_memberships(self): def test_get_member_list_mixed_memberships(self):
room_creator = b"@some_other_guy:red" room_creator = "@some_other_guy:red"
room_id = self.helper.create_room_as(room_creator) room_id = self.helper.create_room_as(room_creator)
room_path = b"/rooms/%s/members" % room_id room_path = "/rooms/%s/members" % room_id
self.helper.invite(room=room_id, src=room_creator, targ=self.user_id) self.helper.invite(room=room_id, src=room_creator, targ=self.user_id)
# can't see list if you're just invited. # can't see list if you're just invited.
request, channel = make_request(b"GET", room_path) request, channel = self.make_request("GET", room_path)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(403, channel.code, msg=channel.result["body"])
self.helper.join(room=room_id, user=self.user_id) self.helper.join(room=room_id, user=self.user_id)
# can see list now joined # can see list now joined
request, channel = make_request(b"GET", room_path) request, channel = self.make_request("GET", room_path)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(200, channel.code, msg=channel.result["body"])
self.helper.leave(room=room_id, user=self.user_id) self.helper.leave(room=room_id, user=self.user_id)
# can see old list once left # can see old list once left
request, channel = make_request(b"GET", room_path) request, channel = self.make_request("GET", room_path)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(200, channel.code, msg=channel.result["body"])
class RoomsCreateTestCase(RoomBase): class RoomsCreateTestCase(RoomBase):
""" Tests /rooms and /rooms/$room_id REST events. """ """ Tests /rooms and /rooms/$room_id REST events. """
user_id = b"@sid1:red" user_id = "@sid1:red"
def test_post_room_no_keys(self): def test_post_room_no_keys(self):
# POST with no config keys, expect new room id # POST with no config keys, expect new room id
request, channel = make_request(b"POST", b"/createRoom", b"{}") request, channel = self.make_request("POST", "/createRoom", "{}")
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(200, int(channel.result["code"]), channel.result) self.assertEquals(200, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body) self.assertTrue("room_id" in channel.json_body)
def test_post_room_visibility_key(self): def test_post_room_visibility_key(self):
# POST with visibility config key, expect new room id # POST with visibility config key, expect new room id
request, channel = make_request( request, channel = self.make_request(
b"POST", b"/createRoom", b'{"visibility":"private"}' "POST", "/createRoom", b'{"visibility":"private"}'
) )
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(200, int(channel.result["code"])) self.assertEquals(200, channel.code)
self.assertTrue("room_id" in channel.json_body) self.assertTrue("room_id" in channel.json_body)
def test_post_room_custom_key(self): def test_post_room_custom_key(self):
# POST with custom config keys, expect new room id # POST with custom config keys, expect new room id
request, channel = make_request(b"POST", b"/createRoom", b'{"custom":"stuff"}') request, channel = self.make_request(
render(request, self.resource, self.clock) "POST", "/createRoom", b'{"custom":"stuff"}'
self.assertEquals(200, int(channel.result["code"])) )
self.render(request)
self.assertEquals(200, channel.code)
self.assertTrue("room_id" in channel.json_body) self.assertTrue("room_id" in channel.json_body)
def test_post_room_known_and_unknown_keys(self): def test_post_room_known_and_unknown_keys(self):
# POST with custom + known config keys, expect new room id # POST with custom + known config keys, expect new room id
request, channel = make_request( request, channel = self.make_request(
b"POST", b"/createRoom", b'{"visibility":"private","custom":"things"}' "POST", "/createRoom", b'{"visibility":"private","custom":"things"}'
) )
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(200, int(channel.result["code"])) self.assertEquals(200, channel.code)
self.assertTrue("room_id" in channel.json_body) self.assertTrue("room_id" in channel.json_body)
def test_post_room_invalid_content(self): def test_post_room_invalid_content(self):
# POST with invalid content / paths, expect 400 # POST with invalid content / paths, expect 400
request, channel = make_request(b"POST", b"/createRoom", b'{"visibili') request, channel = self.make_request("POST", "/createRoom", b'{"visibili')
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(400, int(channel.result["code"])) self.assertEquals(400, channel.code)
request, channel = make_request(b"POST", b"/createRoom", b'["hello"]') request, channel = self.make_request("POST", "/createRoom", b'["hello"]')
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(400, int(channel.result["code"])) self.assertEquals(400, channel.code)
class RoomTopicTestCase(RoomBase): class RoomTopicTestCase(RoomBase):
""" Tests /rooms/$room_id/topic REST events. """ """ Tests /rooms/$room_id/topic REST events. """
user_id = b"@sid1:red" user_id = "@sid1:red"
def setUp(self):
super(RoomTopicTestCase, self).setUp()
def prepare(self, reactor, clock, hs):
# create the room # create the room
self.room_id = self.helper.create_room_as(self.user_id) self.room_id = self.helper.create_room_as(self.user_id)
self.path = b"/rooms/%s/state/m.room.topic" % (self.room_id,) self.path = "/rooms/%s/state/m.room.topic" % (self.room_id,)
def test_invalid_puts(self): def test_invalid_puts(self):
# missing keys or invalid json # missing keys or invalid json
request, channel = make_request(b"PUT", self.path, '{}') request, channel = self.make_request("PUT", self.path, '{}')
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = make_request(b"PUT", self.path, '{"_name":"bob"}') request, channel = self.make_request("PUT", self.path, '{"_name":"bo"}')
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = make_request(b"PUT", self.path, '{"nao') request, channel = self.make_request("PUT", self.path, '{"nao')
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = make_request( request, channel = self.make_request(
b"PUT", self.path, '[{"_name":"bob"},{"_name":"jill"}]' "PUT", self.path, '[{"_name":"bo"},{"_name":"jill"}]'
) )
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = make_request(b"PUT", self.path, 'text only') request, channel = self.make_request("PUT", self.path, 'text only')
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = make_request(b"PUT", self.path, '') request, channel = self.make_request("PUT", self.path, '')
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(400, channel.code, msg=channel.result["body"])
# valid key, wrong type # valid key, wrong type
content = '{"topic":["Topic name"]}' content = '{"topic":["Topic name"]}'
request, channel = make_request(b"PUT", self.path, content) request, channel = self.make_request("PUT", self.path, content)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(400, channel.code, msg=channel.result["body"])
def test_rooms_topic(self): def test_rooms_topic(self):
# nothing should be there # nothing should be there
request, channel = make_request(b"GET", self.path) request, channel = self.make_request("GET", self.path)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(404, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(404, channel.code, msg=channel.result["body"])
# valid put # valid put
content = '{"topic":"Topic name"}' content = '{"topic":"Topic name"}'
request, channel = make_request(b"PUT", self.path, content) request, channel = self.make_request("PUT", self.path, content)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(200, channel.code, msg=channel.result["body"])
# valid get # valid get
request, channel = make_request(b"GET", self.path) request, channel = self.make_request("GET", self.path)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(200, channel.code, msg=channel.result["body"])
self.assert_dict(json.loads(content), channel.json_body) self.assert_dict(json.loads(content), channel.json_body)
def test_rooms_topic_with_extra_keys(self): def test_rooms_topic_with_extra_keys(self):
# valid put with extra keys # valid put with extra keys
content = '{"topic":"Seasons","subtopic":"Summer"}' content = '{"topic":"Seasons","subtopic":"Summer"}'
request, channel = make_request(b"PUT", self.path, content) request, channel = self.make_request("PUT", self.path, content)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(200, channel.code, msg=channel.result["body"])
# valid get # valid get
request, channel = make_request(b"GET", self.path) request, channel = self.make_request("GET", self.path)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(200, channel.code, msg=channel.result["body"])
self.assert_dict(json.loads(content), channel.json_body) self.assert_dict(json.loads(content), channel.json_body)
class RoomMemberStateTestCase(RoomBase): class RoomMemberStateTestCase(RoomBase):
""" Tests /rooms/$room_id/members/$user_id/state REST events. """ """ Tests /rooms/$room_id/members/$user_id/state REST events. """
user_id = b"@sid1:red" user_id = "@sid1:red"
def setUp(self): def prepare(self, reactor, clock, hs):
super(RoomMemberStateTestCase, self).setUp()
self.room_id = self.helper.create_room_as(self.user_id) self.room_id = self.helper.create_room_as(self.user_id)
def tearDown(self):
pass
def test_invalid_puts(self): def test_invalid_puts(self):
path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id) path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id)
# missing keys or invalid json # missing keys or invalid json
request, channel = make_request(b"PUT", path, '{}') request, channel = self.make_request("PUT", path, '{}')
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = make_request(b"PUT", path, '{"_name":"bob"}') request, channel = self.make_request("PUT", path, '{"_name":"bo"}')
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = make_request(b"PUT", path, '{"nao') request, channel = self.make_request("PUT", path, '{"nao')
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = make_request( request, channel = self.make_request(
b"PUT", path, b'[{"_name":"bob"},{"_name":"jill"}]' "PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]'
) )
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = make_request(b"PUT", path, 'text only') request, channel = self.make_request("PUT", path, 'text only')
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = make_request(b"PUT", path, '') request, channel = self.make_request("PUT", path, '')
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(400, channel.code, msg=channel.result["body"])
# valid keys, wrong types # valid keys, wrong types
content = '{"membership":["%s","%s","%s"]}' % ( content = '{"membership":["%s","%s","%s"]}' % (
@ -646,9 +602,9 @@ class RoomMemberStateTestCase(RoomBase):
Membership.JOIN, Membership.JOIN,
Membership.LEAVE, Membership.LEAVE,
) )
request, channel = make_request(b"PUT", path, content.encode('ascii')) request, channel = self.make_request("PUT", path, content.encode('ascii'))
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(400, channel.code, msg=channel.result["body"])
def test_rooms_members_self(self): def test_rooms_members_self(self):
path = "/rooms/%s/state/m.room.member/%s" % ( path = "/rooms/%s/state/m.room.member/%s" % (
@ -658,13 +614,13 @@ class RoomMemberStateTestCase(RoomBase):
# valid join message (NOOP since we made the room) # valid join message (NOOP since we made the room)
content = '{"membership":"%s"}' % Membership.JOIN content = '{"membership":"%s"}' % Membership.JOIN
request, channel = make_request(b"PUT", path, content.encode('ascii')) request, channel = self.make_request("PUT", path, content.encode('ascii'))
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(200, channel.code, msg=channel.result["body"])
request, channel = make_request(b"GET", path, None) request, channel = self.make_request("GET", path, None)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(200, channel.code, msg=channel.result["body"])
expected_response = {"membership": Membership.JOIN} expected_response = {"membership": Membership.JOIN}
self.assertEquals(expected_response, channel.json_body) self.assertEquals(expected_response, channel.json_body)
@ -678,13 +634,13 @@ class RoomMemberStateTestCase(RoomBase):
# valid invite message # valid invite message
content = '{"membership":"%s"}' % Membership.INVITE content = '{"membership":"%s"}' % Membership.INVITE
request, channel = make_request(b"PUT", path, content) request, channel = self.make_request("PUT", path, content)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(200, channel.code, msg=channel.result["body"])
request, channel = make_request(b"GET", path, None) request, channel = self.make_request("GET", path, None)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(200, channel.code, msg=channel.result["body"])
self.assertEquals(json.loads(content), channel.json_body) self.assertEquals(json.loads(content), channel.json_body)
def test_rooms_members_other_custom_keys(self): def test_rooms_members_other_custom_keys(self):
@ -699,13 +655,13 @@ class RoomMemberStateTestCase(RoomBase):
Membership.INVITE, Membership.INVITE,
"Join us!", "Join us!",
) )
request, channel = make_request(b"PUT", path, content) request, channel = self.make_request("PUT", path, content)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(200, channel.code, msg=channel.result["body"])
request, channel = make_request(b"GET", path, None) request, channel = self.make_request("GET", path, None)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(200, channel.code, msg=channel.result["body"])
self.assertEquals(json.loads(content), channel.json_body) self.assertEquals(json.loads(content), channel.json_body)
@ -714,60 +670,58 @@ class RoomMessagesTestCase(RoomBase):
user_id = "@sid1:red" user_id = "@sid1:red"
def setUp(self): def prepare(self, reactor, clock, hs):
super(RoomMessagesTestCase, self).setUp()
self.room_id = self.helper.create_room_as(self.user_id) self.room_id = self.helper.create_room_as(self.user_id)
def test_invalid_puts(self): def test_invalid_puts(self):
path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id))
# missing keys or invalid json # missing keys or invalid json
request, channel = make_request(b"PUT", path, '{}') request, channel = self.make_request("PUT", path, b'{}')
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = make_request(b"PUT", path, '{"_name":"bob"}') request, channel = self.make_request("PUT", path, b'{"_name":"bo"}')
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = make_request(b"PUT", path, '{"nao') request, channel = self.make_request("PUT", path, b'{"nao')
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = make_request( request, channel = self.make_request(
b"PUT", path, '[{"_name":"bob"},{"_name":"jill"}]' "PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]'
) )
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = make_request(b"PUT", path, 'text only') request, channel = self.make_request("PUT", path, b'text only')
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = make_request(b"PUT", path, '') request, channel = self.make_request("PUT", path, b'')
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(400, channel.code, msg=channel.result["body"])
def test_rooms_messages_sent(self): def test_rooms_messages_sent(self):
path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id))
content = '{"body":"test","msgtype":{"type":"a"}}' content = b'{"body":"test","msgtype":{"type":"a"}}'
request, channel = make_request(b"PUT", path, content) request, channel = self.make_request("PUT", path, content)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(400, channel.code, msg=channel.result["body"])
# custom message types # custom message types
content = '{"body":"test","msgtype":"test.custom.text"}' content = b'{"body":"test","msgtype":"test.custom.text"}'
request, channel = make_request(b"PUT", path, content) request, channel = self.make_request("PUT", path, content)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(200, channel.code, msg=channel.result["body"])
# m.text message type # m.text message type
path = "/rooms/%s/send/m.room.message/mid2" % (urlparse.quote(self.room_id)) path = "/rooms/%s/send/m.room.message/mid2" % (urlparse.quote(self.room_id))
content = '{"body":"test2","msgtype":"m.text"}' content = b'{"body":"test2","msgtype":"m.text"}'
request, channel = make_request(b"PUT", path, content) request, channel = self.make_request("PUT", path, content)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEquals(200, channel.code, msg=channel.result["body"])
class RoomInitialSyncTestCase(RoomBase): class RoomInitialSyncTestCase(RoomBase):
@ -775,16 +729,16 @@ class RoomInitialSyncTestCase(RoomBase):
user_id = "@sid1:red" user_id = "@sid1:red"
def setUp(self): def prepare(self, reactor, clock, hs):
super(RoomInitialSyncTestCase, self).setUp()
# create the room # create the room
self.room_id = self.helper.create_room_as(self.user_id) self.room_id = self.helper.create_room_as(self.user_id)
def test_initial_sync(self): def test_initial_sync(self):
request, channel = make_request(b"GET", "/rooms/%s/initialSync" % self.room_id) request, channel = self.make_request(
render(request, self.resource, self.clock) "GET", "/rooms/%s/initialSync" % self.room_id
self.assertEquals(200, int(channel.result["code"])) )
self.render(request)
self.assertEquals(200, channel.code)
self.assertEquals(self.room_id, channel.json_body["room_id"]) self.assertEquals(self.room_id, channel.json_body["room_id"])
self.assertEquals("join", channel.json_body["membership"]) self.assertEquals("join", channel.json_body["membership"])
@ -819,17 +773,16 @@ class RoomMessageListTestCase(RoomBase):
user_id = "@sid1:red" user_id = "@sid1:red"
def setUp(self): def prepare(self, reactor, clock, hs):
super(RoomMessageListTestCase, self).setUp()
self.room_id = self.helper.create_room_as(self.user_id) self.room_id = self.helper.create_room_as(self.user_id)
def test_topo_token_is_accepted(self): def test_topo_token_is_accepted(self):
token = "t1-0_0_0_0_0_0_0_0_0" token = "t1-0_0_0_0_0_0_0_0_0"
request, channel = make_request( request, channel = self.make_request(
b"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
) )
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(200, int(channel.result["code"])) self.assertEquals(200, channel.code)
self.assertTrue("start" in channel.json_body) self.assertTrue("start" in channel.json_body)
self.assertEquals(token, channel.json_body['start']) self.assertEquals(token, channel.json_body['start'])
self.assertTrue("chunk" in channel.json_body) self.assertTrue("chunk" in channel.json_body)
@ -837,11 +790,11 @@ class RoomMessageListTestCase(RoomBase):
def test_stream_token_is_accepted_for_fwd_pagianation(self): def test_stream_token_is_accepted_for_fwd_pagianation(self):
token = "s0_0_0_0_0_0_0_0_0" token = "s0_0_0_0_0_0_0_0_0"
request, channel = make_request( request, channel = self.make_request(
b"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
) )
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(200, int(channel.result["code"])) self.assertEquals(200, channel.code)
self.assertTrue("start" in channel.json_body) self.assertTrue("start" in channel.json_body)
self.assertEquals(token, channel.json_body['start']) self.assertEquals(token, channel.json_body['start'])
self.assertTrue("chunk" in channel.json_body) self.assertTrue("chunk" in channel.json_body)

View file

@ -62,12 +62,6 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)
self.assertTrue( self.assertTrue(
set( set(
[ ["next_batch", "rooms", "account_data", "to_device", "device_lists"]
"next_batch",
"rooms",
"account_data",
"to_device",
"device_lists",
]
).issubset(set(channel.json_body.keys())) ).issubset(set(channel.json_body.keys()))
) )

View file

@ -4,9 +4,14 @@ from io import BytesIO
from six import text_type from six import text_type
import attr import attr
from zope.interface import implementer
from twisted.internet import address, threads from twisted.internet import address, threads, udp
from twisted.internet._resolver import HostResolution
from twisted.internet.address import IPv4Address
from twisted.internet.defer import Deferred from twisted.internet.defer import Deferred
from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import IReactorPluggableNameResolver
from twisted.python.failure import Failure from twisted.python.failure import Failure
from twisted.test.proto_helpers import MemoryReactorClock from twisted.test.proto_helpers import MemoryReactorClock
@ -65,7 +70,7 @@ class FakeChannel(object):
def getPeer(self): def getPeer(self):
# We give an address so that getClientIP returns a non null entry, # We give an address so that getClientIP returns a non null entry,
# causing us to record the MAU # causing us to record the MAU
return address.IPv4Address(b"TCP", "127.0.0.1", 3423) return address.IPv4Address("TCP", "127.0.0.1", 3423)
def getHost(self): def getHost(self):
return None return None
@ -154,11 +159,46 @@ def render(request, resource, clock):
wait_until_result(clock, request) wait_until_result(clock, request)
@implementer(IReactorPluggableNameResolver)
class ThreadedMemoryReactorClock(MemoryReactorClock): class ThreadedMemoryReactorClock(MemoryReactorClock):
""" """
A MemoryReactorClock that supports callFromThread. A MemoryReactorClock that supports callFromThread.
""" """
def __init__(self):
self._udp = []
self.lookups = {}
class Resolver(object):
def resolveHostName(
_self,
resolutionReceiver,
hostName,
portNumber=0,
addressTypes=None,
transportSemantics='TCP',
):
resolution = HostResolution(hostName)
resolutionReceiver.resolutionBegan(resolution)
if hostName not in self.lookups:
raise DNSLookupError("OH NO")
resolutionReceiver.addressResolved(
IPv4Address('TCP', self.lookups[hostName], portNumber)
)
resolutionReceiver.resolutionComplete()
return resolution
self.nameResolver = Resolver()
super(ThreadedMemoryReactorClock, self).__init__()
def listenUDP(self, port, protocol, interface='', maxPacketSize=8196):
p = udp.Port(port, protocol, interface, maxPacketSize, self)
p.startListening()
self._udp.append(p)
return p
def callFromThread(self, callback, *args, **kwargs): def callFromThread(self, callback, *args, **kwargs):
""" """
Make the callback fire in the next reactor iteration. Make the callback fire in the next reactor iteration.

View file

@ -80,12 +80,11 @@ class TestResourceLimitsServerNotices(unittest.TestCase):
self._rlsn._auth.check_auth_blocking = Mock() self._rlsn._auth.check_auth_blocking = Mock()
mock_event = Mock( mock_event = Mock(
type=EventTypes.Message, type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
content={"msgtype": ServerNoticeMsgType}, )
self._rlsn._store.get_events = Mock(
return_value=defer.succeed({"123": mock_event})
) )
self._rlsn._store.get_events = Mock(return_value=defer.succeed(
{"123": mock_event}
))
yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) yield self._rlsn.maybe_send_server_notice_to_user(self.user_id)
# Would be better to check the content, but once == remove blocking event # Would be better to check the content, but once == remove blocking event
@ -99,12 +98,11 @@ class TestResourceLimitsServerNotices(unittest.TestCase):
) )
mock_event = Mock( mock_event = Mock(
type=EventTypes.Message, type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
content={"msgtype": ServerNoticeMsgType}, )
self._rlsn._store.get_events = Mock(
return_value=defer.succeed({"123": mock_event})
) )
self._rlsn._store.get_events = Mock(return_value=defer.succeed(
{"123": mock_event}
))
yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) yield self._rlsn.maybe_send_server_notice_to_user(self.user_id)
self._send_notice.assert_not_called() self._send_notice.assert_not_called()
@ -177,13 +175,9 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_server_notice_only_sent_once(self): def test_server_notice_only_sent_once(self):
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(return_value=1000)
return_value=1000,
)
self.store.user_last_seen_monthly_active = Mock( self.store.user_last_seen_monthly_active = Mock(return_value=1000)
return_value=1000,
)
# Call the function multiple times to ensure we only send the notice once # Call the function multiple times to ensure we only send the notice once
yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) yield self._rlsn.maybe_send_server_notice_to_user(self.user_id)
@ -193,12 +187,12 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.TestCase):
# Now lets get the last load of messages in the service notice room and # Now lets get the last load of messages in the service notice room and
# check that there is only one server notice # check that there is only one server notice
room_id = yield self.server_notices_manager.get_notice_room_for_user( room_id = yield self.server_notices_manager.get_notice_room_for_user(
self.user_id, self.user_id
) )
token = yield self.event_source.get_current_token() token = yield self.event_source.get_current_token()
events, _ = yield self.store.get_recent_events_for_room( events, _ = yield self.store.get_recent_events_for_room(
room_id, limit=100, end_token=token.room_key, room_id, limit=100, end_token=token.room_key
) )
count = 0 count = 0

View file

@ -101,13 +101,11 @@ class ClientIpStoreTestCase(tests.unittest.TestCase):
self.hs.config.limit_usage_by_mau = True self.hs.config.limit_usage_by_mau = True
self.hs.config.max_mau_value = 50 self.hs.config.max_mau_value = 50
user_id = "@user:server" user_id = "@user:server"
yield self.store.register(user_id=user_id, token="123", password_hash=None)
active = yield self.store.user_last_seen_monthly_active(user_id) active = yield self.store.user_last_seen_monthly_active(user_id)
self.assertFalse(active) self.assertFalse(active)
yield self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", "device_id"
)
yield self.store.insert_client_ip( yield self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", "device_id" user_id, "access_token", "ip", "user_agent", "device_id"
) )

View file

@ -12,6 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from mock import Mock
from twisted.internet import defer
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
@ -23,7 +26,8 @@ class MonthlyActiveUsersTestCase(HomeserverTestCase):
hs = self.setup_test_homeserver() hs = self.setup_test_homeserver()
self.store = hs.get_datastore() self.store = hs.get_datastore()
hs.config.limit_usage_by_mau = True
hs.config.max_mau_value = 50
# Advance the clock a bit # Advance the clock a bit
reactor.advance(FORTY_DAYS) reactor.advance(FORTY_DAYS)
@ -73,7 +77,7 @@ class MonthlyActiveUsersTestCase(HomeserverTestCase):
active_count = self.store.get_monthly_active_count() active_count = self.store.get_monthly_active_count()
self.assertEquals(self.get_success(active_count), user_num) self.assertEquals(self.get_success(active_count), user_num)
# Test that regalar users are removed from the db # Test that regular users are removed from the db
ru_count = 2 ru_count = 2
self.store.upsert_monthly_active_user("@ru1:server") self.store.upsert_monthly_active_user("@ru1:server")
self.store.upsert_monthly_active_user("@ru2:server") self.store.upsert_monthly_active_user("@ru2:server")
@ -139,3 +143,74 @@ class MonthlyActiveUsersTestCase(HomeserverTestCase):
count = self.store.get_monthly_active_count() count = self.store.get_monthly_active_count()
self.assertEquals(self.get_success(count), 0) self.assertEquals(self.get_success(count), 0)
def test_populate_monthly_users_is_guest(self):
# Test that guest users are not added to mau list
user_id = "user_id"
self.store.register(
user_id=user_id, token="123", password_hash=None, make_guest=True
)
self.store.upsert_monthly_active_user = Mock()
self.store.populate_monthly_active_users(user_id)
self.pump()
self.store.upsert_monthly_active_user.assert_not_called()
def test_populate_monthly_users_should_update(self):
self.store.upsert_monthly_active_user = Mock()
self.store.is_trial_user = Mock(
return_value=defer.succeed(False)
)
self.store.user_last_seen_monthly_active = Mock(
return_value=defer.succeed(None)
)
self.store.populate_monthly_active_users('user_id')
self.pump()
self.store.upsert_monthly_active_user.assert_called_once()
def test_populate_monthly_users_should_not_update(self):
self.store.upsert_monthly_active_user = Mock()
self.store.is_trial_user = Mock(
return_value=defer.succeed(False)
)
self.store.user_last_seen_monthly_active = Mock(
return_value=defer.succeed(
self.hs.get_clock().time_msec()
)
)
self.store.populate_monthly_active_users('user_id')
self.pump()
self.store.upsert_monthly_active_user.assert_not_called()
def test_get_reserved_real_user_account(self):
# Test no reserved users, or reserved threepids
count = self.store.get_registered_reserved_users_count()
self.assertEquals(self.get_success(count), 0)
# Test reserved users but no registered users
user1 = '@user1:example.com'
user2 = '@user2:example.com'
user1_email = 'user1@example.com'
user2_email = 'user2@example.com'
threepids = [
{'medium': 'email', 'address': user1_email},
{'medium': 'email', 'address': user2_email},
]
self.hs.config.mau_limits_reserved_threepids = threepids
self.store.initialise_reserved_users(threepids)
self.pump()
count = self.store.get_registered_reserved_users_count()
self.assertEquals(self.get_success(count), 0)
# Test reserved registed users
self.store.register(user_id=user1, token="123", password_hash=None)
self.store.register(user_id=user2, token="456", password_hash=None)
self.pump()
now = int(self.hs.get_clock().time_msec())
self.store.user_add_threepid(user1, "email", user1_email, now, now)
self.store.user_add_threepid(user2, "email", user2_email, now, now)
count = self.store.get_registered_reserved_users_count()
self.assertEquals(self.get_success(count), len(threepids))

View file

@ -185,8 +185,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_some_state_from_cache correctly filters out members with types=[] # test _get_some_state_from_cache correctly filters out members with types=[]
(state_dict, is_all) = yield self.store._get_some_state_from_cache( (state_dict, is_all) = yield self.store._get_some_state_from_cache(
self.store._state_group_cache, self.store._state_group_cache, group, [], filtered_types=[EventTypes.Member]
group, [], filtered_types=[EventTypes.Member]
) )
self.assertEqual(is_all, True) self.assertEqual(is_all, True)
@ -200,19 +199,20 @@ class StateStoreTestCase(tests.unittest.TestCase):
(state_dict, is_all) = yield self.store._get_some_state_from_cache( (state_dict, is_all) = yield self.store._get_some_state_from_cache(
self.store._state_group_members_cache, self.store._state_group_members_cache,
group, [], filtered_types=[EventTypes.Member] group,
[],
filtered_types=[EventTypes.Member],
) )
self.assertEqual(is_all, True) self.assertEqual(is_all, True)
self.assertDictEqual( self.assertDictEqual({}, state_dict)
{},
state_dict,
)
# test _get_some_state_from_cache correctly filters in members with wildcard types # test _get_some_state_from_cache correctly filters in members with wildcard types
(state_dict, is_all) = yield self.store._get_some_state_from_cache( (state_dict, is_all) = yield self.store._get_some_state_from_cache(
self.store._state_group_cache, self.store._state_group_cache,
group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member] group,
[(EventTypes.Member, None)],
filtered_types=[EventTypes.Member],
) )
self.assertEqual(is_all, True) self.assertEqual(is_all, True)
@ -226,7 +226,9 @@ class StateStoreTestCase(tests.unittest.TestCase):
(state_dict, is_all) = yield self.store._get_some_state_from_cache( (state_dict, is_all) = yield self.store._get_some_state_from_cache(
self.store._state_group_members_cache, self.store._state_group_members_cache,
group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member] group,
[(EventTypes.Member, None)],
filtered_types=[EventTypes.Member],
) )
self.assertEqual(is_all, True) self.assertEqual(is_all, True)
@ -264,18 +266,15 @@ class StateStoreTestCase(tests.unittest.TestCase):
) )
self.assertEqual(is_all, True) self.assertEqual(is_all, True)
self.assertDictEqual( self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
{
(e5.type, e5.state_key): e5.event_id,
},
state_dict,
)
# test _get_some_state_from_cache correctly filters in members with specific types # test _get_some_state_from_cache correctly filters in members with specific types
# and no filtered_types # and no filtered_types
(state_dict, is_all) = yield self.store._get_some_state_from_cache( (state_dict, is_all) = yield self.store._get_some_state_from_cache(
self.store._state_group_members_cache, self.store._state_group_members_cache,
group, [(EventTypes.Member, e5.state_key)], filtered_types=None group,
[(EventTypes.Member, e5.state_key)],
filtered_types=None,
) )
self.assertEqual(is_all, True) self.assertEqual(is_all, True)
@ -305,9 +304,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
key=group, key=group,
value=state_dict_ids, value=state_dict_ids,
# list fetched keys so it knows it's partial # list fetched keys so it knows it's partial
fetched_keys=( fetched_keys=((e1.type, e1.state_key),),
(e1.type, e1.state_key),
),
) )
(is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get( (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get(
@ -315,20 +312,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
) )
self.assertEqual(is_all, False) self.assertEqual(is_all, False)
self.assertEqual( self.assertEqual(known_absent, set([(e1.type, e1.state_key)]))
known_absent, self.assertDictEqual(state_dict_ids, {(e1.type, e1.state_key): e1.event_id})
set(
[
(e1.type, e1.state_key),
]
),
)
self.assertDictEqual(
state_dict_ids,
{
(e1.type, e1.state_key): e1.event_id,
},
)
############################################ ############################################
# test that things work with a partial cache # test that things work with a partial cache
@ -336,8 +321,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_some_state_from_cache correctly filters out members with types=[] # test _get_some_state_from_cache correctly filters out members with types=[]
room_id = self.room.to_string() room_id = self.room.to_string()
(state_dict, is_all) = yield self.store._get_some_state_from_cache( (state_dict, is_all) = yield self.store._get_some_state_from_cache(
self.store._state_group_cache, self.store._state_group_cache, group, [], filtered_types=[EventTypes.Member]
group, [], filtered_types=[EventTypes.Member]
) )
self.assertEqual(is_all, False) self.assertEqual(is_all, False)
@ -346,7 +330,9 @@ class StateStoreTestCase(tests.unittest.TestCase):
room_id = self.room.to_string() room_id = self.room.to_string()
(state_dict, is_all) = yield self.store._get_some_state_from_cache( (state_dict, is_all) = yield self.store._get_some_state_from_cache(
self.store._state_group_members_cache, self.store._state_group_members_cache,
group, [], filtered_types=[EventTypes.Member] group,
[],
filtered_types=[EventTypes.Member],
) )
self.assertEqual(is_all, True) self.assertEqual(is_all, True)
@ -355,20 +341,19 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_some_state_from_cache correctly filters in members wildcard types # test _get_some_state_from_cache correctly filters in members wildcard types
(state_dict, is_all) = yield self.store._get_some_state_from_cache( (state_dict, is_all) = yield self.store._get_some_state_from_cache(
self.store._state_group_cache, self.store._state_group_cache,
group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member] group,
[(EventTypes.Member, None)],
filtered_types=[EventTypes.Member],
) )
self.assertEqual(is_all, False) self.assertEqual(is_all, False)
self.assertDictEqual( self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
{
(e1.type, e1.state_key): e1.event_id,
},
state_dict,
)
(state_dict, is_all) = yield self.store._get_some_state_from_cache( (state_dict, is_all) = yield self.store._get_some_state_from_cache(
self.store._state_group_members_cache, self.store._state_group_members_cache,
group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member] group,
[(EventTypes.Member, None)],
filtered_types=[EventTypes.Member],
) )
self.assertEqual(is_all, True) self.assertEqual(is_all, True)
@ -389,12 +374,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
) )
self.assertEqual(is_all, False) self.assertEqual(is_all, False)
self.assertDictEqual( self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
{
(e1.type, e1.state_key): e1.event_id,
},
state_dict,
)
(state_dict, is_all) = yield self.store._get_some_state_from_cache( (state_dict, is_all) = yield self.store._get_some_state_from_cache(
self.store._state_group_members_cache, self.store._state_group_members_cache,
@ -404,18 +384,15 @@ class StateStoreTestCase(tests.unittest.TestCase):
) )
self.assertEqual(is_all, True) self.assertEqual(is_all, True)
self.assertDictEqual( self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
{
(e5.type, e5.state_key): e5.event_id,
},
state_dict,
)
# test _get_some_state_from_cache correctly filters in members with specific types # test _get_some_state_from_cache correctly filters in members with specific types
# and no filtered_types # and no filtered_types
(state_dict, is_all) = yield self.store._get_some_state_from_cache( (state_dict, is_all) = yield self.store._get_some_state_from_cache(
self.store._state_group_cache, self.store._state_group_cache,
group, [(EventTypes.Member, e5.state_key)], filtered_types=None group,
[(EventTypes.Member, e5.state_key)],
filtered_types=None,
) )
self.assertEqual(is_all, False) self.assertEqual(is_all, False)
@ -423,13 +400,10 @@ class StateStoreTestCase(tests.unittest.TestCase):
(state_dict, is_all) = yield self.store._get_some_state_from_cache( (state_dict, is_all) = yield self.store._get_some_state_from_cache(
self.store._state_group_members_cache, self.store._state_group_members_cache,
group, [(EventTypes.Member, e5.state_key)], filtered_types=None group,
[(EventTypes.Member, e5.state_key)],
filtered_types=None,
) )
self.assertEqual(is_all, True) self.assertEqual(is_all, True)
self.assertDictEqual( self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
{
(e5.type, e5.state_key): e5.event_id,
},
state_dict,
)

View file

@ -185,20 +185,20 @@ class TestMauLimit(unittest.TestCase):
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
def create_user(self, localpart): def create_user(self, localpart):
request_data = json.dumps({ request_data = json.dumps(
"username": localpart, {
"password": "monkey", "username": localpart,
"auth": {"type": LoginType.DUMMY}, "password": "monkey",
}) "auth": {"type": LoginType.DUMMY},
}
)
request, channel = make_request(b"POST", b"/register", request_data) request, channel = make_request("POST", "/register", request_data)
render(request, self.resource, self.reactor) render(request, self.resource, self.reactor)
if channel.result["code"] != b"200": if channel.code != 200:
raise HttpResponseException( raise HttpResponseException(
int(channel.result["code"]), channel.code, channel.result["reason"], channel.result["body"]
channel.result["reason"],
channel.result["body"],
).to_synapse_error() ).to_synapse_error()
access_token = channel.json_body["access_token"] access_token = channel.json_body["access_token"]
@ -206,12 +206,12 @@ class TestMauLimit(unittest.TestCase):
return access_token return access_token
def do_sync_for_user(self, token): def do_sync_for_user(self, token):
request, channel = make_request(b"GET", b"/sync", access_token=token) request, channel = make_request(
"GET", "/sync", access_token=token.encode('ascii')
)
render(request, self.resource, self.reactor) render(request, self.resource, self.reactor)
if channel.result["code"] != b"200": if channel.code != 200:
raise HttpResponseException( raise HttpResponseException(
int(channel.result["code"]), channel.code, channel.result["reason"], channel.result["body"]
channel.result["reason"],
channel.result["body"],
).to_synapse_error() ).to_synapse_error()

81
tests/test_metrics.py Normal file
View file

@ -0,0 +1,81 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.metrics import InFlightGauge
from tests import unittest
class TestMauLimit(unittest.TestCase):
def test_basic(self):
gauge = InFlightGauge(
"test1", "",
labels=["test_label"],
sub_metrics=["foo", "bar"],
)
def handle1(metrics):
metrics.foo += 2
metrics.bar = max(metrics.bar, 5)
def handle2(metrics):
metrics.foo += 3
metrics.bar = max(metrics.bar, 7)
gauge.register(("key1",), handle1)
self.assert_dict({
"test1_total": {("key1",): 1},
"test1_foo": {("key1",): 2},
"test1_bar": {("key1",): 5},
}, self.get_metrics_from_gauge(gauge))
gauge.unregister(("key1",), handle1)
self.assert_dict({
"test1_total": {("key1",): 0},
"test1_foo": {("key1",): 0},
"test1_bar": {("key1",): 0},
}, self.get_metrics_from_gauge(gauge))
gauge.register(("key1",), handle1)
gauge.register(("key2",), handle2)
self.assert_dict({
"test1_total": {("key1",): 1, ("key2",): 1},
"test1_foo": {("key1",): 2, ("key2",): 3},
"test1_bar": {("key1",): 5, ("key2",): 7},
}, self.get_metrics_from_gauge(gauge))
gauge.unregister(("key2",), handle2)
gauge.register(("key1",), handle2)
self.assert_dict({
"test1_total": {("key1",): 2, ("key2",): 0},
"test1_foo": {("key1",): 5, ("key2",): 0},
"test1_bar": {("key1",): 7, ("key2",): 0},
}, self.get_metrics_from_gauge(gauge))
def get_metrics_from_gauge(self, gauge):
results = {}
for r in gauge.collect():
results[r.name] = {
tuple(labels[x] for x in gauge.labels): value
for _, labels, value in r.samples
}
return results

View file

@ -180,7 +180,7 @@ class StateTestCase(unittest.TestCase):
graph = Graph( graph = Graph(
nodes={ nodes={
"START": DictObj( "START": DictObj(
type=EventTypes.Create, state_key="", content={}, depth=1, type=EventTypes.Create, state_key="", content={}, depth=1
), ),
"A": DictObj(type=EventTypes.Message, depth=2), "A": DictObj(type=EventTypes.Message, depth=2),
"B": DictObj(type=EventTypes.Message, depth=3), "B": DictObj(type=EventTypes.Message, depth=3),

View file

@ -100,8 +100,13 @@ class TestHomeServer(HomeServer):
@defer.inlineCallbacks @defer.inlineCallbacks
def setup_test_homeserver( def setup_test_homeserver(
cleanup_func, name="test", datastore=None, config=None, reactor=None, cleanup_func,
homeserverToUse=TestHomeServer, **kargs name="test",
datastore=None,
config=None,
reactor=None,
homeserverToUse=TestHomeServer,
**kargs
): ):
""" """
Setup a homeserver suitable for running tests against. Keyword arguments Setup a homeserver suitable for running tests against. Keyword arguments
@ -147,6 +152,7 @@ def setup_test_homeserver(
config.hs_disabled_message = "" config.hs_disabled_message = ""
config.hs_disabled_limit_type = "" config.hs_disabled_limit_type = ""
config.max_mau_value = 50 config.max_mau_value = 50
config.mau_trial_days = 0
config.mau_limits_reserved_threepids = [] config.mau_limits_reserved_threepids = []
config.admin_contact = None config.admin_contact = None
config.rc_messages_per_second = 10000 config.rc_messages_per_second = 10000
@ -322,8 +328,7 @@ class MockHttpResource(HttpServer):
@patch('twisted.web.http.Request') @patch('twisted.web.http.Request')
@defer.inlineCallbacks @defer.inlineCallbacks
def trigger( def trigger(
self, http_method, path, content, mock_request, self, http_method, path, content, mock_request, federation_auth_origin=None
federation_auth_origin=None,
): ):
""" Fire an HTTP event. """ Fire an HTTP event.
@ -356,7 +361,7 @@ class MockHttpResource(HttpServer):
headers = {} headers = {}
if federation_auth_origin is not None: if federation_auth_origin is not None:
headers[b"Authorization"] = [ headers[b"Authorization"] = [
b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin, ) b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,)
] ]
mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers) mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers)
@ -576,16 +581,16 @@ def create_room(hs, room_id, creator_id):
event_builder_factory = hs.get_event_builder_factory() event_builder_factory = hs.get_event_builder_factory()
event_creation_handler = hs.get_event_creation_handler() event_creation_handler = hs.get_event_creation_handler()
builder = event_builder_factory.new({ builder = event_builder_factory.new(
"type": EventTypes.Create, {
"state_key": "", "type": EventTypes.Create,
"sender": creator_id, "state_key": "",
"room_id": room_id, "sender": creator_id,
"content": {}, "room_id": room_id,
}) "content": {},
}
event, context = yield event_creation_handler.create_new_client_event(
builder
) )
event, context = yield event_creation_handler.create_new_client_event(builder)
yield store.persist_event(event, context) yield store.persist_event(event, context)

44
tox.ini
View file

@ -64,49 +64,11 @@ setenv =
{[base]setenv} {[base]setenv}
SYNAPSE_POSTGRES = 1 SYNAPSE_POSTGRES = 1
[testenv:py35]
usedevelop=true
[testenv:py36] [testenv:py36]
usedevelop=true usedevelop=true
commands =
/usr/bin/find "{toxinidir}" -name '*.pyc' -delete
coverage run {env:COVERAGE_OPTS:} --source="{toxinidir}/synapse" \
"{envbindir}/trial" {env:TRIAL_FLAGS:} {posargs:tests/config \
tests/api/test_filtering.py \
tests/api/test_ratelimiting.py \
tests/appservice \
tests/crypto \
tests/events \
tests/handlers/test_appservice.py \
tests/handlers/test_auth.py \
tests/handlers/test_device.py \
tests/handlers/test_directory.py \
tests/handlers/test_e2e_keys.py \
tests/handlers/test_presence.py \
tests/handlers/test_profile.py \
tests/handlers/test_register.py \
tests/replication/slave/storage/test_account_data.py \
tests/replication/slave/storage/test_receipts.py \
tests/storage/test_appservice.py \
tests/storage/test_background_update.py \
tests/storage/test_base.py \
tests/storage/test__base.py \
tests/storage/test_client_ips.py \
tests/storage/test_devices.py \
tests/storage/test_end_to_end_keys.py \
tests/storage/test_event_push_actions.py \
tests/storage/test_keys.py \
tests/storage/test_presence.py \
tests/storage/test_profile.py \
tests/storage/test_registration.py \
tests/storage/test_room.py \
tests/storage/test_user_directory.py \
tests/test_distributor.py \
tests/test_dns.py \
tests/test_preview.py \
tests/test_test_utils.py \
tests/test_types.py \
tests/util} \
{env:TOXSUFFIX:}
{env:DUMP_COVERAGE_COMMAND:coverage report -m}
[testenv:packaging] [testenv:packaging]
deps = deps =