mirror of
https://mau.dev/maunium/synapse.git
synced 2024-11-17 07:21:37 +01:00
Merge remote-tracking branch 'origin/develop' into dbkr/notifications_api
This commit is contained in:
commit
b4ecf0b886
208 changed files with 9809 additions and 3566 deletions
294
CHANGES.rst
294
CHANGES.rst
|
@ -1,3 +1,297 @@
|
|||
Changes in synapse v0.17.0 (2016-08-08)
|
||||
=======================================
|
||||
|
||||
This release contains significant security bug fixes regarding authenticating
|
||||
events received over federation. PLEASE UPGRADE.
|
||||
|
||||
This release changes the LDAP configuration format in a backwards incompatible
|
||||
way, see PR #843 for details.
|
||||
|
||||
|
||||
Changes:
|
||||
|
||||
* Add federation /version API (PR #990)
|
||||
* Make psutil dependency optional (PR #992)
|
||||
|
||||
|
||||
Bug fixes:
|
||||
|
||||
* Fix URL preview API to exclude HTML comments in description (PR #988)
|
||||
* Fix error handling of remote joins (PR #991)
|
||||
|
||||
|
||||
Changes in synapse v0.17.0-rc4 (2016-08-05)
|
||||
===========================================
|
||||
|
||||
Changes:
|
||||
|
||||
* Change the way we summarize URLs when previewing (PR #973)
|
||||
* Add new ``/state_ids/`` federation API (PR #979)
|
||||
* Speed up processing of ``/state/`` response (PR #986)
|
||||
|
||||
Bug fixes:
|
||||
|
||||
* Fix event persistence when event has already been partially persisted
|
||||
(PR #975, #983, #985)
|
||||
* Fix port script to also copy across backfilled events (PR #982)
|
||||
|
||||
|
||||
Changes in synapse v0.17.0-rc3 (2016-08-02)
|
||||
===========================================
|
||||
|
||||
Changes:
|
||||
|
||||
* Forbid non-ASes from registering users whose names begin with '_' (PR #958)
|
||||
* Add some basic admin API docs (PR #963)
|
||||
|
||||
|
||||
Bug fixes:
|
||||
|
||||
* Send the correct host header when fetching keys (PR #941)
|
||||
* Fix joining a room that has missing auth events (PR #964)
|
||||
* Fix various push bugs (PR #966, #970)
|
||||
* Fix adding emails on registration (PR #968)
|
||||
|
||||
|
||||
Changes in synapse v0.17.0-rc2 (2016-08-02)
|
||||
===========================================
|
||||
|
||||
(This release did not include the changes advertised and was identical to RC1)
|
||||
|
||||
|
||||
Changes in synapse v0.17.0-rc1 (2016-07-28)
|
||||
===========================================
|
||||
|
||||
This release changes the LDAP configuration format in a backwards incompatible
|
||||
way, see PR #843 for details.
|
||||
|
||||
|
||||
Features:
|
||||
|
||||
* Add purge_media_cache admin API (PR #902)
|
||||
* Add deactivate account admin API (PR #903)
|
||||
* Add optional pepper to password hashing (PR #907, #910 by KentShikama)
|
||||
* Add an admin option to shared secret registration (breaks backwards compat)
|
||||
(PR #909)
|
||||
* Add purge local room history API (PR #911, #923, #924)
|
||||
* Add requestToken endpoints (PR #915)
|
||||
* Add an /account/deactivate endpoint (PR #921)
|
||||
* Add filter param to /messages. Add 'contains_url' to filter. (PR #922)
|
||||
* Add device_id support to /login (PR #929)
|
||||
* Add device_id support to /v2/register flow. (PR #937, #942)
|
||||
* Add GET /devices endpoint (PR #939, #944)
|
||||
* Add GET /device/{deviceId} (PR #943)
|
||||
* Add update and delete APIs for devices (PR #949)
|
||||
|
||||
|
||||
Changes:
|
||||
|
||||
* Rewrite LDAP Authentication against ldap3 (PR #843 by mweinelt)
|
||||
* Linearize some federation endpoints based on (origin, room_id) (PR #879)
|
||||
* Remove the legacy v0 content upload API. (PR #888)
|
||||
* Use similar naming we use in email notifs for push (PR #894)
|
||||
* Optionally include password hash in createUser endpoint (PR #905 by
|
||||
KentShikama)
|
||||
* Use a query that postgresql optimises better for get_events_around (PR #906)
|
||||
* Fall back to 'username' if 'user' is not given for appservice registration.
|
||||
(PR #927 by Half-Shot)
|
||||
* Add metrics for psutil derived memory usage (PR #936)
|
||||
* Record device_id in client_ips (PR #938)
|
||||
* Send the correct host header when fetching keys (PR #941)
|
||||
* Log the hostname the reCAPTCHA was completed on (PR #946)
|
||||
* Make the device id on e2e key upload optional (PR #956)
|
||||
* Add r0.2.0 to the "supported versions" list (PR #960)
|
||||
* Don't include name of room for invites in push (PR #961)
|
||||
|
||||
|
||||
Bug fixes:
|
||||
|
||||
* Fix substitution failure in mail template (PR #887)
|
||||
* Put most recent 20 messages in email notif (PR #892)
|
||||
* Ensure that the guest user is in the database when upgrading accounts
|
||||
(PR #914)
|
||||
* Fix various edge cases in auth handling (PR #919)
|
||||
* Fix 500 ISE when sending alias event without a state_key (PR #925)
|
||||
* Fix bug where we stored rejections in the state_group, persist all
|
||||
rejections (PR #948)
|
||||
* Fix lack of check of if the user is banned when handling 3pid invites
|
||||
(PR #952)
|
||||
* Fix a couple of bugs in the transaction and keyring code (PR #954, #955)
|
||||
|
||||
|
||||
|
||||
Changes in synapse v0.16.1-r1 (2016-07-08)
|
||||
==========================================
|
||||
|
||||
THIS IS A CRITICAL SECURITY UPDATE.
|
||||
|
||||
This fixes a bug which allowed users' accounts to be accessed by unauthorised
|
||||
users.
|
||||
|
||||
Changes in synapse v0.16.1 (2016-06-20)
|
||||
=======================================
|
||||
|
||||
Bug fixes:
|
||||
|
||||
* Fix assorted bugs in ``/preview_url`` (PR #872)
|
||||
* Fix TypeError when setting unicode passwords (PR #873)
|
||||
|
||||
|
||||
Performance improvements:
|
||||
|
||||
* Turn ``use_frozen_events`` off by default (PR #877)
|
||||
* Disable responding with canonical json for federation (PR #878)
|
||||
|
||||
|
||||
Changes in synapse v0.16.1-rc1 (2016-06-15)
|
||||
===========================================
|
||||
|
||||
Features: None
|
||||
|
||||
Changes:
|
||||
|
||||
* Log requester for ``/publicRoom`` endpoints when possible (PR #856)
|
||||
* 502 on ``/thumbnail`` when can't connect to remote server (PR #862)
|
||||
* Linearize fetching of gaps on incoming events (PR #871)
|
||||
|
||||
|
||||
Bugs fixes:
|
||||
|
||||
* Fix bug where rooms where marked as published by default (PR #857)
|
||||
* Fix bug where joining room with an event with invalid sender (PR #868)
|
||||
* Fix bug where backfilled events were sent down sync streams (PR #869)
|
||||
* Fix bug where outgoing connections could wedge indefinitely, causing push
|
||||
notifications to be unreliable (PR #870)
|
||||
|
||||
|
||||
Performance improvements:
|
||||
|
||||
* Improve ``/publicRooms`` performance(PR #859)
|
||||
|
||||
|
||||
Changes in synapse v0.16.0 (2016-06-09)
|
||||
=======================================
|
||||
|
||||
NB: As of v0.14 all AS config files must have an ID field.
|
||||
|
||||
|
||||
Bug fixes:
|
||||
|
||||
* Don't make rooms published by default (PR #857)
|
||||
|
||||
Changes in synapse v0.16.0-rc2 (2016-06-08)
|
||||
===========================================
|
||||
|
||||
Features:
|
||||
|
||||
* Add configuration option for tuning GC via ``gc.set_threshold`` (PR #849)
|
||||
|
||||
Changes:
|
||||
|
||||
* Record metrics about GC (PR #771, #847, #852)
|
||||
* Add metric counter for number of persisted events (PR #841)
|
||||
|
||||
Bug fixes:
|
||||
|
||||
* Fix 'From' header in email notifications (PR #843)
|
||||
* Fix presence where timeouts were not being fired for the first 8h after
|
||||
restarts (PR #842)
|
||||
* Fix bug where synapse sent malformed transactions to AS's when retrying
|
||||
transactions (Commits 310197b, 8437906)
|
||||
|
||||
Performance improvements:
|
||||
|
||||
* Remove event fetching from DB threads (PR #835)
|
||||
* Change the way we cache events (PR #836)
|
||||
* Add events to cache when we persist them (PR #840)
|
||||
|
||||
|
||||
Changes in synapse v0.16.0-rc1 (2016-06-03)
|
||||
===========================================
|
||||
|
||||
Version 0.15 was not released. See v0.15.0-rc1 below for additional changes.
|
||||
|
||||
Features:
|
||||
|
||||
* Add email notifications for missed messages (PR #759, #786, #799, #810, #815,
|
||||
#821)
|
||||
* Add a ``url_preview_ip_range_whitelist`` config param (PR #760)
|
||||
* Add /report endpoint (PR #762)
|
||||
* Add basic ignore user API (PR #763)
|
||||
* Add an openidish mechanism for proving that you own a given user_id (PR #765)
|
||||
* Allow clients to specify a server_name to avoid 'No known servers' (PR #794)
|
||||
* Add secondary_directory_servers option to fetch room list from other servers
|
||||
(PR #808, #813)
|
||||
|
||||
Changes:
|
||||
|
||||
* Report per request metrics for all of the things using request_handler (PR
|
||||
#756)
|
||||
* Correctly handle ``NULL`` password hashes from the database (PR #775)
|
||||
* Allow receipts for events we haven't seen in the db (PR #784)
|
||||
* Make synctl read a cache factor from config file (PR #785)
|
||||
* Increment badge count per missed convo, not per msg (PR #793)
|
||||
* Special case m.room.third_party_invite event auth to match invites (PR #814)
|
||||
|
||||
|
||||
Bug fixes:
|
||||
|
||||
* Fix typo in event_auth servlet path (PR #757)
|
||||
* Fix password reset (PR #758)
|
||||
|
||||
|
||||
Performance improvements:
|
||||
|
||||
* Reduce database inserts when sending transactions (PR #767)
|
||||
* Queue events by room for persistence (PR #768)
|
||||
* Add cache to ``get_user_by_id`` (PR #772)
|
||||
* Add and use ``get_domain_from_id`` (PR #773)
|
||||
* Use tree cache for ``get_linearized_receipts_for_room`` (PR #779)
|
||||
* Remove unused indices (PR #782)
|
||||
* Add caches to ``bulk_get_push_rules*`` (PR #804)
|
||||
* Cache ``get_event_reference_hashes`` (PR #806)
|
||||
* Add ``get_users_with_read_receipts_in_room`` cache (PR #809)
|
||||
* Use state to calculate ``get_users_in_room`` (PR #811)
|
||||
* Load push rules in storage layer so that they get cached (PR #825)
|
||||
* Make ``get_joined_hosts_for_room`` use get_users_in_room (PR #828)
|
||||
* Poke notifier on next reactor tick (PR #829)
|
||||
* Change CacheMetrics to be quicker (PR #830)
|
||||
|
||||
|
||||
Changes in synapse v0.15.0-rc1 (2016-04-26)
|
||||
===========================================
|
||||
|
||||
Features:
|
||||
|
||||
* Add login support for Javascript Web Tokens, thanks to Niklas Riekenbrauck
|
||||
(PR #671,#687)
|
||||
* Add URL previewing support (PR #688)
|
||||
* Add login support for LDAP, thanks to Christoph Witzany (PR #701)
|
||||
* Add GET endpoint for pushers (PR #716)
|
||||
|
||||
Changes:
|
||||
|
||||
* Never notify for member events (PR #667)
|
||||
* Deduplicate identical ``/sync`` requests (PR #668)
|
||||
* Require user to have left room to forget room (PR #673)
|
||||
* Use DNS cache if within TTL (PR #677)
|
||||
* Let users see their own leave events (PR #699)
|
||||
* Deduplicate membership changes (PR #700)
|
||||
* Increase performance of pusher code (PR #705)
|
||||
* Respond with error status 504 if failed to talk to remote server (PR #731)
|
||||
* Increase search performance on postgres (PR #745)
|
||||
|
||||
Bug fixes:
|
||||
|
||||
* Fix bug where disabling all notifications still resulted in push (PR #678)
|
||||
* Fix bug where users couldn't reject remote invites if remote refused (PR #691)
|
||||
* Fix bug where synapse attempted to backfill from itself (PR #693)
|
||||
* Fix bug where profile information was not correctly added when joining remote
|
||||
rooms (PR #703)
|
||||
* Fix bug where register API required incorrect key name for AS registration
|
||||
(PR #727)
|
||||
|
||||
|
||||
Changes in synapse v0.14.0 (2016-03-30)
|
||||
=======================================
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@ recursive-include docs *
|
|||
recursive-include res *
|
||||
recursive-include scripts *
|
||||
recursive-include scripts-dev *
|
||||
recursive-include synapse *.pyi
|
||||
recursive-include tests *.py
|
||||
|
||||
recursive-include synapse/static *.css
|
||||
|
@ -23,5 +24,7 @@ recursive-include synapse/static *.js
|
|||
|
||||
exclude jenkins.sh
|
||||
exclude jenkins*.sh
|
||||
exclude jenkins*
|
||||
recursive-exclude jenkins *.sh
|
||||
|
||||
prune demo/etc
|
||||
|
|
21
README.rst
21
README.rst
|
@ -11,8 +11,8 @@ VoIP. The basics you need to know to get up and running are:
|
|||
like ``#matrix:matrix.org`` or ``#test:localhost:8448``.
|
||||
|
||||
- Matrix user IDs look like ``@matthew:matrix.org`` (although in the future
|
||||
you will normally refer to yourself and others using a 3PID: email
|
||||
address, phone number, etc rather than manipulating Matrix user IDs)
|
||||
you will normally refer to yourself and others using a third party identifier
|
||||
(3PID): email address, phone number, etc rather than manipulating Matrix user IDs)
|
||||
|
||||
The overall architecture is::
|
||||
|
||||
|
@ -58,12 +58,13 @@ the spec in the context of a codebase and let you run your own homeserver and
|
|||
generally help bootstrap the ecosystem.
|
||||
|
||||
In Matrix, every user runs one or more Matrix clients, which connect through to
|
||||
a Matrix homeserver which stores all their personal chat history and user
|
||||
account information - much as a mail client connects through to an IMAP/SMTP
|
||||
server. Just like email, you can either run your own Matrix homeserver and
|
||||
control and own your own communications and history or use one hosted by
|
||||
someone else (e.g. matrix.org) - there is no single point of control or
|
||||
mandatory service provider in Matrix, unlike WhatsApp, Facebook, Hangouts, etc.
|
||||
a Matrix homeserver. The homeserver stores all their personal chat history and
|
||||
user account information - much as a mail client connects through to an
|
||||
IMAP/SMTP server. Just like email, you can either run your own Matrix
|
||||
homeserver and control and own your own communications and history or use one
|
||||
hosted by someone else (e.g. matrix.org) - there is no single point of control
|
||||
or mandatory service provider in Matrix, unlike WhatsApp, Facebook, Hangouts,
|
||||
etc.
|
||||
|
||||
Synapse ships with two basic demo Matrix clients: webclient (a basic group chat
|
||||
web client demo implemented in AngularJS) and cmdclient (a basic Python
|
||||
|
@ -444,7 +445,7 @@ You have two choices here, which will influence the form of your Matrix user
|
|||
IDs:
|
||||
|
||||
1) Use the machine's own hostname as available on public DNS in the form of
|
||||
its A or AAAA records. This is easier to set up initially, perhaps for
|
||||
its A records. This is easier to set up initially, perhaps for
|
||||
testing, but lacks the flexibility of SRV.
|
||||
|
||||
2) Set up a SRV record for your domain name. This requires you create a SRV
|
||||
|
@ -617,7 +618,7 @@ Building internal API documentation::
|
|||
|
||||
|
||||
|
||||
Halp!! Synapse eats all my RAM!
|
||||
Help!! Synapse eats all my RAM!
|
||||
===============================
|
||||
|
||||
Synapse's architecture is quite RAM hungry currently - we deliberately
|
||||
|
|
|
@ -27,7 +27,7 @@ running:
|
|||
# Pull the latest version of the master branch.
|
||||
git pull
|
||||
# Update the versions of synapse's python dependencies.
|
||||
python synapse/python_dependencies.py | xargs -n1 pip install
|
||||
python synapse/python_dependencies.py | xargs -n1 pip install --upgrade
|
||||
|
||||
|
||||
Upgrading to v0.15.0
|
||||
|
|
|
@ -9,6 +9,7 @@ Description=Synapse Matrix homeserver
|
|||
Type=simple
|
||||
User=synapse
|
||||
Group=synapse
|
||||
EnvironmentFile=-/etc/sysconfig/synapse
|
||||
WorkingDirectory=/var/lib/synapse
|
||||
ExecStart=/usr/bin/python2.7 -m synapse.app.homeserver --config-path=/etc/synapse/homeserver.yaml --log-config=/etc/synapse/log_config.yaml
|
||||
|
||||
|
|
12
docs/admin_api/README.rst
Normal file
12
docs/admin_api/README.rst
Normal file
|
@ -0,0 +1,12 @@
|
|||
Admin APIs
|
||||
==========
|
||||
|
||||
This directory includes documentation for the various synapse specific admin
|
||||
APIs available.
|
||||
|
||||
Only users that are server admins can use these APIs. A user can be marked as a
|
||||
server admin by updating the database directly, e.g.:
|
||||
|
||||
``UPDATE users SET admin = 1 WHERE name = '@foo:bar.com'``
|
||||
|
||||
Restarting may be required for the changes to register.
|
15
docs/admin_api/purge_history_api.rst
Normal file
15
docs/admin_api/purge_history_api.rst
Normal file
|
@ -0,0 +1,15 @@
|
|||
Purge History API
|
||||
=================
|
||||
|
||||
The purge history API allows server admins to purge historic events from their
|
||||
database, reclaiming disk space.
|
||||
|
||||
Depending on the amount of history being purged a call to the API may take
|
||||
several minutes or longer. During this period users will not be able to
|
||||
paginate further back in the room from the point being purged from.
|
||||
|
||||
The API is simply:
|
||||
|
||||
``POST /_matrix/client/r0/admin/purge_history/<room_id>/<event_id>``
|
||||
|
||||
including an ``access_token`` of a server admin.
|
19
docs/admin_api/purge_remote_media.rst
Normal file
19
docs/admin_api/purge_remote_media.rst
Normal file
|
@ -0,0 +1,19 @@
|
|||
Purge Remote Media API
|
||||
======================
|
||||
|
||||
The purge remote media API allows server admins to purge old cached remote
|
||||
media.
|
||||
|
||||
The API is::
|
||||
|
||||
POST /_matrix/client/r0/admin/purge_media_cache
|
||||
|
||||
{
|
||||
"before_ts": <unix_timestamp_in_ms>
|
||||
}
|
||||
|
||||
Which will remove all cached media that was last accessed before
|
||||
``<unix_timestamp_in_ms>``.
|
||||
|
||||
If the user re-requests purged remote media, synapse will re-request the media
|
||||
from the originating server.
|
|
@ -32,5 +32,4 @@ The format of the AS configuration file is as follows:
|
|||
|
||||
See the spec_ for further details on how application services work.
|
||||
|
||||
.. _spec: https://github.com/matrix-org/matrix-doc/blob/master/specification/25_application_service_api.rst#application-service-api
|
||||
|
||||
.. _spec: https://matrix.org/docs/spec/application_service/unstable.html
|
||||
|
|
|
@ -43,7 +43,10 @@ Basically, PEP8
|
|||
together, or want to deliberately extend or preserve vertical/horizontal
|
||||
space)
|
||||
|
||||
Comments should follow the google code style. This is so that we can generate
|
||||
documentation with sphinx (http://sphinxcontrib-napoleon.readthedocs.org/en/latest/)
|
||||
Comments should follow the `google code style <http://google.github.io/styleguide/pyguide.html?showone=Comments#Comments>`_.
|
||||
This is so that we can generate documentation with
|
||||
`sphinx <http://sphinxcontrib-napoleon.readthedocs.org/en/latest/>`_. See the
|
||||
`examples <http://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html>`_
|
||||
in the sphinx documentation.
|
||||
|
||||
Code should pass pep8 --max-line-length=100 without any warnings.
|
||||
|
|
|
@ -9,31 +9,35 @@ the Home Server to generate credentials that are valid for use on the TURN
|
|||
server through the use of a secret shared between the Home Server and the
|
||||
TURN server.
|
||||
|
||||
This document described how to install coturn
|
||||
(https://code.google.com/p/coturn/) which also supports the TURN REST API,
|
||||
This document describes how to install coturn
|
||||
(https://github.com/coturn/coturn) which also supports the TURN REST API,
|
||||
and integrate it with synapse.
|
||||
|
||||
coturn Setup
|
||||
============
|
||||
|
||||
You may be able to setup coturn via your package manager, or set it up manually using the usual ``configure, make, make install`` process.
|
||||
|
||||
1. Check out coturn::
|
||||
svn checkout http://coturn.googlecode.com/svn/trunk/ coturn
|
||||
|
||||
git clone https://github.com/coturn/coturn.git coturn
|
||||
cd coturn
|
||||
|
||||
2. Configure it::
|
||||
|
||||
./configure
|
||||
|
||||
You may need to install libevent2: if so, you should do so
|
||||
You may need to install ``libevent2``: if so, you should do so
|
||||
in the way recommended by your operating system.
|
||||
You can ignore warnings about lack of database support: a
|
||||
database is unnecessary for this purpose.
|
||||
|
||||
3. Build and install it::
|
||||
|
||||
make
|
||||
make install
|
||||
|
||||
4. Make a config file in /etc/turnserver.conf. You can customise
|
||||
a config file from turnserver.conf.default. The relevant
|
||||
4. Create or edit the config file in ``/etc/turnserver.conf``. The relevant
|
||||
lines, with example values, are::
|
||||
|
||||
lt-cred-mech
|
||||
|
@ -41,7 +45,7 @@ coturn Setup
|
|||
static-auth-secret=[your secret key here]
|
||||
realm=turn.myserver.org
|
||||
|
||||
See turnserver.conf.default for explanations of the options.
|
||||
See turnserver.conf for explanations of the options.
|
||||
One way to generate the static-auth-secret is with pwgen::
|
||||
|
||||
pwgen -s 64 1
|
||||
|
@ -54,6 +58,7 @@ coturn Setup
|
|||
import your private key and certificate.
|
||||
|
||||
7. Start the turn server::
|
||||
|
||||
bin/turnserver -o
|
||||
|
||||
|
||||
|
|
22
jenkins-dendron-postgres.sh
Executable file
22
jenkins-dendron-postgres.sh
Executable file
|
@ -0,0 +1,22 @@
|
|||
#!/bin/bash
|
||||
|
||||
set -eux
|
||||
|
||||
: ${WORKSPACE:="$(pwd)"}
|
||||
|
||||
export WORKSPACE
|
||||
export PYTHONDONTWRITEBYTECODE=yep
|
||||
export SYNAPSE_CACHE_FACTOR=1
|
||||
|
||||
./jenkins/prepare_synapse.sh
|
||||
./jenkins/clone.sh sytest https://github.com/matrix-org/sytest.git
|
||||
./jenkins/clone.sh dendron https://github.com/matrix-org/dendron.git
|
||||
./dendron/jenkins/build_dendron.sh
|
||||
./sytest/jenkins/prep_sytest_for_postgres.sh
|
||||
|
||||
./sytest/jenkins/install_and_run.sh \
|
||||
--synapse-directory $WORKSPACE \
|
||||
--dendron $WORKSPACE/dendron/bin/dendron \
|
||||
--pusher \
|
||||
--synchrotron \
|
||||
--federation-reader \
|
|
@ -4,60 +4,14 @@ set -eux
|
|||
|
||||
: ${WORKSPACE:="$(pwd)"}
|
||||
|
||||
export WORKSPACE
|
||||
export PYTHONDONTWRITEBYTECODE=yep
|
||||
export SYNAPSE_CACHE_FACTOR=1
|
||||
|
||||
# Output test results as junit xml
|
||||
export TRIAL_FLAGS="--reporter=subunit"
|
||||
export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
|
||||
# Write coverage reports to a separate file for each process
|
||||
export COVERAGE_OPTS="-p"
|
||||
export DUMP_COVERAGE_COMMAND="coverage help"
|
||||
./jenkins/prepare_synapse.sh
|
||||
./jenkins/clone.sh sytest https://github.com/matrix-org/sytest.git
|
||||
|
||||
# Output flake8 violations to violations.flake8.log
|
||||
# Don't exit with non-0 status code on Jenkins,
|
||||
# so that the build steps continue and a later step can decided whether to
|
||||
# UNSTABLE or FAILURE this build.
|
||||
export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?"
|
||||
./sytest/jenkins/prep_sytest_for_postgres.sh
|
||||
|
||||
rm .coverage* || echo "No coverage files to remove"
|
||||
|
||||
tox --notest -e py27
|
||||
|
||||
TOX_BIN=$WORKSPACE/.tox/py27/bin
|
||||
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
|
||||
$TOX_BIN/pip install psycopg2
|
||||
$TOX_BIN/pip install lxml
|
||||
|
||||
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
|
||||
|
||||
if [[ ! -e .sytest-base ]]; then
|
||||
git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
|
||||
else
|
||||
(cd .sytest-base; git fetch -p)
|
||||
fi
|
||||
|
||||
rm -rf sytest
|
||||
git clone .sytest-base sytest --shared
|
||||
cd sytest
|
||||
|
||||
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
|
||||
|
||||
: ${PORT_BASE:=8000}
|
||||
|
||||
./jenkins/prep_sytest_for_postgres.sh
|
||||
|
||||
echo >&2 "Running sytest with PostgreSQL";
|
||||
./jenkins/install_and_run.sh --coverage \
|
||||
--python $TOX_BIN/python \
|
||||
./sytest/jenkins/install_and_run.sh \
|
||||
--synapse-directory $WORKSPACE \
|
||||
--port-base $PORT_BASE
|
||||
|
||||
cd ..
|
||||
cp sytest/.coverage.* .
|
||||
|
||||
# Combine the coverage reports
|
||||
echo "Combining:" .coverage.*
|
||||
$TOX_BIN/python -m coverage combine
|
||||
# Output coverage to coverage.xml
|
||||
$TOX_BIN/coverage xml -o coverage.xml
|
||||
|
|
|
@ -4,54 +4,12 @@ set -eux
|
|||
|
||||
: ${WORKSPACE:="$(pwd)"}
|
||||
|
||||
export WORKSPACE
|
||||
export PYTHONDONTWRITEBYTECODE=yep
|
||||
export SYNAPSE_CACHE_FACTOR=1
|
||||
|
||||
# Output test results as junit xml
|
||||
export TRIAL_FLAGS="--reporter=subunit"
|
||||
export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
|
||||
# Write coverage reports to a separate file for each process
|
||||
export COVERAGE_OPTS="-p"
|
||||
export DUMP_COVERAGE_COMMAND="coverage help"
|
||||
./jenkins/prepare_synapse.sh
|
||||
./jenkins/clone.sh sytest https://github.com/matrix-org/sytest.git
|
||||
|
||||
# Output flake8 violations to violations.flake8.log
|
||||
# Don't exit with non-0 status code on Jenkins,
|
||||
# so that the build steps continue and a later step can decided whether to
|
||||
# UNSTABLE or FAILURE this build.
|
||||
export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?"
|
||||
|
||||
rm .coverage* || echo "No coverage files to remove"
|
||||
|
||||
tox --notest -e py27
|
||||
TOX_BIN=$WORKSPACE/.tox/py27/bin
|
||||
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
|
||||
$TOX_BIN/pip install lxml
|
||||
|
||||
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
|
||||
|
||||
if [[ ! -e .sytest-base ]]; then
|
||||
git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
|
||||
else
|
||||
(cd .sytest-base; git fetch -p)
|
||||
fi
|
||||
|
||||
rm -rf sytest
|
||||
git clone .sytest-base sytest --shared
|
||||
cd sytest
|
||||
|
||||
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
|
||||
|
||||
: ${PORT_BASE:=8500}
|
||||
./jenkins/install_and_run.sh --coverage \
|
||||
--python $TOX_BIN/python \
|
||||
./sytest/jenkins/install_and_run.sh \
|
||||
--synapse-directory $WORKSPACE \
|
||||
--port-base $PORT_BASE
|
||||
|
||||
cd ..
|
||||
cp sytest/.coverage.* .
|
||||
|
||||
# Combine the coverage reports
|
||||
echo "Combining:" .coverage.*
|
||||
$TOX_BIN/python -m coverage combine
|
||||
# Output coverage to coverage.xml
|
||||
$TOX_BIN/coverage xml -o coverage.xml
|
||||
|
|
|
@ -22,4 +22,8 @@ export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished w
|
|||
|
||||
rm .coverage* || echo "No coverage files to remove"
|
||||
|
||||
tox --notest -e py27
|
||||
TOX_BIN=$WORKSPACE/.tox/py27/bin
|
||||
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
|
||||
|
||||
tox -e py27
|
||||
|
|
44
jenkins/clone.sh
Executable file
44
jenkins/clone.sh
Executable file
|
@ -0,0 +1,44 @@
|
|||
#! /bin/bash
|
||||
|
||||
# This clones a project from github into a named subdirectory
|
||||
# If the project has a branch with the same name as this branch
|
||||
# then it will checkout that branch after cloning.
|
||||
# Otherwise it will checkout "origin/develop."
|
||||
# The first argument is the name of the directory to checkout
|
||||
# the branch into.
|
||||
# The second argument is the URL of the remote repository to checkout.
|
||||
# Usually something like https://github.com/matrix-org/sytest.git
|
||||
|
||||
set -eux
|
||||
|
||||
NAME=$1
|
||||
PROJECT=$2
|
||||
BASE=".$NAME-base"
|
||||
|
||||
# Update our mirror.
|
||||
if [ ! -d ".$NAME-base" ]; then
|
||||
# Create a local mirror of the source repository.
|
||||
# This saves us from having to download the entire repository
|
||||
# when this script is next run.
|
||||
git clone "$PROJECT" "$BASE" --mirror
|
||||
else
|
||||
# Fetch any updates from the source repository.
|
||||
(cd "$BASE"; git fetch -p)
|
||||
fi
|
||||
|
||||
# Remove the existing repository so that we have a clean copy
|
||||
rm -rf "$NAME"
|
||||
# Cloning with --shared means that we will share portions of the
|
||||
# .git directory with our local mirror.
|
||||
git clone "$BASE" "$NAME" --shared
|
||||
|
||||
# Jenkins may have supplied us with the name of the branch in the
|
||||
# environment. Otherwise we will have to guess based on the current
|
||||
# commit.
|
||||
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
|
||||
cd "$NAME"
|
||||
# check out the relevant branch
|
||||
git checkout "${GIT_BRANCH}" || (
|
||||
echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop"
|
||||
git checkout "origin/develop"
|
||||
)
|
19
jenkins/prepare_synapse.sh
Executable file
19
jenkins/prepare_synapse.sh
Executable file
|
@ -0,0 +1,19 @@
|
|||
#! /bin/bash
|
||||
|
||||
cd "`dirname $0`/.."
|
||||
|
||||
TOX_DIR=$WORKSPACE/.tox
|
||||
|
||||
mkdir -p $TOX_DIR
|
||||
|
||||
if ! [ $TOX_DIR -ef .tox ]; then
|
||||
ln -s "$TOX_DIR" .tox
|
||||
fi
|
||||
|
||||
# set up the virtualenv
|
||||
tox -e py27 --notest -v
|
||||
|
||||
TOX_BIN=$TOX_DIR/py27/bin
|
||||
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
|
||||
$TOX_BIN/pip install lxml
|
||||
$TOX_BIN/pip install psycopg2
|
|
@ -145,6 +145,11 @@ pre, code {
|
|||
text-decoration: none;
|
||||
}
|
||||
|
||||
.debug {
|
||||
font-size: 10px;
|
||||
color: #888;
|
||||
}
|
||||
|
||||
.footer {
|
||||
margin-top: 20px;
|
||||
text-align: center;
|
||||
|
|
|
@ -17,11 +17,15 @@
|
|||
</td>
|
||||
<td class="message_contents">
|
||||
{% if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %}
|
||||
<div class="sender_name">{{ message.sender_name }}</div>
|
||||
<div class="sender_name">{% if message.msgtype == "m.emote" %}*{% endif %} {{ message.sender_name }}</div>
|
||||
{% endif %}
|
||||
<div class="message_body">
|
||||
{% if message.msgtype == "m.text" %}
|
||||
{{ message.body_text_html }}
|
||||
{% elif message.msgtype == "m.emote" %}
|
||||
{{ message.body_text_html }}
|
||||
{% elif message.msgtype == "m.notice" %}
|
||||
{{ message.body_text_html }}
|
||||
{% elif message.msgtype == "m.image" %}
|
||||
<img src="{{ message.image_url|mxc_to_http(640, 480, scale) }}" />
|
||||
{% elif message.msgtype == "m.file" %}
|
||||
|
|
|
@ -1,7 +1,11 @@
|
|||
{% for message in notif.messages %}
|
||||
{{ message.sender_name }} ({{ message.ts|format_ts("%H:%M") }})
|
||||
{% if message.msgtype == "m.emote" %}* {% endif %}{{ message.sender_name }} ({{ message.ts|format_ts("%H:%M") }})
|
||||
{% if message.msgtype == "m.text" %}
|
||||
{{ message.body_text_plain }}
|
||||
{% elif message.msgtype == "m.emote" %}
|
||||
{{ message.body_text_plain }}
|
||||
{% elif message.msgtype == "m.notice" %}
|
||||
{{ message.body_text_plain }}
|
||||
{% elif message.msgtype == "m.image" %}
|
||||
{{ message.body_text_plain }}
|
||||
{% elif message.msgtype == "m.file" %}
|
||||
|
|
|
@ -30,18 +30,20 @@
|
|||
{% include 'room.html' with context %}
|
||||
{% endfor %}
|
||||
<div class="footer">
|
||||
<small>
|
||||
Sending email at {{ reason.now|format_ts("%c") }} due to activity in room '{{ reason.room_name }}' because:<br/>
|
||||
1. An event was received at {{ reason.received_at|format_ts("%c") }}
|
||||
which is more than {{ "%.1f"|format(reason.delay_before_mail_ms / (60*1000)) }} (delay_before_mail_ms) mins ago.<br/>
|
||||
<a href="{{ unsubscribe_link }}">Unsubscribe</a>
|
||||
<br/>
|
||||
<br/>
|
||||
<div class="debug">
|
||||
Sending email at {{ reason.now|format_ts("%c") }} due to activity in room {{ reason.room_name }} because
|
||||
an event was received at {{ reason.received_at|format_ts("%c") }}
|
||||
which is more than {{ "%.1f"|format(reason.delay_before_mail_ms / (60*1000)) }} ({{ reason.delay_before_mail_ms }}) mins ago,
|
||||
{% if reason.last_sent_ts %}
|
||||
2. The last time we sent a mail for this room was {{ reason.last_sent_ts|format_ts("%c") }},
|
||||
and the last time we sent a mail for this room was {{ reason.last_sent_ts|format_ts("%c") }},
|
||||
which is more than {{ "%.1f"|format(reason.throttle_ms / (60*1000)) }} (current throttle_ms) mins ago.
|
||||
{% else %}
|
||||
2. We can't remember the last time we sent a mail for this room.
|
||||
and we don't have a last time we sent a mail for this room.
|
||||
{% endif %}
|
||||
</small>
|
||||
<a href="{{ unsubscribe_link }}">Unsubscribe</a>
|
||||
</div>
|
||||
</div>
|
||||
</td>
|
||||
<td> </td>
|
||||
|
|
|
@ -116,17 +116,19 @@ def get_json(origin_name, origin_key, destination, path):
|
|||
authorization_headers = []
|
||||
|
||||
for key, sig in signed_json["signatures"][origin_name].items():
|
||||
authorization_headers.append(bytes(
|
||||
"X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % (
|
||||
header = "X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % (
|
||||
origin_name, key, sig,
|
||||
)
|
||||
))
|
||||
authorization_headers.append(bytes(header))
|
||||
sys.stderr.write(header)
|
||||
sys.stderr.write("\n")
|
||||
|
||||
result = requests.get(
|
||||
lookup(destination, path),
|
||||
headers={"Authorization": authorization_headers[0]},
|
||||
verify=False,
|
||||
)
|
||||
sys.stderr.write("Status Code: %d\n" % (result.status_code,))
|
||||
return result.json()
|
||||
|
||||
|
||||
|
@ -141,6 +143,7 @@ def main():
|
|||
)
|
||||
|
||||
json.dump(result, sys.stdout)
|
||||
print ""
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -1,10 +1,16 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
import argparse
|
||||
|
||||
import sys
|
||||
|
||||
import bcrypt
|
||||
import getpass
|
||||
|
||||
import yaml
|
||||
|
||||
bcrypt_rounds=12
|
||||
password_pepper = ""
|
||||
|
||||
def prompt_for_pass():
|
||||
password = getpass.getpass("Password: ")
|
||||
|
@ -28,12 +34,22 @@ if __name__ == "__main__":
|
|||
default=None,
|
||||
help="New password for user. Will prompt if omitted.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c", "--config",
|
||||
type=argparse.FileType('r'),
|
||||
help="Path to server config file. Used to read in bcrypt_rounds and password_pepper.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
if "config" in args and args.config:
|
||||
config = yaml.safe_load(args.config)
|
||||
bcrypt_rounds = config.get("bcrypt_rounds", bcrypt_rounds)
|
||||
password_config = config.get("password_config", {})
|
||||
password_pepper = password_config.get("pepper", password_pepper)
|
||||
password = args.password
|
||||
|
||||
if not password:
|
||||
password = prompt_for_pass()
|
||||
|
||||
print bcrypt.hashpw(password, bcrypt.gensalt(bcrypt_rounds))
|
||||
print bcrypt.hashpw(password + password_pepper, bcrypt.gensalt(bcrypt_rounds))
|
||||
|
||||
|
|
|
@ -25,18 +25,26 @@ import urllib2
|
|||
import yaml
|
||||
|
||||
|
||||
def request_registration(user, password, server_location, shared_secret):
|
||||
def request_registration(user, password, server_location, shared_secret, admin=False):
|
||||
mac = hmac.new(
|
||||
key=shared_secret,
|
||||
msg=user,
|
||||
digestmod=hashlib.sha1,
|
||||
).hexdigest()
|
||||
)
|
||||
|
||||
mac.update(user)
|
||||
mac.update("\x00")
|
||||
mac.update(password)
|
||||
mac.update("\x00")
|
||||
mac.update("admin" if admin else "notadmin")
|
||||
|
||||
mac = mac.hexdigest()
|
||||
|
||||
data = {
|
||||
"user": user,
|
||||
"password": password,
|
||||
"mac": mac,
|
||||
"type": "org.matrix.login.shared_secret",
|
||||
"admin": admin,
|
||||
}
|
||||
|
||||
server_location = server_location.rstrip("/")
|
||||
|
@ -68,7 +76,7 @@ def request_registration(user, password, server_location, shared_secret):
|
|||
sys.exit(1)
|
||||
|
||||
|
||||
def register_new_user(user, password, server_location, shared_secret):
|
||||
def register_new_user(user, password, server_location, shared_secret, admin):
|
||||
if not user:
|
||||
try:
|
||||
default_user = getpass.getuser()
|
||||
|
@ -99,7 +107,14 @@ def register_new_user(user, password, server_location, shared_secret):
|
|||
print "Passwords do not match"
|
||||
sys.exit(1)
|
||||
|
||||
request_registration(user, password, server_location, shared_secret)
|
||||
if not admin:
|
||||
admin = raw_input("Make admin [no]: ")
|
||||
if admin in ("y", "yes", "true"):
|
||||
admin = True
|
||||
else:
|
||||
admin = False
|
||||
|
||||
request_registration(user, password, server_location, shared_secret, bool(admin))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -119,6 +134,11 @@ if __name__ == "__main__":
|
|||
default=None,
|
||||
help="New password for user. Will prompt if omitted.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-a", "--admin",
|
||||
action="store_true",
|
||||
help="Register new user as an admin. Will prompt if omitted.",
|
||||
)
|
||||
|
||||
group = parser.add_mutually_exclusive_group(required=True)
|
||||
group.add_argument(
|
||||
|
@ -151,4 +171,4 @@ if __name__ == "__main__":
|
|||
else:
|
||||
secret = args.shared_secret
|
||||
|
||||
register_new_user(args.user, args.password, args.server_url, secret)
|
||||
register_new_user(args.user, args.password, args.server_url, secret, args.admin)
|
||||
|
|
|
@ -34,7 +34,7 @@ logger = logging.getLogger("synapse_port_db")
|
|||
|
||||
|
||||
BOOLEAN_COLUMNS = {
|
||||
"events": ["processed", "outlier"],
|
||||
"events": ["processed", "outlier", "contains_url"],
|
||||
"rooms": ["is_public"],
|
||||
"event_edges": ["is_state"],
|
||||
"presence_list": ["accepted"],
|
||||
|
@ -92,8 +92,12 @@ class Store(object):
|
|||
|
||||
_simple_select_onecol_txn = SQLBaseStore.__dict__["_simple_select_onecol_txn"]
|
||||
_simple_select_onecol = SQLBaseStore.__dict__["_simple_select_onecol"]
|
||||
_simple_select_one = SQLBaseStore.__dict__["_simple_select_one"]
|
||||
_simple_select_one_txn = SQLBaseStore.__dict__["_simple_select_one_txn"]
|
||||
_simple_select_one_onecol = SQLBaseStore.__dict__["_simple_select_one_onecol"]
|
||||
_simple_select_one_onecol_txn = SQLBaseStore.__dict__["_simple_select_one_onecol_txn"]
|
||||
_simple_select_one_onecol_txn = SQLBaseStore.__dict__[
|
||||
"_simple_select_one_onecol_txn"
|
||||
]
|
||||
|
||||
_simple_update_one = SQLBaseStore.__dict__["_simple_update_one"]
|
||||
_simple_update_one_txn = SQLBaseStore.__dict__["_simple_update_one_txn"]
|
||||
|
@ -158,31 +162,40 @@ class Porter(object):
|
|||
def setup_table(self, table):
|
||||
if table in APPEND_ONLY_TABLES:
|
||||
# It's safe to just carry on inserting.
|
||||
next_chunk = yield self.postgres_store._simple_select_one_onecol(
|
||||
row = yield self.postgres_store._simple_select_one(
|
||||
table="port_from_sqlite3",
|
||||
keyvalues={"table_name": table},
|
||||
retcol="rowid",
|
||||
retcols=("forward_rowid", "backward_rowid"),
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
total_to_port = None
|
||||
if next_chunk is None:
|
||||
if row is None:
|
||||
if table == "sent_transactions":
|
||||
next_chunk, already_ported, total_to_port = (
|
||||
forward_chunk, already_ported, total_to_port = (
|
||||
yield self._setup_sent_transactions()
|
||||
)
|
||||
backward_chunk = 0
|
||||
else:
|
||||
yield self.postgres_store._simple_insert(
|
||||
table="port_from_sqlite3",
|
||||
values={"table_name": table, "rowid": 1}
|
||||
values={
|
||||
"table_name": table,
|
||||
"forward_rowid": 1,
|
||||
"backward_rowid": 0,
|
||||
}
|
||||
)
|
||||
|
||||
next_chunk = 1
|
||||
forward_chunk = 1
|
||||
backward_chunk = 0
|
||||
already_ported = 0
|
||||
else:
|
||||
forward_chunk = row["forward_rowid"]
|
||||
backward_chunk = row["backward_rowid"]
|
||||
|
||||
if total_to_port is None:
|
||||
already_ported, total_to_port = yield self._get_total_count_to_port(
|
||||
table, next_chunk
|
||||
table, forward_chunk, backward_chunk
|
||||
)
|
||||
else:
|
||||
def delete_all(txn):
|
||||
|
@ -196,46 +209,85 @@ class Porter(object):
|
|||
|
||||
yield self.postgres_store._simple_insert(
|
||||
table="port_from_sqlite3",
|
||||
values={"table_name": table, "rowid": 0}
|
||||
values={
|
||||
"table_name": table,
|
||||
"forward_rowid": 1,
|
||||
"backward_rowid": 0,
|
||||
}
|
||||
)
|
||||
|
||||
next_chunk = 1
|
||||
forward_chunk = 1
|
||||
backward_chunk = 0
|
||||
|
||||
already_ported, total_to_port = yield self._get_total_count_to_port(
|
||||
table, next_chunk
|
||||
table, forward_chunk, backward_chunk
|
||||
)
|
||||
|
||||
defer.returnValue((table, already_ported, total_to_port, next_chunk))
|
||||
defer.returnValue(
|
||||
(table, already_ported, total_to_port, forward_chunk, backward_chunk)
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def handle_table(self, table, postgres_size, table_size, next_chunk):
|
||||
def handle_table(self, table, postgres_size, table_size, forward_chunk,
|
||||
backward_chunk):
|
||||
if not table_size:
|
||||
return
|
||||
|
||||
self.progress.add_table(table, postgres_size, table_size)
|
||||
|
||||
if table == "event_search":
|
||||
yield self.handle_search_table(postgres_size, table_size, next_chunk)
|
||||
yield self.handle_search_table(
|
||||
postgres_size, table_size, forward_chunk, backward_chunk
|
||||
)
|
||||
return
|
||||
|
||||
select = (
|
||||
forward_select = (
|
||||
"SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?"
|
||||
% (table,)
|
||||
)
|
||||
|
||||
backward_select = (
|
||||
"SELECT rowid, * FROM %s WHERE rowid <= ? ORDER BY rowid LIMIT ?"
|
||||
% (table,)
|
||||
)
|
||||
|
||||
do_forward = [True]
|
||||
do_backward = [True]
|
||||
|
||||
while True:
|
||||
def r(txn):
|
||||
txn.execute(select, (next_chunk, self.batch_size,))
|
||||
rows = txn.fetchall()
|
||||
forward_rows = []
|
||||
backward_rows = []
|
||||
if do_forward[0]:
|
||||
txn.execute(forward_select, (forward_chunk, self.batch_size,))
|
||||
forward_rows = txn.fetchall()
|
||||
if not forward_rows:
|
||||
do_forward[0] = False
|
||||
|
||||
if do_backward[0]:
|
||||
txn.execute(backward_select, (backward_chunk, self.batch_size,))
|
||||
backward_rows = txn.fetchall()
|
||||
if not backward_rows:
|
||||
do_backward[0] = False
|
||||
|
||||
if forward_rows or backward_rows:
|
||||
headers = [column[0] for column in txn.description]
|
||||
else:
|
||||
headers = None
|
||||
|
||||
return headers, rows
|
||||
return headers, forward_rows, backward_rows
|
||||
|
||||
headers, rows = yield self.sqlite_store.runInteraction("select", r)
|
||||
headers, frows, brows = yield self.sqlite_store.runInteraction(
|
||||
"select", r
|
||||
)
|
||||
|
||||
if rows:
|
||||
next_chunk = rows[-1][0] + 1
|
||||
if frows or brows:
|
||||
if frows:
|
||||
forward_chunk = max(row[0] for row in frows) + 1
|
||||
if brows:
|
||||
backward_chunk = min(row[0] for row in brows) - 1
|
||||
|
||||
rows = frows + brows
|
||||
self._convert_rows(table, headers, rows)
|
||||
|
||||
def insert(txn):
|
||||
|
@ -247,7 +299,10 @@ class Porter(object):
|
|||
txn,
|
||||
table="port_from_sqlite3",
|
||||
keyvalues={"table_name": table},
|
||||
updatevalues={"rowid": next_chunk},
|
||||
updatevalues={
|
||||
"forward_rowid": forward_chunk,
|
||||
"backward_rowid": backward_chunk,
|
||||
},
|
||||
)
|
||||
|
||||
yield self.postgres_store.execute(insert)
|
||||
|
@ -259,7 +314,8 @@ class Porter(object):
|
|||
return
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def handle_search_table(self, postgres_size, table_size, next_chunk):
|
||||
def handle_search_table(self, postgres_size, table_size, forward_chunk,
|
||||
backward_chunk):
|
||||
select = (
|
||||
"SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering"
|
||||
" FROM event_search as es"
|
||||
|
@ -270,7 +326,7 @@ class Porter(object):
|
|||
|
||||
while True:
|
||||
def r(txn):
|
||||
txn.execute(select, (next_chunk, self.batch_size,))
|
||||
txn.execute(select, (forward_chunk, self.batch_size,))
|
||||
rows = txn.fetchall()
|
||||
headers = [column[0] for column in txn.description]
|
||||
|
||||
|
@ -279,7 +335,7 @@ class Porter(object):
|
|||
headers, rows = yield self.sqlite_store.runInteraction("select", r)
|
||||
|
||||
if rows:
|
||||
next_chunk = rows[-1][0] + 1
|
||||
forward_chunk = rows[-1][0] + 1
|
||||
|
||||
# We have to treat event_search differently since it has a
|
||||
# different structure in the two different databases.
|
||||
|
@ -312,7 +368,10 @@ class Porter(object):
|
|||
txn,
|
||||
table="port_from_sqlite3",
|
||||
keyvalues={"table_name": "event_search"},
|
||||
updatevalues={"rowid": next_chunk},
|
||||
updatevalues={
|
||||
"forward_rowid": forward_chunk,
|
||||
"backward_rowid": backward_chunk,
|
||||
},
|
||||
)
|
||||
|
||||
yield self.postgres_store.execute(insert)
|
||||
|
@ -324,7 +383,6 @@ class Porter(object):
|
|||
else:
|
||||
return
|
||||
|
||||
|
||||
def setup_db(self, db_config, database_engine):
|
||||
db_conn = database_engine.module.connect(
|
||||
**{
|
||||
|
@ -395,10 +453,32 @@ class Porter(object):
|
|||
txn.execute(
|
||||
"CREATE TABLE port_from_sqlite3 ("
|
||||
" table_name varchar(100) NOT NULL UNIQUE,"
|
||||
" rowid bigint NOT NULL"
|
||||
" forward_rowid bigint NOT NULL,"
|
||||
" backward_rowid bigint NOT NULL"
|
||||
")"
|
||||
)
|
||||
|
||||
# The old port script created a table with just a "rowid" column.
|
||||
# We want people to be able to rerun this script from an old port
|
||||
# so that they can pick up any missing events that were not
|
||||
# ported across.
|
||||
def alter_table(txn):
|
||||
txn.execute(
|
||||
"ALTER TABLE IF EXISTS port_from_sqlite3"
|
||||
" RENAME rowid TO forward_rowid"
|
||||
)
|
||||
txn.execute(
|
||||
"ALTER TABLE IF EXISTS port_from_sqlite3"
|
||||
" ADD backward_rowid bigint NOT NULL DEFAULT 0"
|
||||
)
|
||||
|
||||
try:
|
||||
yield self.postgres_store.runInteraction(
|
||||
"alter_table", alter_table
|
||||
)
|
||||
except Exception as e:
|
||||
logger.info("Failed to create port table: %s", e)
|
||||
|
||||
try:
|
||||
yield self.postgres_store.runInteraction(
|
||||
"create_port_table", create_port_table
|
||||
|
@ -458,7 +538,7 @@ class Porter(object):
|
|||
@defer.inlineCallbacks
|
||||
def _setup_sent_transactions(self):
|
||||
# Only save things from the last day
|
||||
yesterday = int(time.time()*1000) - 86400000
|
||||
yesterday = int(time.time() * 1000) - 86400000
|
||||
|
||||
# And save the max transaction id from each destination
|
||||
select = (
|
||||
|
@ -514,7 +594,11 @@ class Porter(object):
|
|||
|
||||
yield self.postgres_store._simple_insert(
|
||||
table="port_from_sqlite3",
|
||||
values={"table_name": "sent_transactions", "rowid": next_chunk}
|
||||
values={
|
||||
"table_name": "sent_transactions",
|
||||
"forward_rowid": next_chunk,
|
||||
"backward_rowid": 0,
|
||||
}
|
||||
)
|
||||
|
||||
def get_sent_table_size(txn):
|
||||
|
@ -535,13 +619,18 @@ class Porter(object):
|
|||
defer.returnValue((next_chunk, inserted_rows, total_count))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_remaining_count_to_port(self, table, next_chunk):
|
||||
rows = yield self.sqlite_store.execute_sql(
|
||||
def _get_remaining_count_to_port(self, table, forward_chunk, backward_chunk):
|
||||
frows = yield self.sqlite_store.execute_sql(
|
||||
"SELECT count(*) FROM %s WHERE rowid >= ?" % (table,),
|
||||
next_chunk,
|
||||
forward_chunk,
|
||||
)
|
||||
|
||||
defer.returnValue(rows[0][0])
|
||||
brows = yield self.sqlite_store.execute_sql(
|
||||
"SELECT count(*) FROM %s WHERE rowid <= ?" % (table,),
|
||||
backward_chunk,
|
||||
)
|
||||
|
||||
defer.returnValue(frows[0][0] + brows[0][0])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_already_ported_count(self, table):
|
||||
|
@ -552,10 +641,10 @@ class Porter(object):
|
|||
defer.returnValue(rows[0][0])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_total_count_to_port(self, table, next_chunk):
|
||||
def _get_total_count_to_port(self, table, forward_chunk, backward_chunk):
|
||||
remaining, done = yield defer.gatherResults(
|
||||
[
|
||||
self._get_remaining_count_to_port(table, next_chunk),
|
||||
self._get_remaining_count_to_port(table, forward_chunk, backward_chunk),
|
||||
self._get_already_ported_count(table),
|
||||
],
|
||||
consumeErrors=True,
|
||||
|
@ -686,7 +775,7 @@ class CursesProgress(Progress):
|
|||
color = curses.color_pair(2) if perc == 100 else curses.color_pair(1)
|
||||
|
||||
self.stdscr.addstr(
|
||||
i+2, left_margin + max_len - len(table),
|
||||
i + 2, left_margin + max_len - len(table),
|
||||
table,
|
||||
curses.A_BOLD | color,
|
||||
)
|
||||
|
@ -694,18 +783,18 @@ class CursesProgress(Progress):
|
|||
size = 20
|
||||
|
||||
progress = "[%s%s]" % (
|
||||
"#" * int(perc*size/100),
|
||||
" " * (size - int(perc*size/100)),
|
||||
"#" * int(perc * size / 100),
|
||||
" " * (size - int(perc * size / 100)),
|
||||
)
|
||||
|
||||
self.stdscr.addstr(
|
||||
i+2, left_margin + max_len + middle_space,
|
||||
i + 2, left_margin + max_len + middle_space,
|
||||
"%s %3d%% (%d/%d)" % (progress, perc, data["num_done"], data["total"]),
|
||||
)
|
||||
|
||||
if self.finished:
|
||||
self.stdscr.addstr(
|
||||
rows-1, 0,
|
||||
rows - 1, 0,
|
||||
"Press any key to exit...",
|
||||
)
|
||||
|
||||
|
|
|
@ -16,7 +16,5 @@ ignore =
|
|||
|
||||
[flake8]
|
||||
max-line-length = 90
|
||||
ignore = W503 ; W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it.
|
||||
|
||||
[pep8]
|
||||
max-line-length = 90
|
||||
# W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it.
|
||||
ignore = W503
|
||||
|
|
|
@ -16,4 +16,4 @@
|
|||
""" This is a reference implementation of a Matrix home server.
|
||||
"""
|
||||
|
||||
__version__ = "0.14.0"
|
||||
__version__ = "0.17.0"
|
||||
|
|
|
@ -13,23 +13,22 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""This module contains classes for authenticating the user."""
|
||||
import logging
|
||||
|
||||
import pymacaroons
|
||||
from canonicaljson import encode_canonical_json
|
||||
from signedjson.key import decode_verify_key_bytes
|
||||
from signedjson.sign import verify_signed_json, SignatureVerifyException
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership, JoinRules
|
||||
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
|
||||
from synapse.types import Requester, UserID, get_domain_from_id
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.logcontext import preserve_context_over_fn
|
||||
from synapse.util.metrics import Measure
|
||||
from unpaddedbase64 import decode_base64
|
||||
|
||||
import logging
|
||||
import pymacaroons
|
||||
import synapse.types
|
||||
from synapse.api.constants import EventTypes, Membership, JoinRules
|
||||
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
|
||||
from synapse.types import UserID, get_domain_from_id
|
||||
from synapse.util.logcontext import preserve_context_over_fn
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -42,13 +41,20 @@ AuthEventTypes = (
|
|||
|
||||
|
||||
class Auth(object):
|
||||
|
||||
"""
|
||||
FIXME: This class contains a mix of functions for authenticating users
|
||||
of our client-server API and authenticating events added to room graphs.
|
||||
"""
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
self.clock = hs.get_clock()
|
||||
self.store = hs.get_datastore()
|
||||
self.state = hs.get_state_handler()
|
||||
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
|
||||
# Docs for these currently lives at
|
||||
# https://github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst
|
||||
# In addition, we have type == delete_pusher which grants access only to
|
||||
# delete pushers.
|
||||
self._KNOWN_CAVEAT_PREFIXES = set([
|
||||
"gen = ",
|
||||
"guest = ",
|
||||
|
@ -57,7 +63,7 @@ class Auth(object):
|
|||
"user_id = ",
|
||||
])
|
||||
|
||||
def check(self, event, auth_events):
|
||||
def check(self, event, auth_events, do_sig_check=True):
|
||||
""" Checks if this event is correctly authed.
|
||||
|
||||
Args:
|
||||
|
@ -73,6 +79,13 @@ class Auth(object):
|
|||
|
||||
if not hasattr(event, "room_id"):
|
||||
raise AuthError(500, "Event has no room_id: %s" % event)
|
||||
|
||||
sender_domain = get_domain_from_id(event.sender)
|
||||
|
||||
# Check the sender's domain has signed the event
|
||||
if do_sig_check and not event.signatures.get(sender_domain):
|
||||
raise AuthError(403, "Event not signed by sending server")
|
||||
|
||||
if auth_events is None:
|
||||
# Oh, we don't know what the state of the room was, so we
|
||||
# are trusting that this is allowed (at least for now)
|
||||
|
@ -80,6 +93,12 @@ class Auth(object):
|
|||
return True
|
||||
|
||||
if event.type == EventTypes.Create:
|
||||
room_id_domain = get_domain_from_id(event.room_id)
|
||||
if room_id_domain != sender_domain:
|
||||
raise AuthError(
|
||||
403,
|
||||
"Creation event's room_id domain does not match sender's"
|
||||
)
|
||||
# FIXME
|
||||
return True
|
||||
|
||||
|
@ -102,6 +121,22 @@ class Auth(object):
|
|||
|
||||
# FIXME: Temp hack
|
||||
if event.type == EventTypes.Aliases:
|
||||
if not event.is_state():
|
||||
raise AuthError(
|
||||
403,
|
||||
"Alias event must be a state event",
|
||||
)
|
||||
if not event.state_key:
|
||||
raise AuthError(
|
||||
403,
|
||||
"Alias event must have non-empty state_key"
|
||||
)
|
||||
sender_domain = get_domain_from_id(event.sender)
|
||||
if event.state_key != sender_domain:
|
||||
raise AuthError(
|
||||
403,
|
||||
"Alias event's state_key does not match sender's domain"
|
||||
)
|
||||
return True
|
||||
|
||||
logger.debug(
|
||||
|
@ -120,6 +155,24 @@ class Auth(object):
|
|||
return allowed
|
||||
|
||||
self.check_event_sender_in_room(event, auth_events)
|
||||
|
||||
# Special case to allow m.room.third_party_invite events wherever
|
||||
# a user is allowed to issue invites. Fixes
|
||||
# https://github.com/vector-im/vector-web/issues/1208 hopefully
|
||||
if event.type == EventTypes.ThirdPartyInvite:
|
||||
user_level = self._get_user_power_level(event.user_id, auth_events)
|
||||
invite_level = self._get_named_level(auth_events, "invite", 0)
|
||||
|
||||
if user_level < invite_level:
|
||||
raise AuthError(
|
||||
403, (
|
||||
"You cannot issue a third party invite for %s." %
|
||||
(event.content.display_name,)
|
||||
)
|
||||
)
|
||||
else:
|
||||
return True
|
||||
|
||||
self._can_send_event(event, auth_events)
|
||||
|
||||
if event.type == EventTypes.PowerLevels:
|
||||
|
@ -323,6 +376,10 @@ class Auth(object):
|
|||
if Membership.INVITE == membership and "third_party_invite" in event.content:
|
||||
if not self._verify_third_party_invite(event, auth_events):
|
||||
raise AuthError(403, "You are not invited to this room.")
|
||||
if target_banned:
|
||||
raise AuthError(
|
||||
403, "%s is banned from the room" % (target_user_id,)
|
||||
)
|
||||
return True
|
||||
|
||||
if Membership.JOIN != membership:
|
||||
|
@ -507,15 +564,13 @@ class Auth(object):
|
|||
return default
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_user_by_req(self, request, allow_guest=False):
|
||||
def get_user_by_req(self, request, allow_guest=False, rights="access"):
|
||||
""" Get a registered user's ID.
|
||||
|
||||
Args:
|
||||
request - An HTTP request with an access_token query parameter.
|
||||
Returns:
|
||||
tuple of:
|
||||
UserID (str)
|
||||
Access token ID (str)
|
||||
defer.Deferred: resolves to a ``synapse.types.Requester`` object
|
||||
Raises:
|
||||
AuthError if no user by that token exists or the token is invalid.
|
||||
"""
|
||||
|
@ -524,16 +579,18 @@ class Auth(object):
|
|||
user_id = yield self._get_appservice_user_id(request.args)
|
||||
if user_id:
|
||||
request.authenticated_entity = user_id
|
||||
defer.returnValue(
|
||||
Requester(UserID.from_string(user_id), "", False)
|
||||
)
|
||||
defer.returnValue(synapse.types.create_requester(user_id))
|
||||
|
||||
access_token = request.args["access_token"][0]
|
||||
user_info = yield self.get_user_by_access_token(access_token)
|
||||
user_info = yield self.get_user_by_access_token(access_token, rights)
|
||||
user = user_info["user"]
|
||||
token_id = user_info["token_id"]
|
||||
is_guest = user_info["is_guest"]
|
||||
|
||||
# device_id may not be present if get_user_by_access_token has been
|
||||
# stubbed out.
|
||||
device_id = user_info.get("device_id")
|
||||
|
||||
ip_addr = self.hs.get_ip_from_request(request)
|
||||
user_agent = request.requestHeaders.getRawHeaders(
|
||||
"User-Agent",
|
||||
|
@ -545,7 +602,8 @@ class Auth(object):
|
|||
user=user,
|
||||
access_token=access_token,
|
||||
ip=ip_addr,
|
||||
user_agent=user_agent
|
||||
user_agent=user_agent,
|
||||
device_id=device_id,
|
||||
)
|
||||
|
||||
if is_guest and not allow_guest:
|
||||
|
@ -555,7 +613,8 @@ class Auth(object):
|
|||
|
||||
request.authenticated_entity = user.to_string()
|
||||
|
||||
defer.returnValue(Requester(user, token_id, is_guest))
|
||||
defer.returnValue(synapse.types.create_requester(
|
||||
user, token_id, is_guest, device_id))
|
||||
except KeyError:
|
||||
raise AuthError(
|
||||
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
|
||||
|
@ -590,7 +649,7 @@ class Auth(object):
|
|||
defer.returnValue(user_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_user_by_access_token(self, token):
|
||||
def get_user_by_access_token(self, token, rights="access"):
|
||||
""" Get a registered user's ID.
|
||||
|
||||
Args:
|
||||
|
@ -601,47 +660,61 @@ class Auth(object):
|
|||
AuthError if no user by that token exists or the token is invalid.
|
||||
"""
|
||||
try:
|
||||
ret = yield self.get_user_from_macaroon(token)
|
||||
ret = yield self.get_user_from_macaroon(token, rights)
|
||||
except AuthError:
|
||||
# TODO(daniel): Remove this fallback when all existing access tokens
|
||||
# have been re-issued as macaroons.
|
||||
if self.hs.config.expire_access_token:
|
||||
raise
|
||||
ret = yield self._look_up_user_by_access_token(token)
|
||||
|
||||
defer.returnValue(ret)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_user_from_macaroon(self, macaroon_str):
|
||||
def get_user_from_macaroon(self, macaroon_str, rights="access"):
|
||||
try:
|
||||
macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
|
||||
|
||||
self.validate_macaroon(macaroon, "access", self.hs.config.expire_access_token)
|
||||
user_id = self.get_user_id_from_macaroon(macaroon)
|
||||
user = UserID.from_string(user_id)
|
||||
|
||||
self.validate_macaroon(
|
||||
macaroon, rights, self.hs.config.expire_access_token,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
user_prefix = "user_id = "
|
||||
user = None
|
||||
guest = False
|
||||
for caveat in macaroon.caveats:
|
||||
if caveat.caveat_id.startswith(user_prefix):
|
||||
user = UserID.from_string(caveat.caveat_id[len(user_prefix):])
|
||||
elif caveat.caveat_id == "guest = true":
|
||||
if caveat.caveat_id == "guest = true":
|
||||
guest = True
|
||||
|
||||
if user is None:
|
||||
raise AuthError(
|
||||
self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
|
||||
errcode=Codes.UNKNOWN_TOKEN
|
||||
)
|
||||
|
||||
if guest:
|
||||
ret = {
|
||||
"user": user,
|
||||
"is_guest": True,
|
||||
"token_id": None,
|
||||
"device_id": None,
|
||||
}
|
||||
elif rights == "delete_pusher":
|
||||
# We don't store these tokens in the database
|
||||
ret = {
|
||||
"user": user,
|
||||
"is_guest": False,
|
||||
"token_id": None,
|
||||
"device_id": None,
|
||||
}
|
||||
else:
|
||||
# This codepath exists so that we can actually return a
|
||||
# token ID, because we use token IDs in place of device
|
||||
# identifiers throughout the codebase.
|
||||
# TODO(daniel): Remove this fallback when device IDs are
|
||||
# properly implemented.
|
||||
# This codepath exists for several reasons:
|
||||
# * so that we can actually return a token ID, which is used
|
||||
# in some parts of the schema (where we probably ought to
|
||||
# use device IDs instead)
|
||||
# * the only way we currently have to invalidate an
|
||||
# access_token is by removing it from the database, so we
|
||||
# have to check here that it is still in the db
|
||||
# * some attributes (notably device_id) aren't stored in the
|
||||
# macaroon. They probably should be.
|
||||
# TODO: build the dictionary from the macaroon once the
|
||||
# above are fixed
|
||||
ret = yield self._look_up_user_by_access_token(macaroon_str)
|
||||
if ret["user"] != user:
|
||||
logger.error(
|
||||
|
@ -661,21 +734,46 @@ class Auth(object):
|
|||
errcode=Codes.UNKNOWN_TOKEN
|
||||
)
|
||||
|
||||
def validate_macaroon(self, macaroon, type_string, verify_expiry):
|
||||
def get_user_id_from_macaroon(self, macaroon):
|
||||
"""Retrieve the user_id given by the caveats on the macaroon.
|
||||
|
||||
Does *not* validate the macaroon.
|
||||
|
||||
Args:
|
||||
macaroon (pymacaroons.Macaroon): The macaroon to validate
|
||||
|
||||
Returns:
|
||||
(str) user id
|
||||
|
||||
Raises:
|
||||
AuthError if there is no user_id caveat in the macaroon
|
||||
"""
|
||||
user_prefix = "user_id = "
|
||||
for caveat in macaroon.caveats:
|
||||
if caveat.caveat_id.startswith(user_prefix):
|
||||
return caveat.caveat_id[len(user_prefix):]
|
||||
raise AuthError(
|
||||
self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
|
||||
errcode=Codes.UNKNOWN_TOKEN
|
||||
)
|
||||
|
||||
def validate_macaroon(self, macaroon, type_string, verify_expiry, user_id):
|
||||
"""
|
||||
validate that a Macaroon is understood by and was signed by this server.
|
||||
|
||||
Args:
|
||||
macaroon(pymacaroons.Macaroon): The macaroon to validate
|
||||
type_string(str): The kind of token this is (e.g. "access", "refresh")
|
||||
type_string(str): The kind of token required (e.g. "access", "refresh",
|
||||
"delete_pusher")
|
||||
verify_expiry(bool): Whether to verify whether the macaroon has expired.
|
||||
This should really always be True, but no clients currently implement
|
||||
token refresh, so we can't enforce expiry yet.
|
||||
user_id (str): The user_id required
|
||||
"""
|
||||
v = pymacaroons.Verifier()
|
||||
v.satisfy_exact("gen = 1")
|
||||
v.satisfy_exact("type = " + type_string)
|
||||
v.satisfy_general(lambda c: c.startswith("user_id = "))
|
||||
v.satisfy_exact("user_id = %s" % user_id)
|
||||
v.satisfy_exact("guest = true")
|
||||
if verify_expiry:
|
||||
v.satisfy_general(self._verify_expiry)
|
||||
|
@ -714,10 +812,14 @@ class Auth(object):
|
|||
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.",
|
||||
errcode=Codes.UNKNOWN_TOKEN
|
||||
)
|
||||
# we use ret.get() below because *lots* of unit tests stub out
|
||||
# get_user_by_access_token in a way where it only returns a couple of
|
||||
# the fields.
|
||||
user_info = {
|
||||
"user": UserID.from_string(ret.get("name")),
|
||||
"token_id": ret.get("token_id", None),
|
||||
"is_guest": False,
|
||||
"device_id": ret.get("device_id"),
|
||||
}
|
||||
defer.returnValue(user_info)
|
||||
|
||||
|
|
|
@ -42,8 +42,10 @@ class Codes(object):
|
|||
TOO_LARGE = "M_TOO_LARGE"
|
||||
EXCLUSIVE = "M_EXCLUSIVE"
|
||||
THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
|
||||
THREEPID_IN_USE = "THREEPID_IN_USE"
|
||||
THREEPID_IN_USE = "M_THREEPID_IN_USE"
|
||||
THREEPID_NOT_FOUND = "M_THREEPID_NOT_FOUND"
|
||||
INVALID_USERNAME = "M_INVALID_USERNAME"
|
||||
SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED"
|
||||
|
||||
|
||||
class CodeMessageException(RuntimeError):
|
||||
|
|
|
@ -191,6 +191,17 @@ class Filter(object):
|
|||
def __init__(self, filter_json):
|
||||
self.filter_json = filter_json
|
||||
|
||||
self.types = self.filter_json.get("types", None)
|
||||
self.not_types = self.filter_json.get("not_types", [])
|
||||
|
||||
self.rooms = self.filter_json.get("rooms", None)
|
||||
self.not_rooms = self.filter_json.get("not_rooms", [])
|
||||
|
||||
self.senders = self.filter_json.get("senders", None)
|
||||
self.not_senders = self.filter_json.get("not_senders", [])
|
||||
|
||||
self.contains_url = self.filter_json.get("contains_url", None)
|
||||
|
||||
def check(self, event):
|
||||
"""Checks whether the filter matches the given event.
|
||||
|
||||
|
@ -209,9 +220,10 @@ class Filter(object):
|
|||
event.get("room_id", None),
|
||||
sender,
|
||||
event.get("type", None),
|
||||
"url" in event.get("content", {})
|
||||
)
|
||||
|
||||
def check_fields(self, room_id, sender, event_type):
|
||||
def check_fields(self, room_id, sender, event_type, contains_url):
|
||||
"""Checks whether the filter matches the given event fields.
|
||||
|
||||
Returns:
|
||||
|
@ -225,15 +237,20 @@ class Filter(object):
|
|||
|
||||
for name, match_func in literal_keys.items():
|
||||
not_name = "not_%s" % (name,)
|
||||
disallowed_values = self.filter_json.get(not_name, [])
|
||||
disallowed_values = getattr(self, not_name)
|
||||
if any(map(match_func, disallowed_values)):
|
||||
return False
|
||||
|
||||
allowed_values = self.filter_json.get(name, None)
|
||||
allowed_values = getattr(self, name)
|
||||
if allowed_values is not None:
|
||||
if not any(map(match_func, allowed_values)):
|
||||
return False
|
||||
|
||||
contains_url_filter = self.filter_json.get("contains_url")
|
||||
if contains_url_filter is not None:
|
||||
if contains_url_filter != contains_url:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def filter_rooms(self, room_ids):
|
||||
|
|
|
@ -16,13 +16,11 @@
|
|||
import sys
|
||||
sys.dont_write_bytecode = True
|
||||
|
||||
from synapse.python_dependencies import (
|
||||
check_requirements, MissingRequirementError
|
||||
) # NOQA
|
||||
from synapse import python_dependencies # noqa: E402
|
||||
|
||||
try:
|
||||
check_requirements()
|
||||
except MissingRequirementError as e:
|
||||
python_dependencies.check_requirements()
|
||||
except python_dependencies.MissingRequirementError as e:
|
||||
message = "\n".join([
|
||||
"Missing Requirement: %s" % (e.message,),
|
||||
"To install run:",
|
||||
|
|
206
synapse/app/federation_reader.py
Normal file
206
synapse/app/federation_reader.py
Normal file
|
@ -0,0 +1,206 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import synapse
|
||||
|
||||
from synapse.config._base import ConfigError
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.config.logger import setup_logging
|
||||
from synapse.http.site import SynapseSite
|
||||
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
|
||||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||
from synapse.replication.slave.storage.events import SlavedEventStore
|
||||
from synapse.replication.slave.storage.keys import SlavedKeyStore
|
||||
from synapse.replication.slave.storage.room import RoomStore
|
||||
from synapse.replication.slave.storage.transactions import TransactionStore
|
||||
from synapse.replication.slave.storage.directory import DirectoryStore
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.engines import create_engine
|
||||
from synapse.util.async import sleep
|
||||
from synapse.util.httpresourcetree import create_resource_tree
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
from synapse.util.manhole import manhole
|
||||
from synapse.util.rlimit import change_resource_limit
|
||||
from synapse.util.versionstring import get_version_string
|
||||
from synapse.api.urls import FEDERATION_PREFIX
|
||||
from synapse.federation.transport.server import TransportLayerServer
|
||||
from synapse.crypto import context_factory
|
||||
|
||||
|
||||
from twisted.internet import reactor, defer
|
||||
from twisted.web.resource import Resource
|
||||
|
||||
from daemonize import Daemonize
|
||||
|
||||
import sys
|
||||
import logging
|
||||
import gc
|
||||
|
||||
logger = logging.getLogger("synapse.app.federation_reader")
|
||||
|
||||
|
||||
class FederationReaderSlavedStore(
|
||||
SlavedEventStore,
|
||||
SlavedKeyStore,
|
||||
RoomStore,
|
||||
DirectoryStore,
|
||||
TransactionStore,
|
||||
BaseSlavedStore,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class FederationReaderServer(HomeServer):
|
||||
def get_db_conn(self, run_new_connection=True):
|
||||
# Any param beginning with cp_ is a parameter for adbapi, and should
|
||||
# not be passed to the database engine.
|
||||
db_params = {
|
||||
k: v for k, v in self.db_config.get("args", {}).items()
|
||||
if not k.startswith("cp_")
|
||||
}
|
||||
db_conn = self.database_engine.module.connect(**db_params)
|
||||
|
||||
if run_new_connection:
|
||||
self.database_engine.on_new_connection(db_conn)
|
||||
return db_conn
|
||||
|
||||
def setup(self):
|
||||
logger.info("Setting up.")
|
||||
self.datastore = FederationReaderSlavedStore(self.get_db_conn(), self)
|
||||
logger.info("Finished setting up.")
|
||||
|
||||
def _listen_http(self, listener_config):
|
||||
port = listener_config["port"]
|
||||
bind_address = listener_config.get("bind_address", "")
|
||||
site_tag = listener_config.get("tag", port)
|
||||
resources = {}
|
||||
for res in listener_config["resources"]:
|
||||
for name in res["names"]:
|
||||
if name == "metrics":
|
||||
resources[METRICS_PREFIX] = MetricsResource(self)
|
||||
elif name == "federation":
|
||||
resources.update({
|
||||
FEDERATION_PREFIX: TransportLayerServer(self),
|
||||
})
|
||||
|
||||
root_resource = create_resource_tree(resources, Resource())
|
||||
reactor.listenTCP(
|
||||
port,
|
||||
SynapseSite(
|
||||
"synapse.access.http.%s" % (site_tag,),
|
||||
site_tag,
|
||||
listener_config,
|
||||
root_resource,
|
||||
),
|
||||
interface=bind_address
|
||||
)
|
||||
logger.info("Synapse federation reader now listening on port %d", port)
|
||||
|
||||
def start_listening(self, listeners):
|
||||
for listener in listeners:
|
||||
if listener["type"] == "http":
|
||||
self._listen_http(listener)
|
||||
elif listener["type"] == "manhole":
|
||||
reactor.listenTCP(
|
||||
listener["port"],
|
||||
manhole(
|
||||
username="matrix",
|
||||
password="rabbithole",
|
||||
globals={"hs": self},
|
||||
),
|
||||
interface=listener.get("bind_address", '127.0.0.1')
|
||||
)
|
||||
else:
|
||||
logger.warn("Unrecognized listener type: %s", listener["type"])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def replicate(self):
|
||||
http_client = self.get_simple_http_client()
|
||||
store = self.get_datastore()
|
||||
replication_url = self.config.worker_replication_url
|
||||
|
||||
while True:
|
||||
try:
|
||||
args = store.stream_positions()
|
||||
args["timeout"] = 30000
|
||||
result = yield http_client.get_json(replication_url, args=args)
|
||||
yield store.process_replication(result)
|
||||
except:
|
||||
logger.exception("Error replicating from %r", replication_url)
|
||||
yield sleep(5)
|
||||
|
||||
|
||||
def start(config_options):
|
||||
try:
|
||||
config = HomeServerConfig.load_config(
|
||||
"Synapse federation reader", config_options
|
||||
)
|
||||
except ConfigError as e:
|
||||
sys.stderr.write("\n" + e.message + "\n")
|
||||
sys.exit(1)
|
||||
|
||||
assert config.worker_app == "synapse.app.federation_reader"
|
||||
|
||||
setup_logging(config.worker_log_config, config.worker_log_file)
|
||||
|
||||
database_engine = create_engine(config.database_config)
|
||||
|
||||
tls_server_context_factory = context_factory.ServerContextFactory(config)
|
||||
|
||||
ss = FederationReaderServer(
|
||||
config.server_name,
|
||||
db_config=config.database_config,
|
||||
tls_server_context_factory=tls_server_context_factory,
|
||||
config=config,
|
||||
version_string="Synapse/" + get_version_string(synapse),
|
||||
database_engine=database_engine,
|
||||
)
|
||||
|
||||
ss.setup()
|
||||
ss.get_handlers()
|
||||
ss.start_listening(config.worker_listeners)
|
||||
|
||||
def run():
|
||||
with LoggingContext("run"):
|
||||
logger.info("Running")
|
||||
change_resource_limit(config.soft_file_limit)
|
||||
if config.gc_thresholds:
|
||||
gc.set_threshold(*config.gc_thresholds)
|
||||
reactor.run()
|
||||
|
||||
def start():
|
||||
ss.get_datastore().start_profiling()
|
||||
ss.replicate()
|
||||
|
||||
reactor.callWhenRunning(start)
|
||||
|
||||
if config.worker_daemonize:
|
||||
daemon = Daemonize(
|
||||
app="synapse-federation-reader",
|
||||
pid=config.worker_pid_file,
|
||||
action=run,
|
||||
auto_close_fds=False,
|
||||
verbose=True,
|
||||
logger=logger,
|
||||
)
|
||||
daemon.start()
|
||||
else:
|
||||
run()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
with LoggingContext("main"):
|
||||
start(sys.argv[1:])
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
import synapse
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
@ -50,6 +51,7 @@ from synapse.api.urls import (
|
|||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.crypto import context_factory
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
from synapse.metrics import register_memory_metrics
|
||||
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
|
||||
from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX
|
||||
from synapse.federation.transport.server import TransportLayerServer
|
||||
|
@ -146,7 +148,7 @@ class SynapseHomeServer(HomeServer):
|
|||
MEDIA_PREFIX: media_repo,
|
||||
LEGACY_MEDIA_PREFIX: media_repo,
|
||||
CONTENT_REPO_PREFIX: ContentRepoResource(
|
||||
self, self.config.uploads_path, self.auth, self.content_addr
|
||||
self, self.config.uploads_path
|
||||
),
|
||||
})
|
||||
|
||||
|
@ -265,10 +267,9 @@ def setup(config_options):
|
|||
HomeServer
|
||||
"""
|
||||
try:
|
||||
config = HomeServerConfig.load_config(
|
||||
config = HomeServerConfig.load_or_generate_config(
|
||||
"Synapse Homeserver",
|
||||
config_options,
|
||||
generate_section="Homeserver"
|
||||
)
|
||||
except ConfigError as e:
|
||||
sys.stderr.write("\n" + e.message + "\n")
|
||||
|
@ -284,7 +285,7 @@ def setup(config_options):
|
|||
# check any extra requirements we have now we have a config
|
||||
check_requirements(config)
|
||||
|
||||
version_string = get_version_string("Synapse", synapse)
|
||||
version_string = "Synapse/" + get_version_string(synapse)
|
||||
|
||||
logger.info("Server hostname: %s", config.server_name)
|
||||
logger.info("Server version: %s", version_string)
|
||||
|
@ -301,7 +302,6 @@ def setup(config_options):
|
|||
db_config=config.database_config,
|
||||
tls_server_context_factory=tls_server_context_factory,
|
||||
config=config,
|
||||
content_addr=config.content_addr,
|
||||
version_string=version_string,
|
||||
database_engine=database_engine,
|
||||
)
|
||||
|
@ -336,6 +336,8 @@ def setup(config_options):
|
|||
hs.get_datastore().start_doing_background_updates()
|
||||
hs.get_replication_layer().start_get_pdu_cache()
|
||||
|
||||
register_memory_metrics(hs)
|
||||
|
||||
reactor.callWhenRunning(start)
|
||||
|
||||
return hs
|
||||
|
@ -351,6 +353,8 @@ class SynapseService(service.Service):
|
|||
def startService(self):
|
||||
hs = setup(self.config)
|
||||
change_resource_limit(hs.config.soft_file_limit)
|
||||
if hs.config.gc_thresholds:
|
||||
gc.set_threshold(*hs.config.gc_thresholds)
|
||||
|
||||
def stopService(self):
|
||||
return self._port.stopListening()
|
||||
|
@ -422,6 +426,8 @@ def run(hs):
|
|||
# sys.settrace(logcontext_tracer)
|
||||
with LoggingContext("run"):
|
||||
change_resource_limit(hs.config.soft_file_limit)
|
||||
if hs.config.gc_thresholds:
|
||||
gc.set_threshold(*hs.config.gc_thresholds)
|
||||
reactor.run()
|
||||
|
||||
if hs.config.daemonize:
|
||||
|
|
|
@ -18,9 +18,8 @@ import synapse
|
|||
|
||||
from synapse.server import HomeServer
|
||||
from synapse.config._base import ConfigError
|
||||
from synapse.config.database import DatabaseConfig
|
||||
from synapse.config.logger import LoggingConfig
|
||||
from synapse.config.emailconfig import EmailConfig
|
||||
from synapse.config.logger import setup_logging
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.http.site import SynapseSite
|
||||
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
|
||||
from synapse.storage.roommember import RoomMemberStore
|
||||
|
@ -44,61 +43,11 @@ from daemonize import Daemonize
|
|||
|
||||
import sys
|
||||
import logging
|
||||
import gc
|
||||
|
||||
logger = logging.getLogger("synapse.app.pusher")
|
||||
|
||||
|
||||
class SlaveConfig(DatabaseConfig):
|
||||
def read_config(self, config):
|
||||
self.replication_url = config["replication_url"]
|
||||
self.server_name = config["server_name"]
|
||||
self.use_insecure_ssl_client_just_for_testing_do_not_use = config.get(
|
||||
"use_insecure_ssl_client_just_for_testing_do_not_use", False
|
||||
)
|
||||
self.user_agent_suffix = None
|
||||
self.start_pushers = True
|
||||
self.listeners = config["listeners"]
|
||||
self.soft_file_limit = config.get("soft_file_limit")
|
||||
self.daemonize = config.get("daemonize")
|
||||
self.pid_file = self.abspath(config.get("pid_file"))
|
||||
self.public_baseurl = config["public_baseurl"]
|
||||
|
||||
def default_config(self, server_name, **kwargs):
|
||||
pid_file = self.abspath("pusher.pid")
|
||||
return """\
|
||||
# Slave configuration
|
||||
|
||||
# The replication listener on the synapse to talk to.
|
||||
#replication_url: https://localhost:{replication_port}/_synapse/replication
|
||||
|
||||
server_name: "%(server_name)s"
|
||||
|
||||
listeners: []
|
||||
# Enable a ssh manhole listener on the pusher.
|
||||
# - type: manhole
|
||||
# port: {manhole_port}
|
||||
# bind_address: 127.0.0.1
|
||||
# Enable a metric listener on the pusher.
|
||||
# - type: http
|
||||
# port: {metrics_port}
|
||||
# bind_address: 127.0.0.1
|
||||
# resources:
|
||||
# - names: ["metrics"]
|
||||
# compress: False
|
||||
|
||||
report_stats: False
|
||||
|
||||
daemonize: False
|
||||
|
||||
pid_file: %(pid_file)s
|
||||
|
||||
""" % locals()
|
||||
|
||||
|
||||
class PusherSlaveConfig(SlaveConfig, LoggingConfig, EmailConfig):
|
||||
pass
|
||||
|
||||
|
||||
class PusherSlaveStore(
|
||||
SlavedEventStore, SlavedPusherStore, SlavedReceiptsStore,
|
||||
SlavedAccountDataStore
|
||||
|
@ -163,7 +112,7 @@ class PusherServer(HomeServer):
|
|||
|
||||
def remove_pusher(self, app_id, push_key, user_id):
|
||||
http_client = self.get_simple_http_client()
|
||||
replication_url = self.config.replication_url
|
||||
replication_url = self.config.worker_replication_url
|
||||
url = replication_url + "/remove_pushers"
|
||||
return http_client.post_json_get_json(url, {
|
||||
"remove": [{
|
||||
|
@ -196,8 +145,8 @@ class PusherServer(HomeServer):
|
|||
)
|
||||
logger.info("Synapse pusher now listening on port %d", port)
|
||||
|
||||
def start_listening(self):
|
||||
for listener in self.config.listeners:
|
||||
def start_listening(self, listeners):
|
||||
for listener in listeners:
|
||||
if listener["type"] == "http":
|
||||
self._listen_http(listener)
|
||||
elif listener["type"] == "manhole":
|
||||
|
@ -217,7 +166,7 @@ class PusherServer(HomeServer):
|
|||
def replicate(self):
|
||||
http_client = self.get_simple_http_client()
|
||||
store = self.get_datastore()
|
||||
replication_url = self.config.replication_url
|
||||
replication_url = self.config.worker_replication_url
|
||||
pusher_pool = self.get_pusherpool()
|
||||
clock = self.get_clock()
|
||||
|
||||
|
@ -290,22 +239,33 @@ class PusherServer(HomeServer):
|
|||
poke_pushers(result)
|
||||
except:
|
||||
logger.exception("Error replicating from %r", replication_url)
|
||||
sleep(30)
|
||||
yield sleep(30)
|
||||
|
||||
|
||||
def setup(config_options):
|
||||
def start(config_options):
|
||||
try:
|
||||
config = PusherSlaveConfig.load_config(
|
||||
config = HomeServerConfig.load_config(
|
||||
"Synapse pusher", config_options
|
||||
)
|
||||
except ConfigError as e:
|
||||
sys.stderr.write("\n" + e.message + "\n")
|
||||
sys.exit(1)
|
||||
|
||||
if not config:
|
||||
sys.exit(0)
|
||||
assert config.worker_app == "synapse.app.pusher"
|
||||
|
||||
config.setup_logging()
|
||||
setup_logging(config.worker_log_config, config.worker_log_file)
|
||||
|
||||
if config.start_pushers:
|
||||
sys.stderr.write(
|
||||
"\nThe pushers must be disabled in the main synapse process"
|
||||
"\nbefore they can be run in a separate worker."
|
||||
"\nPlease add ``start_pushers: false`` to the main config"
|
||||
"\n"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# Force the pushers to start since they will be disabled in the main config
|
||||
config.start_pushers = True
|
||||
|
||||
database_engine = create_engine(config.database_config)
|
||||
|
||||
|
@ -313,14 +273,20 @@ def setup(config_options):
|
|||
config.server_name,
|
||||
db_config=config.database_config,
|
||||
config=config,
|
||||
version_string=get_version_string("Synapse", synapse),
|
||||
version_string="Synapse/" + get_version_string(synapse),
|
||||
database_engine=database_engine,
|
||||
)
|
||||
|
||||
ps.setup()
|
||||
ps.start_listening()
|
||||
ps.start_listening(config.worker_listeners)
|
||||
|
||||
change_resource_limit(ps.config.soft_file_limit)
|
||||
def run():
|
||||
with LoggingContext("run"):
|
||||
logger.info("Running")
|
||||
change_resource_limit(config.soft_file_limit)
|
||||
if config.gc_thresholds:
|
||||
gc.set_threshold(*config.gc_thresholds)
|
||||
reactor.run()
|
||||
|
||||
def start():
|
||||
ps.replicate()
|
||||
|
@ -329,28 +295,20 @@ def setup(config_options):
|
|||
|
||||
reactor.callWhenRunning(start)
|
||||
|
||||
return ps
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
with LoggingContext("main"):
|
||||
ps = setup(sys.argv[1:])
|
||||
|
||||
if ps.config.daemonize:
|
||||
def run():
|
||||
with LoggingContext("run"):
|
||||
change_resource_limit(ps.config.soft_file_limit)
|
||||
reactor.run()
|
||||
|
||||
if config.worker_daemonize:
|
||||
daemon = Daemonize(
|
||||
app="synapse-pusher",
|
||||
pid=ps.config.pid_file,
|
||||
pid=config.worker_pid_file,
|
||||
action=run,
|
||||
auto_close_fds=False,
|
||||
verbose=True,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
daemon.start()
|
||||
else:
|
||||
reactor.run()
|
||||
run()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
with LoggingContext("main"):
|
||||
ps = start(sys.argv[1:])
|
||||
|
|
465
synapse/app/synchrotron.py
Normal file
465
synapse/app/synchrotron.py
Normal file
|
@ -0,0 +1,465 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import synapse
|
||||
|
||||
from synapse.api.constants import EventTypes, PresenceState
|
||||
from synapse.config._base import ConfigError
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.config.logger import setup_logging
|
||||
from synapse.events import FrozenEvent
|
||||
from synapse.handlers.presence import PresenceHandler
|
||||
from synapse.http.site import SynapseSite
|
||||
from synapse.http.server import JsonResource
|
||||
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
|
||||
from synapse.rest.client.v2_alpha import sync
|
||||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||
from synapse.replication.slave.storage.events import SlavedEventStore
|
||||
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
|
||||
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
|
||||
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
|
||||
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
||||
from synapse.replication.slave.storage.filtering import SlavedFilteringStore
|
||||
from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
|
||||
from synapse.replication.slave.storage.presence import SlavedPresenceStore
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.client_ips import ClientIpStore
|
||||
from synapse.storage.engines import create_engine
|
||||
from synapse.storage.presence import PresenceStore, UserPresenceState
|
||||
from synapse.storage.roommember import RoomMemberStore
|
||||
from synapse.util.async import sleep
|
||||
from synapse.util.httpresourcetree import create_resource_tree
|
||||
from synapse.util.logcontext import LoggingContext, preserve_fn
|
||||
from synapse.util.manhole import manhole
|
||||
from synapse.util.rlimit import change_resource_limit
|
||||
from synapse.util.stringutils import random_string
|
||||
from synapse.util.versionstring import get_version_string
|
||||
|
||||
from twisted.internet import reactor, defer
|
||||
from twisted.web.resource import Resource
|
||||
|
||||
from daemonize import Daemonize
|
||||
|
||||
import sys
|
||||
import logging
|
||||
import contextlib
|
||||
import gc
|
||||
import ujson as json
|
||||
|
||||
logger = logging.getLogger("synapse.app.synchrotron")
|
||||
|
||||
|
||||
class SynchrotronSlavedStore(
|
||||
SlavedPushRuleStore,
|
||||
SlavedEventStore,
|
||||
SlavedReceiptsStore,
|
||||
SlavedAccountDataStore,
|
||||
SlavedApplicationServiceStore,
|
||||
SlavedRegistrationStore,
|
||||
SlavedFilteringStore,
|
||||
SlavedPresenceStore,
|
||||
BaseSlavedStore,
|
||||
ClientIpStore, # After BaseSlavedStore because the constructor is different
|
||||
):
|
||||
# XXX: This is a bit broken because we don't persist forgotten rooms
|
||||
# in a way that they can be streamed. This means that we don't have a
|
||||
# way to invalidate the forgotten rooms cache correctly.
|
||||
# For now we expire the cache every 10 minutes.
|
||||
BROKEN_CACHE_EXPIRY_MS = 60 * 60 * 1000
|
||||
who_forgot_in_room = (
|
||||
RoomMemberStore.__dict__["who_forgot_in_room"]
|
||||
)
|
||||
|
||||
# XXX: This is a bit broken because we don't persist the accepted list in a
|
||||
# way that can be replicated. This means that we don't have a way to
|
||||
# invalidate the cache correctly.
|
||||
get_presence_list_accepted = PresenceStore.__dict__[
|
||||
"get_presence_list_accepted"
|
||||
]
|
||||
|
||||
UPDATE_SYNCING_USERS_MS = 10 * 1000
|
||||
|
||||
|
||||
class SynchrotronPresence(object):
|
||||
def __init__(self, hs):
|
||||
self.http_client = hs.get_simple_http_client()
|
||||
self.store = hs.get_datastore()
|
||||
self.user_to_num_current_syncs = {}
|
||||
self.syncing_users_url = hs.config.worker_replication_url + "/syncing_users"
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
active_presence = self.store.take_presence_startup_info()
|
||||
self.user_to_current_state = {
|
||||
state.user_id: state
|
||||
for state in active_presence
|
||||
}
|
||||
|
||||
self.process_id = random_string(16)
|
||||
logger.info("Presence process_id is %r", self.process_id)
|
||||
|
||||
self._sending_sync = False
|
||||
self._need_to_send_sync = False
|
||||
self.clock.looping_call(
|
||||
self._send_syncing_users_regularly,
|
||||
UPDATE_SYNCING_USERS_MS,
|
||||
)
|
||||
|
||||
reactor.addSystemEventTrigger("before", "shutdown", self._on_shutdown)
|
||||
|
||||
def set_state(self, user, state):
|
||||
# TODO Hows this supposed to work?
|
||||
pass
|
||||
|
||||
get_states = PresenceHandler.get_states.__func__
|
||||
current_state_for_users = PresenceHandler.current_state_for_users.__func__
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def user_syncing(self, user_id, affect_presence):
|
||||
if affect_presence:
|
||||
curr_sync = self.user_to_num_current_syncs.get(user_id, 0)
|
||||
self.user_to_num_current_syncs[user_id] = curr_sync + 1
|
||||
prev_states = yield self.current_state_for_users([user_id])
|
||||
if prev_states[user_id].state == PresenceState.OFFLINE:
|
||||
# TODO: Don't block the sync request on this HTTP hit.
|
||||
yield self._send_syncing_users_now()
|
||||
|
||||
def _end():
|
||||
# We check that the user_id is in user_to_num_current_syncs because
|
||||
# user_to_num_current_syncs may have been cleared if we are
|
||||
# shutting down.
|
||||
if affect_presence and user_id in self.user_to_num_current_syncs:
|
||||
self.user_to_num_current_syncs[user_id] -= 1
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _user_syncing():
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_end()
|
||||
|
||||
defer.returnValue(_user_syncing())
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _on_shutdown(self):
|
||||
# When the synchrotron is shutdown tell the master to clear the in
|
||||
# progress syncs for this process
|
||||
self.user_to_num_current_syncs.clear()
|
||||
yield self._send_syncing_users_now()
|
||||
|
||||
def _send_syncing_users_regularly(self):
|
||||
# Only send an update if we aren't in the middle of sending one.
|
||||
if not self._sending_sync:
|
||||
preserve_fn(self._send_syncing_users_now)()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _send_syncing_users_now(self):
|
||||
if self._sending_sync:
|
||||
# We don't want to race with sending another update.
|
||||
# Instead we wait for that update to finish and send another
|
||||
# update afterwards.
|
||||
self._need_to_send_sync = True
|
||||
return
|
||||
|
||||
# Flag that we are sending an update.
|
||||
self._sending_sync = True
|
||||
|
||||
yield self.http_client.post_json_get_json(self.syncing_users_url, {
|
||||
"process_id": self.process_id,
|
||||
"syncing_users": [
|
||||
user_id for user_id, count in self.user_to_num_current_syncs.items()
|
||||
if count > 0
|
||||
],
|
||||
})
|
||||
|
||||
# Unset the flag as we are no longer sending an update.
|
||||
self._sending_sync = False
|
||||
if self._need_to_send_sync:
|
||||
# If something happened while we were sending the update then
|
||||
# we might need to send another update.
|
||||
# TODO: Check if the update that was sent matches the current state
|
||||
# as we only need to send an update if they are different.
|
||||
self._need_to_send_sync = False
|
||||
yield self._send_syncing_users_now()
|
||||
|
||||
def process_replication(self, result):
|
||||
stream = result.get("presence", {"rows": []})
|
||||
for row in stream["rows"]:
|
||||
(
|
||||
position, user_id, state, last_active_ts,
|
||||
last_federation_update_ts, last_user_sync_ts, status_msg,
|
||||
currently_active
|
||||
) = row
|
||||
self.user_to_current_state[user_id] = UserPresenceState(
|
||||
user_id, state, last_active_ts,
|
||||
last_federation_update_ts, last_user_sync_ts, status_msg,
|
||||
currently_active
|
||||
)
|
||||
|
||||
|
||||
class SynchrotronTyping(object):
|
||||
def __init__(self, hs):
|
||||
self._latest_room_serial = 0
|
||||
self._room_serials = {}
|
||||
self._room_typing = {}
|
||||
|
||||
def stream_positions(self):
|
||||
return {"typing": self._latest_room_serial}
|
||||
|
||||
def process_replication(self, result):
|
||||
stream = result.get("typing")
|
||||
if stream:
|
||||
self._latest_room_serial = int(stream["position"])
|
||||
|
||||
for row in stream["rows"]:
|
||||
position, room_id, typing_json = row
|
||||
typing = json.loads(typing_json)
|
||||
self._room_serials[room_id] = position
|
||||
self._room_typing[room_id] = typing
|
||||
|
||||
|
||||
class SynchrotronApplicationService(object):
|
||||
def notify_interested_services(self, event):
|
||||
pass
|
||||
|
||||
|
||||
class SynchrotronServer(HomeServer):
|
||||
def get_db_conn(self, run_new_connection=True):
|
||||
# Any param beginning with cp_ is a parameter for adbapi, and should
|
||||
# not be passed to the database engine.
|
||||
db_params = {
|
||||
k: v for k, v in self.db_config.get("args", {}).items()
|
||||
if not k.startswith("cp_")
|
||||
}
|
||||
db_conn = self.database_engine.module.connect(**db_params)
|
||||
|
||||
if run_new_connection:
|
||||
self.database_engine.on_new_connection(db_conn)
|
||||
return db_conn
|
||||
|
||||
def setup(self):
|
||||
logger.info("Setting up.")
|
||||
self.datastore = SynchrotronSlavedStore(self.get_db_conn(), self)
|
||||
logger.info("Finished setting up.")
|
||||
|
||||
def _listen_http(self, listener_config):
|
||||
port = listener_config["port"]
|
||||
bind_address = listener_config.get("bind_address", "")
|
||||
site_tag = listener_config.get("tag", port)
|
||||
resources = {}
|
||||
for res in listener_config["resources"]:
|
||||
for name in res["names"]:
|
||||
if name == "metrics":
|
||||
resources[METRICS_PREFIX] = MetricsResource(self)
|
||||
elif name == "client":
|
||||
resource = JsonResource(self, canonical_json=False)
|
||||
sync.register_servlets(self, resource)
|
||||
resources.update({
|
||||
"/_matrix/client/r0": resource,
|
||||
"/_matrix/client/unstable": resource,
|
||||
"/_matrix/client/v2_alpha": resource,
|
||||
})
|
||||
|
||||
root_resource = create_resource_tree(resources, Resource())
|
||||
reactor.listenTCP(
|
||||
port,
|
||||
SynapseSite(
|
||||
"synapse.access.http.%s" % (site_tag,),
|
||||
site_tag,
|
||||
listener_config,
|
||||
root_resource,
|
||||
),
|
||||
interface=bind_address
|
||||
)
|
||||
logger.info("Synapse synchrotron now listening on port %d", port)
|
||||
|
||||
def start_listening(self, listeners):
|
||||
for listener in listeners:
|
||||
if listener["type"] == "http":
|
||||
self._listen_http(listener)
|
||||
elif listener["type"] == "manhole":
|
||||
reactor.listenTCP(
|
||||
listener["port"],
|
||||
manhole(
|
||||
username="matrix",
|
||||
password="rabbithole",
|
||||
globals={"hs": self},
|
||||
),
|
||||
interface=listener.get("bind_address", '127.0.0.1')
|
||||
)
|
||||
else:
|
||||
logger.warn("Unrecognized listener type: %s", listener["type"])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def replicate(self):
|
||||
http_client = self.get_simple_http_client()
|
||||
store = self.get_datastore()
|
||||
replication_url = self.config.worker_replication_url
|
||||
clock = self.get_clock()
|
||||
notifier = self.get_notifier()
|
||||
presence_handler = self.get_presence_handler()
|
||||
typing_handler = self.get_typing_handler()
|
||||
|
||||
def expire_broken_caches():
|
||||
store.who_forgot_in_room.invalidate_all()
|
||||
store.get_presence_list_accepted.invalidate_all()
|
||||
|
||||
def notify_from_stream(
|
||||
result, stream_name, stream_key, room=None, user=None
|
||||
):
|
||||
stream = result.get(stream_name)
|
||||
if stream:
|
||||
position_index = stream["field_names"].index("position")
|
||||
if room:
|
||||
room_index = stream["field_names"].index(room)
|
||||
if user:
|
||||
user_index = stream["field_names"].index(user)
|
||||
|
||||
users = ()
|
||||
rooms = ()
|
||||
for row in stream["rows"]:
|
||||
position = row[position_index]
|
||||
|
||||
if user:
|
||||
users = (row[user_index],)
|
||||
|
||||
if room:
|
||||
rooms = (row[room_index],)
|
||||
|
||||
notifier.on_new_event(
|
||||
stream_key, position, users=users, rooms=rooms
|
||||
)
|
||||
|
||||
def notify(result):
|
||||
stream = result.get("events")
|
||||
if stream:
|
||||
max_position = stream["position"]
|
||||
for row in stream["rows"]:
|
||||
position = row[0]
|
||||
internal = json.loads(row[1])
|
||||
event_json = json.loads(row[2])
|
||||
event = FrozenEvent(event_json, internal_metadata_dict=internal)
|
||||
extra_users = ()
|
||||
if event.type == EventTypes.Member:
|
||||
extra_users = (event.state_key,)
|
||||
notifier.on_new_room_event(
|
||||
event, position, max_position, extra_users
|
||||
)
|
||||
|
||||
notify_from_stream(
|
||||
result, "push_rules", "push_rules_key", user="user_id"
|
||||
)
|
||||
notify_from_stream(
|
||||
result, "user_account_data", "account_data_key", user="user_id"
|
||||
)
|
||||
notify_from_stream(
|
||||
result, "room_account_data", "account_data_key", user="user_id"
|
||||
)
|
||||
notify_from_stream(
|
||||
result, "tag_account_data", "account_data_key", user="user_id"
|
||||
)
|
||||
notify_from_stream(
|
||||
result, "receipts", "receipt_key", room="room_id"
|
||||
)
|
||||
notify_from_stream(
|
||||
result, "typing", "typing_key", room="room_id"
|
||||
)
|
||||
|
||||
next_expire_broken_caches_ms = 0
|
||||
while True:
|
||||
try:
|
||||
args = store.stream_positions()
|
||||
args.update(typing_handler.stream_positions())
|
||||
args["timeout"] = 30000
|
||||
result = yield http_client.get_json(replication_url, args=args)
|
||||
now_ms = clock.time_msec()
|
||||
if now_ms > next_expire_broken_caches_ms:
|
||||
expire_broken_caches()
|
||||
next_expire_broken_caches_ms = (
|
||||
now_ms + store.BROKEN_CACHE_EXPIRY_MS
|
||||
)
|
||||
yield store.process_replication(result)
|
||||
typing_handler.process_replication(result)
|
||||
presence_handler.process_replication(result)
|
||||
notify(result)
|
||||
except:
|
||||
logger.exception("Error replicating from %r", replication_url)
|
||||
yield sleep(5)
|
||||
|
||||
def build_presence_handler(self):
|
||||
return SynchrotronPresence(self)
|
||||
|
||||
def build_typing_handler(self):
|
||||
return SynchrotronTyping(self)
|
||||
|
||||
|
||||
def start(config_options):
|
||||
try:
|
||||
config = HomeServerConfig.load_config(
|
||||
"Synapse synchrotron", config_options
|
||||
)
|
||||
except ConfigError as e:
|
||||
sys.stderr.write("\n" + e.message + "\n")
|
||||
sys.exit(1)
|
||||
|
||||
assert config.worker_app == "synapse.app.synchrotron"
|
||||
|
||||
setup_logging(config.worker_log_config, config.worker_log_file)
|
||||
|
||||
database_engine = create_engine(config.database_config)
|
||||
|
||||
ss = SynchrotronServer(
|
||||
config.server_name,
|
||||
db_config=config.database_config,
|
||||
config=config,
|
||||
version_string="Synapse/" + get_version_string(synapse),
|
||||
database_engine=database_engine,
|
||||
application_service_handler=SynchrotronApplicationService(),
|
||||
)
|
||||
|
||||
ss.setup()
|
||||
ss.start_listening(config.worker_listeners)
|
||||
|
||||
def run():
|
||||
with LoggingContext("run"):
|
||||
logger.info("Running")
|
||||
change_resource_limit(config.soft_file_limit)
|
||||
if config.gc_thresholds:
|
||||
gc.set_threshold(*config.gc_thresholds)
|
||||
reactor.run()
|
||||
|
||||
def start():
|
||||
ss.get_datastore().start_profiling()
|
||||
ss.replicate()
|
||||
|
||||
reactor.callWhenRunning(start)
|
||||
|
||||
if config.worker_daemonize:
|
||||
daemon = Daemonize(
|
||||
app="synapse-synchrotron",
|
||||
pid=config.worker_pid_file,
|
||||
action=run,
|
||||
auto_close_fds=False,
|
||||
verbose=True,
|
||||
logger=logger,
|
||||
)
|
||||
daemon.start()
|
||||
else:
|
||||
run()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
with LoggingContext("main"):
|
||||
start(sys.argv[1:])
|
|
@ -14,11 +14,14 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
import collections
|
||||
import glob
|
||||
import os
|
||||
import os.path
|
||||
import subprocess
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import yaml
|
||||
|
||||
SYNAPSE = ["python", "-B", "-m", "synapse.app.homeserver"]
|
||||
|
@ -28,60 +31,181 @@ RED = "\x1b[1;31m"
|
|||
NORMAL = "\x1b[m"
|
||||
|
||||
|
||||
def write(message, colour=NORMAL, stream=sys.stdout):
|
||||
if colour == NORMAL:
|
||||
stream.write(message + "\n")
|
||||
else:
|
||||
stream.write(colour + message + NORMAL + "\n")
|
||||
|
||||
|
||||
def start(configfile):
|
||||
print ("Starting ...")
|
||||
write("Starting ...")
|
||||
args = SYNAPSE
|
||||
args.extend(["--daemonize", "-c", configfile])
|
||||
|
||||
try:
|
||||
subprocess.check_call(args)
|
||||
print (GREEN + "started" + NORMAL)
|
||||
write("started synapse.app.homeserver(%r)" % (configfile,), colour=GREEN)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print (
|
||||
RED +
|
||||
"error starting (exit code: %d); see above for logs" % e.returncode +
|
||||
NORMAL
|
||||
write(
|
||||
"error starting (exit code: %d); see above for logs" % e.returncode,
|
||||
colour=RED,
|
||||
)
|
||||
|
||||
|
||||
def stop(pidfile):
|
||||
def start_worker(app, configfile, worker_configfile):
|
||||
args = [
|
||||
"python", "-B",
|
||||
"-m", app,
|
||||
"-c", configfile,
|
||||
"-c", worker_configfile
|
||||
]
|
||||
|
||||
try:
|
||||
subprocess.check_call(args)
|
||||
write("started %s(%r)" % (app, worker_configfile), colour=GREEN)
|
||||
except subprocess.CalledProcessError as e:
|
||||
write(
|
||||
"error starting %s(%r) (exit code: %d); see above for logs" % (
|
||||
app, worker_configfile, e.returncode,
|
||||
),
|
||||
colour=RED,
|
||||
)
|
||||
|
||||
|
||||
def stop(pidfile, app):
|
||||
if os.path.exists(pidfile):
|
||||
pid = int(open(pidfile).read())
|
||||
os.kill(pid, signal.SIGTERM)
|
||||
print (GREEN + "stopped" + NORMAL)
|
||||
write("stopped %s" % (app,), colour=GREEN)
|
||||
|
||||
|
||||
Worker = collections.namedtuple("Worker", [
|
||||
"app", "configfile", "pidfile", "cache_factor"
|
||||
])
|
||||
|
||||
|
||||
def main():
|
||||
configfile = sys.argv[2] if len(sys.argv) == 3 else "homeserver.yaml"
|
||||
|
||||
if not os.path.exists(configfile):
|
||||
sys.stderr.write(
|
||||
"No config file found\n"
|
||||
"To generate a config file, run '%s -c %s --generate-config"
|
||||
" --server-name=<server name>'\n" % (
|
||||
" ".join(SYNAPSE), configfile
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"action",
|
||||
choices=["start", "stop", "restart"],
|
||||
help="whether to start, stop or restart the synapse",
|
||||
)
|
||||
parser.add_argument(
|
||||
"configfile",
|
||||
nargs="?",
|
||||
default="homeserver.yaml",
|
||||
help="the homeserver config file, defaults to homserver.yaml",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-w", "--worker",
|
||||
metavar="WORKERCONFIG",
|
||||
help="start or stop a single worker",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-a", "--all-processes",
|
||||
metavar="WORKERCONFIGDIR",
|
||||
help="start or stop all the workers in the given directory"
|
||||
" and the main synapse process",
|
||||
)
|
||||
|
||||
options = parser.parse_args()
|
||||
|
||||
if options.worker and options.all_processes:
|
||||
write(
|
||||
'Cannot use "--worker" with "--all-processes"',
|
||||
stream=sys.stderr
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
config = yaml.load(open(configfile))
|
||||
configfile = options.configfile
|
||||
|
||||
if not os.path.exists(configfile):
|
||||
write(
|
||||
"No config file found\n"
|
||||
"To generate a config file, run '%s -c %s --generate-config"
|
||||
" --server-name=<server name>'\n" % (
|
||||
" ".join(SYNAPSE), options.configfile
|
||||
),
|
||||
stream=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
with open(configfile) as stream:
|
||||
config = yaml.load(stream)
|
||||
|
||||
pidfile = config["pid_file"]
|
||||
cache_factor = config.get("synctl_cache_factor", None)
|
||||
cache_factor = config.get("synctl_cache_factor")
|
||||
start_stop_synapse = True
|
||||
|
||||
if cache_factor:
|
||||
os.environ["SYNAPSE_CACHE_FACTOR"] = str(cache_factor)
|
||||
|
||||
action = sys.argv[1] if sys.argv[1:] else "usage"
|
||||
if action == "start":
|
||||
start(configfile)
|
||||
elif action == "stop":
|
||||
stop(pidfile)
|
||||
elif action == "restart":
|
||||
stop(pidfile)
|
||||
start(configfile)
|
||||
else:
|
||||
sys.stderr.write("Usage: %s [start|stop|restart] [configfile]\n" % (sys.argv[0],))
|
||||
worker_configfiles = []
|
||||
if options.worker:
|
||||
start_stop_synapse = False
|
||||
worker_configfile = options.worker
|
||||
if not os.path.exists(worker_configfile):
|
||||
write(
|
||||
"No worker config found at %r" % (worker_configfile,),
|
||||
stream=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
worker_configfiles.append(worker_configfile)
|
||||
|
||||
if options.all_processes:
|
||||
worker_configdir = options.all_processes
|
||||
if not os.path.isdir(worker_configdir):
|
||||
write(
|
||||
"No worker config directory found at %r" % (worker_configdir,),
|
||||
stream=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
worker_configfiles.extend(sorted(glob.glob(
|
||||
os.path.join(worker_configdir, "*.yaml")
|
||||
)))
|
||||
|
||||
workers = []
|
||||
for worker_configfile in worker_configfiles:
|
||||
with open(worker_configfile) as stream:
|
||||
worker_config = yaml.load(stream)
|
||||
worker_app = worker_config["worker_app"]
|
||||
worker_pidfile = worker_config["worker_pid_file"]
|
||||
worker_daemonize = worker_config["worker_daemonize"]
|
||||
assert worker_daemonize # TODO print something more user friendly
|
||||
worker_cache_factor = worker_config.get("synctl_cache_factor")
|
||||
workers.append(Worker(
|
||||
worker_app, worker_configfile, worker_pidfile, worker_cache_factor,
|
||||
))
|
||||
|
||||
action = options.action
|
||||
|
||||
if action == "stop" or action == "restart":
|
||||
for worker in workers:
|
||||
stop(worker.pidfile, worker.app)
|
||||
|
||||
if start_stop_synapse:
|
||||
stop(pidfile, "synapse.app.homeserver")
|
||||
|
||||
# TODO: Wait for synapse to actually shutdown before starting it again
|
||||
|
||||
if action == "start" or action == "restart":
|
||||
if start_stop_synapse:
|
||||
start(configfile)
|
||||
|
||||
for worker in workers:
|
||||
if worker.cache_factor:
|
||||
os.environ["SYNAPSE_CACHE_FACTOR"] = str(worker.cache_factor)
|
||||
|
||||
start_worker(worker.app, configfile, worker.configfile)
|
||||
|
||||
if cache_factor:
|
||||
os.environ["SYNAPSE_CACHE_FACTOR"] = str(cache_factor)
|
||||
else:
|
||||
os.environ.pop("SYNAPSE_CACHE_FACTOR", None)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -56,22 +56,22 @@ import logging
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AppServiceScheduler(object):
|
||||
class ApplicationServiceScheduler(object):
|
||||
""" Public facing API for this module. Does the required DI to tie the
|
||||
components together. This also serves as the "event_pool", which in this
|
||||
case is a simple array.
|
||||
"""
|
||||
|
||||
def __init__(self, clock, store, as_api):
|
||||
self.clock = clock
|
||||
self.store = store
|
||||
self.as_api = as_api
|
||||
def __init__(self, hs):
|
||||
self.clock = hs.get_clock()
|
||||
self.store = hs.get_datastore()
|
||||
self.as_api = hs.get_application_service_api()
|
||||
|
||||
def create_recoverer(service, callback):
|
||||
return _Recoverer(clock, store, as_api, service, callback)
|
||||
return _Recoverer(self.clock, self.store, self.as_api, service, callback)
|
||||
|
||||
self.txn_ctrl = _TransactionController(
|
||||
clock, store, as_api, create_recoverer
|
||||
self.clock, self.store, self.as_api, create_recoverer
|
||||
)
|
||||
self.queuer = _ServiceQueuer(self.txn_ctrl)
|
||||
|
||||
|
|
|
@ -157,9 +157,40 @@ class Config(object):
|
|||
return default_config, config
|
||||
|
||||
@classmethod
|
||||
def load_config(cls, description, argv, generate_section=None):
|
||||
obj = cls()
|
||||
def load_config(cls, description, argv):
|
||||
config_parser = argparse.ArgumentParser(
|
||||
description=description,
|
||||
)
|
||||
config_parser.add_argument(
|
||||
"-c", "--config-path",
|
||||
action="append",
|
||||
metavar="CONFIG_FILE",
|
||||
help="Specify config file. Can be given multiple times and"
|
||||
" may specify directories containing *.yaml files."
|
||||
)
|
||||
|
||||
config_parser.add_argument(
|
||||
"--keys-directory",
|
||||
metavar="DIRECTORY",
|
||||
help="Where files such as certs and signing keys are stored when"
|
||||
" their location is given explicitly in the config."
|
||||
" Defaults to the directory containing the last config file",
|
||||
)
|
||||
|
||||
config_args = config_parser.parse_args(argv)
|
||||
|
||||
config_files = find_config_files(search_paths=config_args.config_path)
|
||||
|
||||
obj = cls()
|
||||
obj.read_config_files(
|
||||
config_files,
|
||||
keys_directory=config_args.keys_directory,
|
||||
generate_keys=False,
|
||||
)
|
||||
return obj
|
||||
|
||||
@classmethod
|
||||
def load_or_generate_config(cls, description, argv):
|
||||
config_parser = argparse.ArgumentParser(add_help=False)
|
||||
config_parser.add_argument(
|
||||
"-c", "--config-path",
|
||||
|
@ -176,7 +207,7 @@ class Config(object):
|
|||
config_parser.add_argument(
|
||||
"--report-stats",
|
||||
action="store",
|
||||
help="Stuff",
|
||||
help="Whether the generated config reports anonymized usage statistics",
|
||||
choices=["yes", "no"]
|
||||
)
|
||||
config_parser.add_argument(
|
||||
|
@ -197,36 +228,11 @@ class Config(object):
|
|||
)
|
||||
config_args, remaining_args = config_parser.parse_known_args(argv)
|
||||
|
||||
config_files = find_config_files(search_paths=config_args.config_path)
|
||||
|
||||
generate_keys = config_args.generate_keys
|
||||
|
||||
config_files = []
|
||||
if config_args.config_path:
|
||||
for config_path in config_args.config_path:
|
||||
if os.path.isdir(config_path):
|
||||
# We accept specifying directories as config paths, we search
|
||||
# inside that directory for all files matching *.yaml, and then
|
||||
# we apply them in *sorted* order.
|
||||
files = []
|
||||
for entry in os.listdir(config_path):
|
||||
entry_path = os.path.join(config_path, entry)
|
||||
if not os.path.isfile(entry_path):
|
||||
print (
|
||||
"Found subdirectory in config directory: %r. IGNORING."
|
||||
) % (entry_path, )
|
||||
continue
|
||||
|
||||
if not entry.endswith(".yaml"):
|
||||
print (
|
||||
"Found file in config directory that does not"
|
||||
" end in '.yaml': %r. IGNORING."
|
||||
) % (entry_path, )
|
||||
continue
|
||||
|
||||
files.append(entry_path)
|
||||
|
||||
config_files.extend(sorted(files))
|
||||
else:
|
||||
config_files.append(config_path)
|
||||
obj = cls()
|
||||
|
||||
if config_args.generate_config:
|
||||
if config_args.report_stats is None:
|
||||
|
@ -299,28 +305,43 @@ class Config(object):
|
|||
" -c CONFIG-FILE\""
|
||||
)
|
||||
|
||||
if config_args.keys_directory:
|
||||
config_dir_path = config_args.keys_directory
|
||||
else:
|
||||
config_dir_path = os.path.dirname(config_args.config_path[-1])
|
||||
config_dir_path = os.path.abspath(config_dir_path)
|
||||
obj.read_config_files(
|
||||
config_files,
|
||||
keys_directory=config_args.keys_directory,
|
||||
generate_keys=generate_keys,
|
||||
)
|
||||
|
||||
if generate_keys:
|
||||
return None
|
||||
|
||||
obj.invoke_all("read_arguments", args)
|
||||
|
||||
return obj
|
||||
|
||||
def read_config_files(self, config_files, keys_directory=None,
|
||||
generate_keys=False):
|
||||
if not keys_directory:
|
||||
keys_directory = os.path.dirname(config_files[-1])
|
||||
|
||||
config_dir_path = os.path.abspath(keys_directory)
|
||||
|
||||
specified_config = {}
|
||||
for config_file in config_files:
|
||||
yaml_config = cls.read_config_file(config_file)
|
||||
yaml_config = self.read_config_file(config_file)
|
||||
specified_config.update(yaml_config)
|
||||
|
||||
if "server_name" not in specified_config:
|
||||
raise ConfigError(MISSING_SERVER_NAME)
|
||||
|
||||
server_name = specified_config["server_name"]
|
||||
_, config = obj.generate_config(
|
||||
_, config = self.generate_config(
|
||||
config_dir_path=config_dir_path,
|
||||
server_name=server_name,
|
||||
is_generating_file=False,
|
||||
)
|
||||
config.pop("log_config")
|
||||
config.update(specified_config)
|
||||
|
||||
if "report_stats" not in config:
|
||||
raise ConfigError(
|
||||
MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + "\n" +
|
||||
|
@ -328,11 +349,51 @@ class Config(object):
|
|||
)
|
||||
|
||||
if generate_keys:
|
||||
obj.invoke_all("generate_files", config)
|
||||
self.invoke_all("generate_files", config)
|
||||
return
|
||||
|
||||
obj.invoke_all("read_config", config)
|
||||
self.invoke_all("read_config", config)
|
||||
|
||||
obj.invoke_all("read_arguments", args)
|
||||
|
||||
return obj
|
||||
def find_config_files(search_paths):
|
||||
"""Finds config files using a list of search paths. If a path is a file
|
||||
then that file path is added to the list. If a search path is a directory
|
||||
then all the "*.yaml" files in that directory are added to the list in
|
||||
sorted order.
|
||||
|
||||
Args:
|
||||
search_paths(list(str)): A list of paths to search.
|
||||
|
||||
Returns:
|
||||
list(str): A list of file paths.
|
||||
"""
|
||||
|
||||
config_files = []
|
||||
if search_paths:
|
||||
for config_path in search_paths:
|
||||
if os.path.isdir(config_path):
|
||||
# We accept specifying directories as config paths, we search
|
||||
# inside that directory for all files matching *.yaml, and then
|
||||
# we apply them in *sorted* order.
|
||||
files = []
|
||||
for entry in os.listdir(config_path):
|
||||
entry_path = os.path.join(config_path, entry)
|
||||
if not os.path.isfile(entry_path):
|
||||
print (
|
||||
"Found subdirectory in config directory: %r. IGNORING."
|
||||
) % (entry_path, )
|
||||
continue
|
||||
|
||||
if not entry.endswith(".yaml"):
|
||||
print (
|
||||
"Found file in config directory that does not"
|
||||
" end in '.yaml': %r. IGNORING."
|
||||
) % (entry_path, )
|
||||
continue
|
||||
|
||||
files.append(entry_path)
|
||||
|
||||
config_files.extend(sorted(files))
|
||||
else:
|
||||
config_files.append(config_path)
|
||||
return config_files
|
||||
|
|
|
@ -27,6 +27,7 @@ class CaptchaConfig(Config):
|
|||
def default_config(self, **kwargs):
|
||||
return """\
|
||||
## Captcha ##
|
||||
# See docs/CAPTCHA_SETUP for full details of configuring this.
|
||||
|
||||
# This Home Server's ReCAPTCHA public key.
|
||||
recaptcha_public_key: "YOUR_PUBLIC_KEY"
|
||||
|
|
|
@ -89,7 +89,7 @@ class EmailConfig(Config):
|
|||
# enable_notifs: false
|
||||
# smtp_host: "localhost"
|
||||
# smtp_port: 25
|
||||
# notif_from: Your Friendly Matrix Home Server <noreply@example.com>
|
||||
# notif_from: "Your Friendly %(app)s Home Server <noreply@example.com>"
|
||||
# app_name: Matrix
|
||||
# template_dir: res/templates
|
||||
# notif_template_html: notif_mail.html
|
||||
|
|
|
@ -32,13 +32,15 @@ from .password import PasswordConfig
|
|||
from .jwt import JWTConfig
|
||||
from .ldap import LDAPConfig
|
||||
from .emailconfig import EmailConfig
|
||||
from .workers import WorkerConfig
|
||||
|
||||
|
||||
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
|
||||
RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
|
||||
VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig,
|
||||
AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
|
||||
JWTConfig, LDAPConfig, PasswordConfig, EmailConfig,):
|
||||
JWTConfig, LDAPConfig, PasswordConfig, EmailConfig,
|
||||
WorkerConfig,):
|
||||
pass
|
||||
|
||||
|
||||
|
|
|
@ -13,40 +13,88 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import Config
|
||||
from ._base import Config, ConfigError
|
||||
|
||||
|
||||
MISSING_LDAP3 = (
|
||||
"Missing ldap3 library. This is required for LDAP Authentication."
|
||||
)
|
||||
|
||||
|
||||
class LDAPMode(object):
|
||||
SIMPLE = "simple",
|
||||
SEARCH = "search",
|
||||
|
||||
LIST = (SIMPLE, SEARCH)
|
||||
|
||||
|
||||
class LDAPConfig(Config):
|
||||
def read_config(self, config):
|
||||
ldap_config = config.get("ldap_config", None)
|
||||
if ldap_config:
|
||||
ldap_config = config.get("ldap_config", {})
|
||||
|
||||
self.ldap_enabled = ldap_config.get("enabled", False)
|
||||
self.ldap_server = ldap_config["server"]
|
||||
self.ldap_port = ldap_config["port"]
|
||||
self.ldap_tls = ldap_config.get("tls", False)
|
||||
self.ldap_search_base = ldap_config["search_base"]
|
||||
self.ldap_search_property = ldap_config["search_property"]
|
||||
self.ldap_email_property = ldap_config["email_property"]
|
||||
self.ldap_full_name_property = ldap_config["full_name_property"]
|
||||
else:
|
||||
self.ldap_enabled = False
|
||||
self.ldap_server = None
|
||||
self.ldap_port = None
|
||||
self.ldap_tls = False
|
||||
self.ldap_search_base = None
|
||||
self.ldap_search_property = None
|
||||
self.ldap_email_property = None
|
||||
self.ldap_full_name_property = None
|
||||
|
||||
if self.ldap_enabled:
|
||||
# verify dependencies are available
|
||||
try:
|
||||
import ldap3
|
||||
ldap3 # to stop unused lint
|
||||
except ImportError:
|
||||
raise ConfigError(MISSING_LDAP3)
|
||||
|
||||
self.ldap_mode = LDAPMode.SIMPLE
|
||||
|
||||
# verify config sanity
|
||||
self.require_keys(ldap_config, [
|
||||
"uri",
|
||||
"base",
|
||||
"attributes",
|
||||
])
|
||||
|
||||
self.ldap_uri = ldap_config["uri"]
|
||||
self.ldap_start_tls = ldap_config.get("start_tls", False)
|
||||
self.ldap_base = ldap_config["base"]
|
||||
self.ldap_attributes = ldap_config["attributes"]
|
||||
|
||||
if "bind_dn" in ldap_config:
|
||||
self.ldap_mode = LDAPMode.SEARCH
|
||||
self.require_keys(ldap_config, [
|
||||
"bind_dn",
|
||||
"bind_password",
|
||||
])
|
||||
|
||||
self.ldap_bind_dn = ldap_config["bind_dn"]
|
||||
self.ldap_bind_password = ldap_config["bind_password"]
|
||||
self.ldap_filter = ldap_config.get("filter", None)
|
||||
|
||||
# verify attribute lookup
|
||||
self.require_keys(ldap_config['attributes'], [
|
||||
"uid",
|
||||
"name",
|
||||
"mail",
|
||||
])
|
||||
|
||||
def require_keys(self, config, required):
|
||||
missing = [key for key in required if key not in config]
|
||||
if missing:
|
||||
raise ConfigError(
|
||||
"LDAP enabled but missing required config values: {}".format(
|
||||
", ".join(missing)
|
||||
)
|
||||
)
|
||||
|
||||
def default_config(self, **kwargs):
|
||||
return """\
|
||||
# ldap_config:
|
||||
# enabled: true
|
||||
# server: "ldap://localhost"
|
||||
# port: 389
|
||||
# tls: false
|
||||
# search_base: "ou=Users,dc=example,dc=com"
|
||||
# search_property: "cn"
|
||||
# email_property: "email"
|
||||
# full_name_property: "givenName"
|
||||
# uri: "ldap://ldap.example.com:389"
|
||||
# start_tls: true
|
||||
# base: "ou=users,dc=example,dc=com"
|
||||
# attributes:
|
||||
# uid: "cn"
|
||||
# mail: "email"
|
||||
# name: "givenName"
|
||||
# #bind_dn:
|
||||
# #bind_password:
|
||||
# #filter: "(objectClass=posixAccount)"
|
||||
"""
|
||||
|
|
|
@ -126,17 +126,21 @@ class LoggingConfig(Config):
|
|||
)
|
||||
|
||||
def setup_logging(self):
|
||||
setup_logging(self.log_config, self.log_file, self.verbosity)
|
||||
|
||||
|
||||
def setup_logging(log_config=None, log_file=None, verbosity=None):
|
||||
log_format = (
|
||||
"%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s"
|
||||
" - %(message)s"
|
||||
)
|
||||
if self.log_config is None:
|
||||
if log_config is None:
|
||||
|
||||
level = logging.INFO
|
||||
level_for_storage = logging.INFO
|
||||
if self.verbosity:
|
||||
if verbosity:
|
||||
level = logging.DEBUG
|
||||
if self.verbosity > 1:
|
||||
if verbosity > 1:
|
||||
level_for_storage = logging.DEBUG
|
||||
|
||||
# FIXME: we need a logging.WARN for a -q quiet option
|
||||
|
@ -146,10 +150,10 @@ class LoggingConfig(Config):
|
|||
logging.getLogger('synapse.storage').setLevel(level_for_storage)
|
||||
|
||||
formatter = logging.Formatter(log_format)
|
||||
if self.log_file:
|
||||
if log_file:
|
||||
# TODO: Customisable file size / backup count
|
||||
handler = logging.handlers.RotatingFileHandler(
|
||||
self.log_file, maxBytes=(1000 * 1000 * 100), backupCount=3
|
||||
log_file, maxBytes=(1000 * 1000 * 100), backupCount=3
|
||||
)
|
||||
|
||||
def sighup(signum, stack):
|
||||
|
@ -172,7 +176,7 @@ class LoggingConfig(Config):
|
|||
|
||||
logger.addHandler(handler)
|
||||
else:
|
||||
with open(self.log_config, 'r') as f:
|
||||
with open(log_config, 'r') as f:
|
||||
logging.config.dictConfig(yaml.load(f))
|
||||
|
||||
observer = PythonLoggingObserver()
|
||||
|
|
|
@ -23,10 +23,14 @@ class PasswordConfig(Config):
|
|||
def read_config(self, config):
|
||||
password_config = config.get("password_config", {})
|
||||
self.password_enabled = password_config.get("enabled", True)
|
||||
self.password_pepper = password_config.get("pepper", "")
|
||||
|
||||
def default_config(self, config_dir_path, server_name, **kwargs):
|
||||
return """
|
||||
# Enable password for login.
|
||||
password_config:
|
||||
enabled: true
|
||||
# Uncomment and change to a secret random string for extra security.
|
||||
# DO NOT CHANGE THIS AFTER INITIAL SETUP!
|
||||
#pepper: ""
|
||||
"""
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import Config
|
||||
from ._base import Config, ConfigError
|
||||
|
||||
|
||||
class ServerConfig(Config):
|
||||
|
@ -27,8 +27,9 @@ class ServerConfig(Config):
|
|||
self.daemonize = config.get("daemonize")
|
||||
self.print_pidfile = config.get("print_pidfile")
|
||||
self.user_agent_suffix = config.get("user_agent_suffix")
|
||||
self.use_frozen_dicts = config.get("use_frozen_dicts", True)
|
||||
self.use_frozen_dicts = config.get("use_frozen_dicts", False)
|
||||
self.public_baseurl = config.get("public_baseurl")
|
||||
self.secondary_directory_servers = config.get("secondary_directory_servers", [])
|
||||
|
||||
if self.public_baseurl is not None:
|
||||
if self.public_baseurl[-1] != '/':
|
||||
|
@ -37,6 +38,8 @@ class ServerConfig(Config):
|
|||
|
||||
self.listeners = config.get("listeners", [])
|
||||
|
||||
self.gc_thresholds = read_gc_thresholds(config.get("gc_thresholds", None))
|
||||
|
||||
bind_port = config.get("bind_port")
|
||||
if bind_port:
|
||||
self.listeners = []
|
||||
|
@ -104,26 +107,6 @@ class ServerConfig(Config):
|
|||
]
|
||||
})
|
||||
|
||||
# Attempt to guess the content_addr for the v0 content repostitory
|
||||
content_addr = config.get("content_addr")
|
||||
if not content_addr:
|
||||
for listener in self.listeners:
|
||||
if listener["type"] == "http" and not listener.get("tls", False):
|
||||
unsecure_port = listener["port"]
|
||||
break
|
||||
else:
|
||||
raise RuntimeError("Could not determine 'content_addr'")
|
||||
|
||||
host = self.server_name
|
||||
if ':' not in host:
|
||||
host = "%s:%d" % (host, unsecure_port)
|
||||
else:
|
||||
host = host.split(':')[0]
|
||||
host = "%s:%d" % (host, unsecure_port)
|
||||
content_addr = "http://%s" % (host,)
|
||||
|
||||
self.content_addr = content_addr
|
||||
|
||||
def default_config(self, server_name, **kwargs):
|
||||
if ":" in server_name:
|
||||
bind_port = int(server_name.split(":")[1])
|
||||
|
@ -156,6 +139,17 @@ class ServerConfig(Config):
|
|||
# hard limit.
|
||||
soft_file_limit: 0
|
||||
|
||||
# The GC threshold parameters to pass to `gc.set_threshold`, if defined
|
||||
# gc_thresholds: [700, 10, 10]
|
||||
|
||||
# A list of other Home Servers to fetch the public room directory from
|
||||
# and include in the public room directory of this home server
|
||||
# This is a temporary stopgap solution to populate new server with a
|
||||
# list of rooms until there exists a good solution of a decentralized
|
||||
# room directory.
|
||||
# secondary_directory_servers:
|
||||
# - matrix.org
|
||||
|
||||
# List of ports that Synapse should listen on, their purpose and their
|
||||
# configuration.
|
||||
listeners:
|
||||
|
@ -237,3 +231,20 @@ class ServerConfig(Config):
|
|||
type=int,
|
||||
help="Turn on the twisted telnet manhole"
|
||||
" service on the given port.")
|
||||
|
||||
|
||||
def read_gc_thresholds(thresholds):
|
||||
"""Reads the three integer thresholds for garbage collection. Ensures that
|
||||
the thresholds are integers if thresholds are supplied.
|
||||
"""
|
||||
if thresholds is None:
|
||||
return None
|
||||
try:
|
||||
assert len(thresholds) == 3
|
||||
return (
|
||||
int(thresholds[0]), int(thresholds[1]), int(thresholds[2]),
|
||||
)
|
||||
except:
|
||||
raise ConfigError(
|
||||
"Value of `gc_threshold` must be a list of three integers if set"
|
||||
)
|
||||
|
|
31
synapse/config/workers.py
Normal file
31
synapse/config/workers.py
Normal file
|
@ -0,0 +1,31 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2016 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import Config
|
||||
|
||||
|
||||
class WorkerConfig(Config):
|
||||
"""The workers are processes run separately to the main synapse process.
|
||||
They have their own pid_file and listener configuration. They use the
|
||||
replication_url to talk to the main synapse process."""
|
||||
|
||||
def read_config(self, config):
|
||||
self.worker_app = config.get("worker_app")
|
||||
self.worker_listeners = config.get("worker_listeners")
|
||||
self.worker_daemonize = config.get("worker_daemonize")
|
||||
self.worker_pid_file = config.get("worker_pid_file")
|
||||
self.worker_log_file = config.get("worker_log_file")
|
||||
self.worker_log_config = config.get("worker_log_config")
|
||||
self.worker_replication_url = config.get("worker_replication_url")
|
|
@ -77,10 +77,12 @@ class SynapseKeyClientProtocol(HTTPClient):
|
|||
def __init__(self):
|
||||
self.remote_key = defer.Deferred()
|
||||
self.host = None
|
||||
self._peer = None
|
||||
|
||||
def connectionMade(self):
|
||||
self.host = self.transport.getHost()
|
||||
logger.debug("Connected to %s", self.host)
|
||||
self._peer = self.transport.getPeer()
|
||||
logger.debug("Connected to %s", self._peer)
|
||||
|
||||
self.sendCommand(b"GET", self.path)
|
||||
if self.host:
|
||||
self.sendHeader(b"Host", self.host)
|
||||
|
@ -124,7 +126,10 @@ class SynapseKeyClientProtocol(HTTPClient):
|
|||
self.timer.cancel()
|
||||
|
||||
def on_timeout(self):
|
||||
logger.debug("Timeout waiting for response from %s", self.host)
|
||||
logger.debug(
|
||||
"Timeout waiting for response from %s: %s",
|
||||
self.host, self._peer,
|
||||
)
|
||||
self.errback(IOError("Timeout waiting for response"))
|
||||
self.transport.abortConnection()
|
||||
|
||||
|
@ -133,4 +138,5 @@ class SynapseKeyClientFactory(Factory):
|
|||
def protocol(self):
|
||||
protocol = SynapseKeyClientProtocol()
|
||||
protocol.path = self.path
|
||||
protocol.host = self.host
|
||||
return protocol
|
||||
|
|
|
@ -44,7 +44,25 @@ import logging
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
KeyGroup = namedtuple("KeyGroup", ("server_name", "group_id", "key_ids"))
|
||||
VerifyKeyRequest = namedtuple("VerifyRequest", (
|
||||
"server_name", "key_ids", "json_object", "deferred"
|
||||
))
|
||||
"""
|
||||
A request for a verify key to verify a JSON object.
|
||||
|
||||
Attributes:
|
||||
server_name(str): The name of the server to verify against.
|
||||
key_ids(set(str)): The set of key_ids to that could be used to verify the
|
||||
JSON object
|
||||
json_object(dict): The JSON object to verify.
|
||||
deferred(twisted.internet.defer.Deferred):
|
||||
A deferred (server_name, key_id, verify_key) tuple that resolves when
|
||||
a verify key has been fetched
|
||||
"""
|
||||
|
||||
|
||||
class KeyLookupError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class Keyring(object):
|
||||
|
@ -74,39 +92,32 @@ class Keyring(object):
|
|||
list of deferreds indicating success or failure to verify each
|
||||
json object's signature for the given server_name.
|
||||
"""
|
||||
group_id_to_json = {}
|
||||
group_id_to_group = {}
|
||||
group_ids = []
|
||||
|
||||
next_group_id = 0
|
||||
deferreds = {}
|
||||
verify_requests = []
|
||||
|
||||
for server_name, json_object in server_and_json:
|
||||
logger.debug("Verifying for %s", server_name)
|
||||
group_id = next_group_id
|
||||
next_group_id += 1
|
||||
group_ids.append(group_id)
|
||||
|
||||
key_ids = signature_ids(json_object, server_name)
|
||||
if not key_ids:
|
||||
deferreds[group_id] = defer.fail(SynapseError(
|
||||
deferred = defer.fail(SynapseError(
|
||||
400,
|
||||
"Not signed with a supported algorithm",
|
||||
Codes.UNAUTHORIZED,
|
||||
))
|
||||
else:
|
||||
deferreds[group_id] = defer.Deferred()
|
||||
deferred = defer.Deferred()
|
||||
|
||||
group = KeyGroup(server_name, group_id, key_ids)
|
||||
verify_request = VerifyKeyRequest(
|
||||
server_name, key_ids, json_object, deferred
|
||||
)
|
||||
|
||||
group_id_to_group[group_id] = group
|
||||
group_id_to_json[group_id] = json_object
|
||||
verify_requests.append(verify_request)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def handle_key_deferred(group, deferred):
|
||||
server_name = group.server_name
|
||||
def handle_key_deferred(verify_request):
|
||||
server_name = verify_request.server_name
|
||||
try:
|
||||
_, _, key_id, verify_key = yield deferred
|
||||
_, key_id, verify_key = yield verify_request.deferred
|
||||
except IOError as e:
|
||||
logger.warn(
|
||||
"Got IOError when downloading keys for %s: %s %s",
|
||||
|
@ -128,7 +139,7 @@ class Keyring(object):
|
|||
Codes.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
json_object = group_id_to_json[group.group_id]
|
||||
json_object = verify_request.json_object
|
||||
|
||||
try:
|
||||
verify_signed_json(json_object, server_name, verify_key)
|
||||
|
@ -157,36 +168,34 @@ class Keyring(object):
|
|||
|
||||
# Actually start fetching keys.
|
||||
wait_on_deferred.addBoth(
|
||||
lambda _: self.get_server_verify_keys(group_id_to_group, deferreds)
|
||||
lambda _: self.get_server_verify_keys(verify_requests)
|
||||
)
|
||||
|
||||
# When we've finished fetching all the keys for a given server_name,
|
||||
# resolve the deferred passed to `wait_for_previous_lookups` so that
|
||||
# any lookups waiting will proceed.
|
||||
server_to_gids = {}
|
||||
server_to_request_ids = {}
|
||||
|
||||
def remove_deferreds(res, server_name, group_id):
|
||||
server_to_gids[server_name].discard(group_id)
|
||||
if not server_to_gids[server_name]:
|
||||
def remove_deferreds(res, server_name, verify_request):
|
||||
request_id = id(verify_request)
|
||||
server_to_request_ids[server_name].discard(request_id)
|
||||
if not server_to_request_ids[server_name]:
|
||||
d = server_to_deferred.pop(server_name, None)
|
||||
if d:
|
||||
d.callback(None)
|
||||
return res
|
||||
|
||||
for g_id, deferred in deferreds.items():
|
||||
server_name = group_id_to_group[g_id].server_name
|
||||
server_to_gids.setdefault(server_name, set()).add(g_id)
|
||||
deferred.addBoth(remove_deferreds, server_name, g_id)
|
||||
for verify_request in verify_requests:
|
||||
server_name = verify_request.server_name
|
||||
request_id = id(verify_request)
|
||||
server_to_request_ids.setdefault(server_name, set()).add(request_id)
|
||||
deferred.addBoth(remove_deferreds, server_name, verify_request)
|
||||
|
||||
# Pass those keys to handle_key_deferred so that the json object
|
||||
# signatures can be verified
|
||||
return [
|
||||
preserve_context_over_fn(
|
||||
handle_key_deferred,
|
||||
group_id_to_group[g_id],
|
||||
deferreds[g_id],
|
||||
)
|
||||
for g_id in group_ids
|
||||
preserve_context_over_fn(handle_key_deferred, verify_request)
|
||||
for verify_request in verify_requests
|
||||
]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -220,7 +229,7 @@ class Keyring(object):
|
|||
|
||||
d.addBoth(rm, server_name)
|
||||
|
||||
def get_server_verify_keys(self, group_id_to_group, group_id_to_deferred):
|
||||
def get_server_verify_keys(self, verify_requests):
|
||||
"""Takes a dict of KeyGroups and tries to find at least one key for
|
||||
each group.
|
||||
"""
|
||||
|
@ -237,62 +246,64 @@ class Keyring(object):
|
|||
merged_results = {}
|
||||
|
||||
missing_keys = {}
|
||||
for group in group_id_to_group.values():
|
||||
missing_keys.setdefault(group.server_name, set()).update(
|
||||
group.key_ids
|
||||
for verify_request in verify_requests:
|
||||
missing_keys.setdefault(verify_request.server_name, set()).update(
|
||||
verify_request.key_ids
|
||||
)
|
||||
|
||||
for fn in key_fetch_fns:
|
||||
results = yield fn(missing_keys.items())
|
||||
merged_results.update(results)
|
||||
|
||||
# We now need to figure out which groups we have keys for
|
||||
# and which we don't
|
||||
missing_groups = {}
|
||||
for group in group_id_to_group.values():
|
||||
for key_id in group.key_ids:
|
||||
if key_id in merged_results[group.server_name]:
|
||||
# We now need to figure out which verify requests we have keys
|
||||
# for and which we don't
|
||||
missing_keys = {}
|
||||
requests_missing_keys = []
|
||||
for verify_request in verify_requests:
|
||||
server_name = verify_request.server_name
|
||||
result_keys = merged_results[server_name]
|
||||
|
||||
if verify_request.deferred.called:
|
||||
# We've already called this deferred, which probably
|
||||
# means that we've already found a key for it.
|
||||
continue
|
||||
|
||||
for key_id in verify_request.key_ids:
|
||||
if key_id in result_keys:
|
||||
with PreserveLoggingContext():
|
||||
group_id_to_deferred[group.group_id].callback((
|
||||
group.group_id,
|
||||
group.server_name,
|
||||
verify_request.deferred.callback((
|
||||
server_name,
|
||||
key_id,
|
||||
merged_results[group.server_name][key_id],
|
||||
result_keys[key_id],
|
||||
))
|
||||
break
|
||||
else:
|
||||
missing_groups.setdefault(
|
||||
group.server_name, []
|
||||
).append(group)
|
||||
# The else block is only reached if the loop above
|
||||
# doesn't break.
|
||||
missing_keys.setdefault(server_name, set()).update(
|
||||
verify_request.key_ids
|
||||
)
|
||||
requests_missing_keys.append(verify_request)
|
||||
|
||||
if not missing_groups:
|
||||
if not missing_keys:
|
||||
break
|
||||
|
||||
missing_keys = {
|
||||
server_name: set(
|
||||
key_id for group in groups for key_id in group.key_ids
|
||||
)
|
||||
for server_name, groups in missing_groups.items()
|
||||
}
|
||||
|
||||
for group in missing_groups.values():
|
||||
group_id_to_deferred[group.group_id].errback(SynapseError(
|
||||
for verify_request in requests_missing_keys.values():
|
||||
verify_request.deferred.errback(SynapseError(
|
||||
401,
|
||||
"No key for %s with id %s" % (
|
||||
group.server_name, group.key_ids,
|
||||
verify_request.server_name, verify_request.key_ids,
|
||||
),
|
||||
Codes.UNAUTHORIZED,
|
||||
))
|
||||
|
||||
def on_err(err):
|
||||
for deferred in group_id_to_deferred.values():
|
||||
if not deferred.called:
|
||||
deferred.errback(err)
|
||||
for verify_request in verify_requests:
|
||||
if not verify_request.deferred.called:
|
||||
verify_request.deferred.errback(err)
|
||||
|
||||
do_iterations().addErrback(on_err)
|
||||
|
||||
return group_id_to_deferred
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_keys_from_store(self, server_name_and_key_ids):
|
||||
res = yield defer.gatherResults(
|
||||
|
@ -356,7 +367,7 @@ class Keyring(object):
|
|||
)
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
"Unable to getting key %r for %r directly: %s %s",
|
||||
"Unable to get key %r for %r directly: %s %s",
|
||||
key_ids, server_name,
|
||||
type(e).__name__, str(e.message),
|
||||
)
|
||||
|
@ -418,7 +429,7 @@ class Keyring(object):
|
|||
for response in responses:
|
||||
if (u"signatures" not in response
|
||||
or perspective_name not in response[u"signatures"]):
|
||||
raise ValueError(
|
||||
raise KeyLookupError(
|
||||
"Key response not signed by perspective server"
|
||||
" %r" % (perspective_name,)
|
||||
)
|
||||
|
@ -441,13 +452,13 @@ class Keyring(object):
|
|||
list(response[u"signatures"][perspective_name]),
|
||||
list(perspective_keys)
|
||||
)
|
||||
raise ValueError(
|
||||
raise KeyLookupError(
|
||||
"Response not signed with a known key for perspective"
|
||||
" server %r" % (perspective_name,)
|
||||
)
|
||||
|
||||
processed_response = yield self.process_v2_response(
|
||||
perspective_name, response
|
||||
perspective_name, response, only_from_server=False
|
||||
)
|
||||
|
||||
for server_name, response_keys in processed_response.items():
|
||||
|
@ -484,10 +495,10 @@ class Keyring(object):
|
|||
|
||||
if (u"signatures" not in response
|
||||
or server_name not in response[u"signatures"]):
|
||||
raise ValueError("Key response not signed by remote server")
|
||||
raise KeyLookupError("Key response not signed by remote server")
|
||||
|
||||
if "tls_fingerprints" not in response:
|
||||
raise ValueError("Key response missing TLS fingerprints")
|
||||
raise KeyLookupError("Key response missing TLS fingerprints")
|
||||
|
||||
certificate_bytes = crypto.dump_certificate(
|
||||
crypto.FILETYPE_ASN1, tls_certificate
|
||||
|
@ -501,7 +512,7 @@ class Keyring(object):
|
|||
response_sha256_fingerprints.add(fingerprint[u"sha256"])
|
||||
|
||||
if sha256_fingerprint_b64 not in response_sha256_fingerprints:
|
||||
raise ValueError("TLS certificate not allowed by fingerprints")
|
||||
raise KeyLookupError("TLS certificate not allowed by fingerprints")
|
||||
|
||||
response_keys = yield self.process_v2_response(
|
||||
from_server=server_name,
|
||||
|
@ -527,7 +538,7 @@ class Keyring(object):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def process_v2_response(self, from_server, response_json,
|
||||
requested_ids=[]):
|
||||
requested_ids=[], only_from_server=True):
|
||||
time_now_ms = self.clock.time_msec()
|
||||
response_keys = {}
|
||||
verify_keys = {}
|
||||
|
@ -551,9 +562,16 @@ class Keyring(object):
|
|||
|
||||
results = {}
|
||||
server_name = response_json["server_name"]
|
||||
if only_from_server:
|
||||
if server_name != from_server:
|
||||
raise KeyLookupError(
|
||||
"Expected a response for server %r not %r" % (
|
||||
from_server, server_name
|
||||
)
|
||||
)
|
||||
for key_id in response_json["signatures"].get(server_name, {}):
|
||||
if key_id not in response_json["verify_keys"]:
|
||||
raise ValueError(
|
||||
raise KeyLookupError(
|
||||
"Key response must include verification keys for all"
|
||||
" signatures"
|
||||
)
|
||||
|
@ -621,15 +639,15 @@ class Keyring(object):
|
|||
|
||||
if ("signatures" not in response
|
||||
or server_name not in response["signatures"]):
|
||||
raise ValueError("Key response not signed by remote server")
|
||||
raise KeyLookupError("Key response not signed by remote server")
|
||||
|
||||
if "tls_certificate" not in response:
|
||||
raise ValueError("Key response missing TLS certificate")
|
||||
raise KeyLookupError("Key response missing TLS certificate")
|
||||
|
||||
tls_certificate_b64 = response["tls_certificate"]
|
||||
|
||||
if encode_base64(x509_certificate_bytes) != tls_certificate_b64:
|
||||
raise ValueError("TLS certificate doesn't match")
|
||||
raise KeyLookupError("TLS certificate doesn't match")
|
||||
|
||||
# Cache the result in the datastore.
|
||||
|
||||
|
@ -645,7 +663,7 @@ class Keyring(object):
|
|||
|
||||
for key_id in response["signatures"][server_name]:
|
||||
if key_id not in response["verify_keys"]:
|
||||
raise ValueError(
|
||||
raise KeyLookupError(
|
||||
"Key response must include verification keys for all"
|
||||
" signatures"
|
||||
)
|
||||
|
|
|
@ -88,6 +88,8 @@ def prune_event(event):
|
|||
|
||||
if "age_ts" in event.unsigned:
|
||||
allowed_fields["unsigned"]["age_ts"] = event.unsigned["age_ts"]
|
||||
if "replaces_state" in event.unsigned:
|
||||
allowed_fields["unsigned"]["replaces_state"] = event.unsigned["replaces_state"]
|
||||
|
||||
return type(event)(
|
||||
allowed_fields,
|
||||
|
|
|
@ -31,6 +31,9 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class FederationBase(object):
|
||||
def __init__(self, hs):
|
||||
pass
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,
|
||||
include_none=False):
|
||||
|
|
|
@ -24,6 +24,7 @@ from synapse.api.errors import (
|
|||
CodeMessageException, HttpResponseException, SynapseError,
|
||||
)
|
||||
from synapse.util import unwrapFirstError
|
||||
from synapse.util.async import concurrently_execute
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.events import FrozenEvent
|
||||
|
@ -50,7 +51,33 @@ sent_edus_counter = metrics.register_counter("sent_edus")
|
|||
sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"])
|
||||
|
||||
|
||||
PDU_RETRY_TIME_MS = 1 * 60 * 1000
|
||||
|
||||
|
||||
class FederationClient(FederationBase):
|
||||
def __init__(self, hs):
|
||||
super(FederationClient, self).__init__(hs)
|
||||
|
||||
self.pdu_destination_tried = {}
|
||||
self._clock.looping_call(
|
||||
self._clear_tried_cache, 60 * 1000,
|
||||
)
|
||||
|
||||
def _clear_tried_cache(self):
|
||||
"""Clear pdu_destination_tried cache"""
|
||||
now = self._clock.time_msec()
|
||||
|
||||
old_dict = self.pdu_destination_tried
|
||||
self.pdu_destination_tried = {}
|
||||
|
||||
for event_id, destination_dict in old_dict.items():
|
||||
destination_dict = {
|
||||
dest: time
|
||||
for dest, time in destination_dict.items()
|
||||
if time + PDU_RETRY_TIME_MS > now
|
||||
}
|
||||
if destination_dict:
|
||||
self.pdu_destination_tried[event_id] = destination_dict
|
||||
|
||||
def start_get_pdu_cache(self):
|
||||
self._get_pdu_cache = ExpiringCache(
|
||||
|
@ -233,12 +260,19 @@ class FederationClient(FederationBase):
|
|||
# TODO: Rate limit the number of times we try and get the same event.
|
||||
|
||||
if self._get_pdu_cache:
|
||||
e = self._get_pdu_cache.get(event_id)
|
||||
if e:
|
||||
defer.returnValue(e)
|
||||
ev = self._get_pdu_cache.get(event_id)
|
||||
if ev:
|
||||
defer.returnValue(ev)
|
||||
|
||||
pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {})
|
||||
|
||||
pdu = None
|
||||
for destination in destinations:
|
||||
now = self._clock.time_msec()
|
||||
last_attempt = pdu_attempts.get(destination, 0)
|
||||
if last_attempt + PDU_RETRY_TIME_MS > now:
|
||||
continue
|
||||
|
||||
try:
|
||||
limiter = yield get_retry_limiter(
|
||||
destination,
|
||||
|
@ -266,25 +300,19 @@ class FederationClient(FederationBase):
|
|||
|
||||
break
|
||||
|
||||
except SynapseError:
|
||||
logger.info(
|
||||
"Failed to get PDU %s from %s because %s",
|
||||
event_id, destination, e,
|
||||
)
|
||||
continue
|
||||
except CodeMessageException as e:
|
||||
if 400 <= e.code < 500:
|
||||
raise
|
||||
pdu_attempts[destination] = now
|
||||
|
||||
except SynapseError as e:
|
||||
logger.info(
|
||||
"Failed to get PDU %s from %s because %s",
|
||||
event_id, destination, e,
|
||||
)
|
||||
continue
|
||||
except NotRetryingDestination as e:
|
||||
logger.info(e.message)
|
||||
continue
|
||||
except Exception as e:
|
||||
pdu_attempts[destination] = now
|
||||
|
||||
logger.info(
|
||||
"Failed to get PDU %s from %s because %s",
|
||||
event_id, destination, e,
|
||||
|
@ -311,6 +339,42 @@ class FederationClient(FederationBase):
|
|||
Deferred: Results in a list of PDUs.
|
||||
"""
|
||||
|
||||
try:
|
||||
# First we try and ask for just the IDs, as thats far quicker if
|
||||
# we have most of the state and auth_chain already.
|
||||
# However, this may 404 if the other side has an old synapse.
|
||||
result = yield self.transport_layer.get_room_state_ids(
|
||||
destination, room_id, event_id=event_id,
|
||||
)
|
||||
|
||||
state_event_ids = result["pdu_ids"]
|
||||
auth_event_ids = result.get("auth_chain_ids", [])
|
||||
|
||||
fetched_events, failed_to_fetch = yield self.get_events(
|
||||
[destination], room_id, set(state_event_ids + auth_event_ids)
|
||||
)
|
||||
|
||||
if failed_to_fetch:
|
||||
logger.warn("Failed to get %r", failed_to_fetch)
|
||||
|
||||
event_map = {
|
||||
ev.event_id: ev for ev in fetched_events
|
||||
}
|
||||
|
||||
pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map]
|
||||
auth_chain = [
|
||||
event_map[e_id] for e_id in auth_event_ids if e_id in event_map
|
||||
]
|
||||
|
||||
auth_chain.sort(key=lambda e: e.depth)
|
||||
|
||||
defer.returnValue((pdus, auth_chain))
|
||||
except HttpResponseException as e:
|
||||
if e.code == 400 or e.code == 404:
|
||||
logger.info("Failed to use get_room_state_ids API, falling back")
|
||||
else:
|
||||
raise e
|
||||
|
||||
result = yield self.transport_layer.get_room_state(
|
||||
destination, room_id, event_id=event_id,
|
||||
)
|
||||
|
@ -324,18 +388,93 @@ class FederationClient(FederationBase):
|
|||
for p in result.get("auth_chain", [])
|
||||
]
|
||||
|
||||
seen_events = yield self.store.get_events([
|
||||
ev.event_id for ev in itertools.chain(pdus, auth_chain)
|
||||
])
|
||||
|
||||
signed_pdus = yield self._check_sigs_and_hash_and_fetch(
|
||||
destination, pdus, outlier=True
|
||||
destination,
|
||||
[p for p in pdus if p.event_id not in seen_events],
|
||||
outlier=True
|
||||
)
|
||||
signed_pdus.extend(
|
||||
seen_events[p.event_id] for p in pdus if p.event_id in seen_events
|
||||
)
|
||||
|
||||
signed_auth = yield self._check_sigs_and_hash_and_fetch(
|
||||
destination, auth_chain, outlier=True
|
||||
destination,
|
||||
[p for p in auth_chain if p.event_id not in seen_events],
|
||||
outlier=True
|
||||
)
|
||||
signed_auth.extend(
|
||||
seen_events[p.event_id] for p in auth_chain if p.event_id in seen_events
|
||||
)
|
||||
|
||||
signed_auth.sort(key=lambda e: e.depth)
|
||||
|
||||
defer.returnValue((signed_pdus, signed_auth))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_events(self, destinations, room_id, event_ids, return_local=True):
|
||||
"""Fetch events from some remote destinations, checking if we already
|
||||
have them.
|
||||
|
||||
Args:
|
||||
destinations (list)
|
||||
room_id (str)
|
||||
event_ids (list)
|
||||
return_local (bool): Whether to include events we already have in
|
||||
the DB in the returned list of events
|
||||
|
||||
Returns:
|
||||
Deferred: A deferred resolving to a 2-tuple where the first is a list of
|
||||
events and the second is a list of event ids that we failed to fetch.
|
||||
"""
|
||||
if return_local:
|
||||
seen_events = yield self.store.get_events(event_ids, allow_rejected=True)
|
||||
signed_events = seen_events.values()
|
||||
else:
|
||||
seen_events = yield self.store.have_events(event_ids)
|
||||
signed_events = []
|
||||
|
||||
failed_to_fetch = set()
|
||||
|
||||
missing_events = set(event_ids)
|
||||
for k in seen_events:
|
||||
missing_events.discard(k)
|
||||
|
||||
if not missing_events:
|
||||
defer.returnValue((signed_events, failed_to_fetch))
|
||||
|
||||
def random_server_list():
|
||||
srvs = list(destinations)
|
||||
random.shuffle(srvs)
|
||||
return srvs
|
||||
|
||||
batch_size = 20
|
||||
missing_events = list(missing_events)
|
||||
for i in xrange(0, len(missing_events), batch_size):
|
||||
batch = set(missing_events[i:i + batch_size])
|
||||
|
||||
deferreds = [
|
||||
self.get_pdu(
|
||||
destinations=random_server_list(),
|
||||
event_id=e_id,
|
||||
)
|
||||
for e_id in batch
|
||||
]
|
||||
|
||||
res = yield defer.DeferredList(deferreds, consumeErrors=True)
|
||||
for success, result in res:
|
||||
if success:
|
||||
signed_events.append(result)
|
||||
batch.discard(result.event_id)
|
||||
|
||||
# We removed all events we successfully fetched from `batch`
|
||||
failed_to_fetch.update(batch)
|
||||
|
||||
defer.returnValue((signed_events, failed_to_fetch))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def get_event_auth(self, destination, room_id, event_id):
|
||||
|
@ -411,14 +550,19 @@ class FederationClient(FederationBase):
|
|||
(destination, self.event_from_pdu_json(pdu_dict))
|
||||
)
|
||||
break
|
||||
except CodeMessageException:
|
||||
except CodeMessageException as e:
|
||||
if not 500 <= e.code < 600:
|
||||
raise
|
||||
else:
|
||||
logger.warn(
|
||||
"Failed to make_%s via %s: %s",
|
||||
membership, destination, e.message
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warn(
|
||||
"Failed to make_%s via %s: %s",
|
||||
membership, destination, e.message
|
||||
)
|
||||
raise
|
||||
|
||||
raise RuntimeError("Failed to send to any server.")
|
||||
|
||||
|
@ -490,8 +634,14 @@ class FederationClient(FederationBase):
|
|||
"auth_chain": signed_auth,
|
||||
"origin": destination,
|
||||
})
|
||||
except CodeMessageException:
|
||||
except CodeMessageException as e:
|
||||
if not 500 <= e.code < 600:
|
||||
raise
|
||||
else:
|
||||
logger.exception(
|
||||
"Failed to send_join via %s: %s",
|
||||
destination, e.message
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to send_join via %s: %s",
|
||||
|
@ -550,6 +700,25 @@ class FederationClient(FederationBase):
|
|||
|
||||
raise RuntimeError("Failed to send to any server.")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_public_rooms(self, destinations):
|
||||
results_by_server = {}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_result(s):
|
||||
if s == self.server_name:
|
||||
defer.returnValue()
|
||||
|
||||
try:
|
||||
result = yield self.transport_layer.get_public_rooms(s)
|
||||
results_by_server[s] = result
|
||||
except:
|
||||
logger.exception("Error getting room list from server %r", s)
|
||||
|
||||
yield concurrently_execute(_get_result, destinations, 3)
|
||||
|
||||
defer.returnValue(results_by_server)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def query_auth(self, destination, room_id, event_id, local_auth):
|
||||
"""
|
||||
|
|
|
@ -19,11 +19,13 @@ from twisted.internet import defer
|
|||
from .federation_base import FederationBase
|
||||
from .units import Transaction, Edu
|
||||
|
||||
from synapse.util.async import Linearizer
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.caches.response_cache import ResponseCache
|
||||
from synapse.events import FrozenEvent
|
||||
import synapse.metrics
|
||||
|
||||
from synapse.api.errors import FederationError, SynapseError
|
||||
from synapse.api.errors import AuthError, FederationError, SynapseError
|
||||
|
||||
from synapse.crypto.event_signing import compute_event_signature
|
||||
|
||||
|
@ -44,6 +46,18 @@ received_queries_counter = metrics.register_counter("received_queries", labels=[
|
|||
|
||||
|
||||
class FederationServer(FederationBase):
|
||||
def __init__(self, hs):
|
||||
super(FederationServer, self).__init__(hs)
|
||||
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
self._room_pdu_linearizer = Linearizer()
|
||||
self._server_linearizer = Linearizer()
|
||||
|
||||
# We cache responses to state queries, as they take a while and often
|
||||
# come in waves.
|
||||
self._state_resp_cache = ResponseCache(hs, timeout_ms=30000)
|
||||
|
||||
def set_handler(self, handler):
|
||||
"""Sets the handler that the replication layer will use to communicate
|
||||
receipt of new PDUs from other home servers. The required methods are
|
||||
|
@ -83,11 +97,14 @@ class FederationServer(FederationBase):
|
|||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def on_backfill_request(self, origin, room_id, versions, limit):
|
||||
with (yield self._server_linearizer.queue((origin, room_id))):
|
||||
pdus = yield self.handler.on_backfill_request(
|
||||
origin, room_id, versions, limit
|
||||
)
|
||||
|
||||
defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
|
||||
res = self._transaction_from_pdus(pdus).get_dict()
|
||||
|
||||
defer.returnValue((200, res))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
|
@ -178,15 +195,59 @@ class FederationServer(FederationBase):
|
|||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def on_context_state_request(self, origin, room_id, event_id):
|
||||
if event_id:
|
||||
if not event_id:
|
||||
raise NotImplementedError("Specify an event")
|
||||
|
||||
in_room = yield self.auth.check_host_in_room(room_id, origin)
|
||||
if not in_room:
|
||||
raise AuthError(403, "Host not in room.")
|
||||
|
||||
result = self._state_resp_cache.get((room_id, event_id))
|
||||
if not result:
|
||||
with (yield self._server_linearizer.queue((origin, room_id))):
|
||||
resp = yield self._state_resp_cache.set(
|
||||
(room_id, event_id),
|
||||
self._on_context_state_request_compute(room_id, event_id)
|
||||
)
|
||||
else:
|
||||
resp = yield result
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_state_ids_request(self, origin, room_id, event_id):
|
||||
if not event_id:
|
||||
raise NotImplementedError("Specify an event")
|
||||
|
||||
in_room = yield self.auth.check_host_in_room(room_id, origin)
|
||||
if not in_room:
|
||||
raise AuthError(403, "Host not in room.")
|
||||
|
||||
pdus = yield self.handler.get_state_for_pdu(
|
||||
origin, room_id, event_id,
|
||||
room_id, event_id,
|
||||
)
|
||||
auth_chain = yield self.store.get_auth_chain(
|
||||
[pdu.event_id for pdu in pdus]
|
||||
)
|
||||
|
||||
defer.returnValue((200, {
|
||||
"pdu_ids": [pdu.event_id for pdu in pdus],
|
||||
"auth_chain_ids": [pdu.event_id for pdu in auth_chain],
|
||||
}))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _on_context_state_request_compute(self, room_id, event_id):
|
||||
pdus = yield self.handler.get_state_for_pdu(
|
||||
room_id, event_id,
|
||||
)
|
||||
auth_chain = yield self.store.get_auth_chain(
|
||||
[pdu.event_id for pdu in pdus]
|
||||
)
|
||||
|
||||
for event in auth_chain:
|
||||
# We sign these again because there was a bug where we
|
||||
# incorrectly signed things the first time round
|
||||
if self.hs.is_mine_id(event.event_id):
|
||||
event.signatures.update(
|
||||
compute_event_signature(
|
||||
event,
|
||||
|
@ -194,13 +255,11 @@ class FederationServer(FederationBase):
|
|||
self.hs.config.signing_key[0]
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Specify an event")
|
||||
|
||||
defer.returnValue((200, {
|
||||
defer.returnValue({
|
||||
"pdus": [pdu.get_pdu_json() for pdu in pdus],
|
||||
"auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
|
||||
}))
|
||||
})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
|
@ -274,14 +333,16 @@ class FederationServer(FederationBase):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_event_auth(self, origin, room_id, event_id):
|
||||
with (yield self._server_linearizer.queue((origin, room_id))):
|
||||
time_now = self._clock.time_msec()
|
||||
auth_pdus = yield self.handler.on_event_auth(event_id)
|
||||
defer.returnValue((200, {
|
||||
res = {
|
||||
"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus],
|
||||
}))
|
||||
}
|
||||
defer.returnValue((200, res))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_query_auth_request(self, origin, content, event_id):
|
||||
def on_query_auth_request(self, origin, content, room_id, event_id):
|
||||
"""
|
||||
Content is a dict with keys::
|
||||
auth_chain (list): A list of events that give the auth chain.
|
||||
|
@ -300,6 +361,7 @@ class FederationServer(FederationBase):
|
|||
Returns:
|
||||
Deferred: Results in `dict` with the same format as `content`
|
||||
"""
|
||||
with (yield self._server_linearizer.queue((origin, room_id))):
|
||||
auth_chain = [
|
||||
self.event_from_pdu_json(e)
|
||||
for e in content["auth_chain"]
|
||||
|
@ -331,27 +393,9 @@ class FederationServer(FederationBase):
|
|||
(200, send_content)
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def on_query_client_keys(self, origin, content):
|
||||
query = []
|
||||
for user_id, device_ids in content.get("device_keys", {}).items():
|
||||
if not device_ids:
|
||||
query.append((user_id, None))
|
||||
else:
|
||||
for device_id in device_ids:
|
||||
query.append((user_id, device_id))
|
||||
|
||||
results = yield self.store.get_e2e_device_keys(query)
|
||||
|
||||
json_result = {}
|
||||
for user_id, device_keys in results.items():
|
||||
for device_id, json_bytes in device_keys.items():
|
||||
json_result.setdefault(user_id, {})[device_id] = json.loads(
|
||||
json_bytes
|
||||
)
|
||||
|
||||
defer.returnValue({"device_keys": json_result})
|
||||
return self.on_query_request("client_keys", content)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
|
@ -377,10 +421,23 @@ class FederationServer(FederationBase):
|
|||
@log_function
|
||||
def on_get_missing_events(self, origin, room_id, earliest_events,
|
||||
latest_events, limit, min_depth):
|
||||
with (yield self._server_linearizer.queue((origin, room_id))):
|
||||
logger.info(
|
||||
"on_get_missing_events: earliest_events: %r, latest_events: %r,"
|
||||
" limit: %d, min_depth: %d",
|
||||
earliest_events, latest_events, limit, min_depth
|
||||
)
|
||||
missing_events = yield self.handler.on_get_missing_events(
|
||||
origin, room_id, earliest_events, latest_events, limit, min_depth
|
||||
)
|
||||
|
||||
if len(missing_events) < 5:
|
||||
logger.info(
|
||||
"Returning %d events: %r", len(missing_events), missing_events
|
||||
)
|
||||
else:
|
||||
logger.info("Returning %d events", len(missing_events))
|
||||
|
||||
time_now = self._clock.time_msec()
|
||||
|
||||
defer.returnValue({
|
||||
|
@ -481,6 +538,14 @@ class FederationServer(FederationBase):
|
|||
pdu.internal_metadata.outlier = True
|
||||
elif min_depth and pdu.depth > min_depth:
|
||||
if get_missing and prevs - seen:
|
||||
# If we're missing stuff, ensure we only fetch stuff one
|
||||
# at a time.
|
||||
with (yield self._room_pdu_linearizer.queue(pdu.room_id)):
|
||||
# We recalculate seen, since it may have changed.
|
||||
have_seen = yield self.store.have_events(prevs)
|
||||
seen = set(have_seen.keys())
|
||||
|
||||
if prevs - seen:
|
||||
latest = yield self.store.get_latest_event_ids_in_room(
|
||||
pdu.room_id
|
||||
)
|
||||
|
@ -490,6 +555,11 @@ class FederationServer(FederationBase):
|
|||
latest = set(latest)
|
||||
latest |= seen
|
||||
|
||||
logger.info(
|
||||
"Missing %d events for room %r: %r...",
|
||||
len(prevs - seen), pdu.room_id, list(prevs - seen)[:5]
|
||||
)
|
||||
|
||||
missing_events = yield self.get_missing_events(
|
||||
origin,
|
||||
pdu.room_id,
|
||||
|
@ -517,6 +587,10 @@ class FederationServer(FederationBase):
|
|||
prevs = {e_id for e_id, _ in pdu.prev_events}
|
||||
seen = set(have_seen.keys())
|
||||
if prevs - seen:
|
||||
logger.info(
|
||||
"Still missing %d events for room %r: %r...",
|
||||
len(prevs - seen), pdu.room_id, list(prevs - seen)[:5]
|
||||
)
|
||||
fetch_state = True
|
||||
|
||||
if fetch_state:
|
||||
|
@ -531,7 +605,7 @@ class FederationServer(FederationBase):
|
|||
origin, pdu.room_id, pdu.event_id,
|
||||
)
|
||||
except:
|
||||
logger.warn("Failed to get state for event: %s", pdu.event_id)
|
||||
logger.exception("Failed to get state for event: %s", pdu.event_id)
|
||||
|
||||
yield self.handler.on_receive_pdu(
|
||||
origin,
|
||||
|
|
|
@ -72,5 +72,7 @@ class ReplicationLayer(FederationClient, FederationServer):
|
|||
|
||||
self.hs = hs
|
||||
|
||||
super(ReplicationLayer, self).__init__(hs)
|
||||
|
||||
def __str__(self):
|
||||
return "<ReplicationLayer(%s)>" % self.server_name
|
||||
|
|
|
@ -21,11 +21,11 @@ from .units import Transaction
|
|||
|
||||
from synapse.api.errors import HttpResponseException
|
||||
from synapse.util.async import run_on_reactor
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.logcontext import PreserveLoggingContext
|
||||
from synapse.util.logcontext import preserve_context_over_fn
|
||||
from synapse.util.retryutils import (
|
||||
get_retry_limiter, NotRetryingDestination,
|
||||
)
|
||||
from synapse.util.metrics import measure_func
|
||||
import synapse.metrics
|
||||
|
||||
import logging
|
||||
|
@ -51,7 +51,7 @@ class TransactionQueue(object):
|
|||
|
||||
self.transport_layer = transport_layer
|
||||
|
||||
self._clock = hs.get_clock()
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
# Is a mapping from destinations -> deferreds. Used to keep track
|
||||
# of which destinations have transactions in flight and when they are
|
||||
|
@ -82,7 +82,7 @@ class TransactionQueue(object):
|
|||
self.pending_failures_by_dest = {}
|
||||
|
||||
# HACK to get unique tx id
|
||||
self._next_txn_id = int(self._clock.time_msec())
|
||||
self._next_txn_id = int(self.clock.time_msec())
|
||||
|
||||
def can_send_to(self, destination):
|
||||
"""Can we send messages to the given server?
|
||||
|
@ -119,89 +119,46 @@ class TransactionQueue(object):
|
|||
if not destinations:
|
||||
return
|
||||
|
||||
deferreds = []
|
||||
|
||||
for destination in destinations:
|
||||
deferred = defer.Deferred()
|
||||
self.pending_pdus_by_dest.setdefault(destination, []).append(
|
||||
(pdu, deferred, order)
|
||||
(pdu, order)
|
||||
)
|
||||
|
||||
def chain(failure):
|
||||
if not deferred.called:
|
||||
deferred.errback(failure)
|
||||
preserve_context_over_fn(
|
||||
self._attempt_new_transaction, destination
|
||||
)
|
||||
|
||||
def log_failure(f):
|
||||
logger.warn("Failed to send pdu to %s: %s", destination, f.value)
|
||||
|
||||
deferred.addErrback(log_failure)
|
||||
|
||||
with PreserveLoggingContext():
|
||||
self._attempt_new_transaction(destination).addErrback(chain)
|
||||
|
||||
deferreds.append(deferred)
|
||||
|
||||
# NO inlineCallbacks
|
||||
def enqueue_edu(self, edu):
|
||||
destination = edu.destination
|
||||
|
||||
if not self.can_send_to(destination):
|
||||
return
|
||||
|
||||
deferred = defer.Deferred()
|
||||
self.pending_edus_by_dest.setdefault(destination, []).append(
|
||||
(edu, deferred)
|
||||
self.pending_edus_by_dest.setdefault(destination, []).append(edu)
|
||||
|
||||
preserve_context_over_fn(
|
||||
self._attempt_new_transaction, destination
|
||||
)
|
||||
|
||||
def chain(failure):
|
||||
if not deferred.called:
|
||||
deferred.errback(failure)
|
||||
|
||||
def log_failure(f):
|
||||
logger.warn("Failed to send edu to %s: %s", destination, f.value)
|
||||
|
||||
deferred.addErrback(log_failure)
|
||||
|
||||
with PreserveLoggingContext():
|
||||
self._attempt_new_transaction(destination).addErrback(chain)
|
||||
|
||||
return deferred
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def enqueue_failure(self, failure, destination):
|
||||
if destination == self.server_name or destination == "localhost":
|
||||
return
|
||||
|
||||
deferred = defer.Deferred()
|
||||
|
||||
if not self.can_send_to(destination):
|
||||
return
|
||||
|
||||
self.pending_failures_by_dest.setdefault(
|
||||
destination, []
|
||||
).append(
|
||||
(failure, deferred)
|
||||
).append(failure)
|
||||
|
||||
preserve_context_over_fn(
|
||||
self._attempt_new_transaction, destination
|
||||
)
|
||||
|
||||
def chain(f):
|
||||
if not deferred.called:
|
||||
deferred.errback(f)
|
||||
|
||||
def log_failure(f):
|
||||
logger.warn("Failed to send failure to %s: %s", destination, f.value)
|
||||
|
||||
deferred.addErrback(log_failure)
|
||||
|
||||
with PreserveLoggingContext():
|
||||
self._attempt_new_transaction(destination).addErrback(chain)
|
||||
|
||||
yield deferred
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def _attempt_new_transaction(self, destination):
|
||||
yield run_on_reactor()
|
||||
|
||||
while True:
|
||||
# list of (pending_pdu, deferred, order)
|
||||
if destination in self.pending_transactions:
|
||||
# XXX: pending_transactions can get stuck on by a never-ending
|
||||
|
@ -226,27 +183,31 @@ class TransactionQueue(object):
|
|||
logger.debug("TX [%s] Nothing to send", destination)
|
||||
return
|
||||
|
||||
yield self._send_new_transaction(
|
||||
destination, pending_pdus, pending_edus, pending_failures
|
||||
)
|
||||
|
||||
@measure_func("_send_new_transaction")
|
||||
@defer.inlineCallbacks
|
||||
def _send_new_transaction(self, destination, pending_pdus, pending_edus,
|
||||
pending_failures):
|
||||
|
||||
# Sort based on the order field
|
||||
pending_pdus.sort(key=lambda t: t[1])
|
||||
pdus = [x[0] for x in pending_pdus]
|
||||
edus = pending_edus
|
||||
failures = [x.get_dict() for x in pending_failures]
|
||||
|
||||
try:
|
||||
self.pending_transactions[destination] = 1
|
||||
|
||||
logger.debug("TX [%s] _attempt_new_transaction", destination)
|
||||
|
||||
# Sort based on the order field
|
||||
pending_pdus.sort(key=lambda t: t[2])
|
||||
|
||||
pdus = [x[0] for x in pending_pdus]
|
||||
edus = [x[0] for x in pending_edus]
|
||||
failures = [x[0].get_dict() for x in pending_failures]
|
||||
deferreds = [
|
||||
x[1]
|
||||
for x in pending_pdus + pending_edus + pending_failures
|
||||
]
|
||||
|
||||
txn_id = str(self._next_txn_id)
|
||||
|
||||
limiter = yield get_retry_limiter(
|
||||
destination,
|
||||
self._clock,
|
||||
self.clock,
|
||||
self.store,
|
||||
)
|
||||
|
||||
|
@ -262,7 +223,7 @@ class TransactionQueue(object):
|
|||
logger.debug("TX [%s] Persisting transaction...", destination)
|
||||
|
||||
transaction = Transaction.create_new(
|
||||
origin_server_ts=int(self._clock.time_msec()),
|
||||
origin_server_ts=int(self.clock.time_msec()),
|
||||
transaction_id=txn_id,
|
||||
origin=self.server_name,
|
||||
destination=destination,
|
||||
|
@ -293,7 +254,7 @@ class TransactionQueue(object):
|
|||
# keys work
|
||||
def json_data_cb():
|
||||
data = transaction.get_dict()
|
||||
now = int(self._clock.time_msec())
|
||||
now = int(self.clock.time_msec())
|
||||
if "pdus" in data:
|
||||
for p in data["pdus"]:
|
||||
if "age_ts" in p:
|
||||
|
@ -333,22 +294,11 @@ class TransactionQueue(object):
|
|||
|
||||
logger.debug("TX [%s] Marked as delivered", destination)
|
||||
|
||||
logger.debug("TX [%s] Yielding to callbacks...", destination)
|
||||
|
||||
for deferred in deferreds:
|
||||
if code == 200:
|
||||
deferred.callback(None)
|
||||
else:
|
||||
deferred.errback(RuntimeError("Got status %d" % code))
|
||||
|
||||
# Ensures we don't continue until all callbacks on that
|
||||
# deferred have fired
|
||||
try:
|
||||
yield deferred
|
||||
except:
|
||||
pass
|
||||
|
||||
logger.debug("TX [%s] Yielded to callbacks", destination)
|
||||
if code != 200:
|
||||
for p in pdus:
|
||||
logger.info(
|
||||
"Failed to send event %s to %s", p.event_id, destination
|
||||
)
|
||||
except NotRetryingDestination:
|
||||
logger.info(
|
||||
"TX [%s] not ready for retry yet - "
|
||||
|
@ -363,6 +313,9 @@ class TransactionQueue(object):
|
|||
destination,
|
||||
e,
|
||||
)
|
||||
|
||||
for p in pdus:
|
||||
logger.info("Failed to send event %s to %s", p.event_id, destination)
|
||||
except Exception as e:
|
||||
# We capture this here as there as nothing actually listens
|
||||
# for this finishing functions deferred.
|
||||
|
@ -372,13 +325,9 @@ class TransactionQueue(object):
|
|||
e,
|
||||
)
|
||||
|
||||
for deferred in deferreds:
|
||||
if not deferred.called:
|
||||
deferred.errback(e)
|
||||
for p in pdus:
|
||||
logger.info("Failed to send event %s to %s", p.event_id, destination)
|
||||
|
||||
finally:
|
||||
# We want to be *very* sure we delete this after we stop processing
|
||||
self.pending_transactions.pop(destination, None)
|
||||
|
||||
# Check to see if there is anything else to send.
|
||||
self._attempt_new_transaction(destination)
|
||||
|
|
|
@ -54,6 +54,28 @@ class TransportLayerClient(object):
|
|||
destination, path=path, args={"event_id": event_id},
|
||||
)
|
||||
|
||||
@log_function
|
||||
def get_room_state_ids(self, destination, room_id, event_id):
|
||||
""" Requests all state for a given room from the given server at the
|
||||
given event. Returns the state's event_id's
|
||||
|
||||
Args:
|
||||
destination (str): The host name of the remote home server we want
|
||||
to get the state from.
|
||||
context (str): The name of the context we want the state of
|
||||
event_id (str): The event we want the context at.
|
||||
|
||||
Returns:
|
||||
Deferred: Results in a dict received from the remote homeserver.
|
||||
"""
|
||||
logger.debug("get_room_state_ids dest=%s, room=%s",
|
||||
destination, room_id)
|
||||
|
||||
path = PREFIX + "/state_ids/%s/" % room_id
|
||||
return self.client.get_json(
|
||||
destination, path=path, args={"event_id": event_id},
|
||||
)
|
||||
|
||||
@log_function
|
||||
def get_event(self, destination, event_id, timeout=None):
|
||||
""" Requests the pdu with give id and origin from the given server.
|
||||
|
@ -224,6 +246,18 @@ class TransportLayerClient(object):
|
|||
|
||||
defer.returnValue(response)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def get_public_rooms(self, remote_server):
|
||||
path = PREFIX + "/publicRooms"
|
||||
|
||||
response = yield self.client.get_json(
|
||||
destination=remote_server,
|
||||
path=path,
|
||||
)
|
||||
|
||||
defer.returnValue(response)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def exchange_third_party_invite(self, destination, room_id, event_dict):
|
||||
|
|
|
@ -18,13 +18,14 @@ from twisted.internet import defer
|
|||
from synapse.api.urls import FEDERATION_PREFIX as PREFIX
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.http.server import JsonResource
|
||||
from synapse.http.servlet import parse_json_object_from_request, parse_string
|
||||
from synapse.http.servlet import parse_json_object_from_request
|
||||
from synapse.util.ratelimitutils import FederationRateLimiter
|
||||
from synapse.util.versionstring import get_version_string
|
||||
|
||||
import functools
|
||||
import logging
|
||||
import simplejson as json
|
||||
import re
|
||||
import synapse
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -37,7 +38,7 @@ class TransportLayerServer(JsonResource):
|
|||
self.hs = hs
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
super(TransportLayerServer, self).__init__(hs)
|
||||
super(TransportLayerServer, self).__init__(hs, canonical_json=False)
|
||||
|
||||
self.authenticator = Authenticator(hs)
|
||||
self.ratelimiter = FederationRateLimiter(
|
||||
|
@ -60,6 +61,16 @@ class TransportLayerServer(JsonResource):
|
|||
)
|
||||
|
||||
|
||||
class AuthenticationError(SynapseError):
|
||||
"""There was a problem authenticating the request"""
|
||||
pass
|
||||
|
||||
|
||||
class NoAuthenticationError(AuthenticationError):
|
||||
"""The request had no authentication information"""
|
||||
pass
|
||||
|
||||
|
||||
class Authenticator(object):
|
||||
def __init__(self, hs):
|
||||
self.keyring = hs.get_keyring()
|
||||
|
@ -67,7 +78,7 @@ class Authenticator(object):
|
|||
|
||||
# A method just so we can pass 'self' as the authenticator to the Servlets
|
||||
@defer.inlineCallbacks
|
||||
def authenticate_request(self, request):
|
||||
def authenticate_request(self, request, content):
|
||||
json_request = {
|
||||
"method": request.method,
|
||||
"uri": request.uri,
|
||||
|
@ -75,17 +86,10 @@ class Authenticator(object):
|
|||
"signatures": {},
|
||||
}
|
||||
|
||||
content = None
|
||||
origin = None
|
||||
|
||||
if request.method in ["PUT", "POST"]:
|
||||
# TODO: Handle other method types? other content types?
|
||||
try:
|
||||
content_bytes = request.content.read()
|
||||
content = json.loads(content_bytes)
|
||||
if content is not None:
|
||||
json_request["content"] = content
|
||||
except:
|
||||
raise SynapseError(400, "Unable to parse JSON", Codes.BAD_JSON)
|
||||
|
||||
origin = None
|
||||
|
||||
def parse_auth_header(header_str):
|
||||
try:
|
||||
|
@ -103,14 +107,14 @@ class Authenticator(object):
|
|||
sig = strip_quotes(param_dict["sig"])
|
||||
return (origin, key, sig)
|
||||
except:
|
||||
raise SynapseError(
|
||||
raise AuthenticationError(
|
||||
400, "Malformed Authorization header", Codes.UNAUTHORIZED
|
||||
)
|
||||
|
||||
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
|
||||
|
||||
if not auth_headers:
|
||||
raise SynapseError(
|
||||
raise NoAuthenticationError(
|
||||
401, "Missing Authorization headers", Codes.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
|
@ -121,7 +125,7 @@ class Authenticator(object):
|
|||
json_request["signatures"].setdefault(origin, {})[key] = sig
|
||||
|
||||
if not json_request["signatures"]:
|
||||
raise SynapseError(
|
||||
raise NoAuthenticationError(
|
||||
401, "Missing Authorization headers", Codes.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
|
@ -130,38 +134,59 @@ class Authenticator(object):
|
|||
logger.info("Request from %s", origin)
|
||||
request.authenticated_entity = origin
|
||||
|
||||
defer.returnValue((origin, content))
|
||||
defer.returnValue(origin)
|
||||
|
||||
|
||||
class BaseFederationServlet(object):
|
||||
def __init__(self, handler, authenticator, ratelimiter, server_name):
|
||||
REQUIRE_AUTH = True
|
||||
|
||||
def __init__(self, handler, authenticator, ratelimiter, server_name,
|
||||
room_list_handler):
|
||||
self.handler = handler
|
||||
self.authenticator = authenticator
|
||||
self.ratelimiter = ratelimiter
|
||||
self.room_list_handler = room_list_handler
|
||||
|
||||
def _wrap(self, code):
|
||||
def _wrap(self, func):
|
||||
authenticator = self.authenticator
|
||||
ratelimiter = self.ratelimiter
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@functools.wraps(code)
|
||||
def new_code(request, *args, **kwargs):
|
||||
@functools.wraps(func)
|
||||
def new_func(request, *args, **kwargs):
|
||||
content = None
|
||||
if request.method in ["PUT", "POST"]:
|
||||
# TODO: Handle other method types? other content types?
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
try:
|
||||
(origin, content) = yield authenticator.authenticate_request(request)
|
||||
with ratelimiter.ratelimit(origin) as d:
|
||||
yield d
|
||||
response = yield code(
|
||||
origin, content, request.args, *args, **kwargs
|
||||
)
|
||||
origin = yield authenticator.authenticate_request(request, content)
|
||||
except NoAuthenticationError:
|
||||
origin = None
|
||||
if self.REQUIRE_AUTH:
|
||||
logger.exception("authenticate_request failed")
|
||||
raise
|
||||
except:
|
||||
logger.exception("authenticate_request failed")
|
||||
raise
|
||||
|
||||
if origin:
|
||||
with ratelimiter.ratelimit(origin) as d:
|
||||
yield d
|
||||
response = yield func(
|
||||
origin, content, request.args, *args, **kwargs
|
||||
)
|
||||
else:
|
||||
response = yield func(
|
||||
origin, content, request.args, *args, **kwargs
|
||||
)
|
||||
|
||||
defer.returnValue(response)
|
||||
|
||||
# Extra logic that functools.wraps() doesn't finish
|
||||
new_code.__self__ = code.__self__
|
||||
new_func.__self__ = func.__self__
|
||||
|
||||
return new_code
|
||||
return new_func
|
||||
|
||||
def register(self, server):
|
||||
pattern = re.compile("^" + PREFIX + self.PATH + "$")
|
||||
|
@ -269,6 +294,17 @@ class FederationStateServlet(BaseFederationServlet):
|
|||
)
|
||||
|
||||
|
||||
class FederationStateIdsServlet(BaseFederationServlet):
|
||||
PATH = "/state_ids/(?P<room_id>[^/]*)/"
|
||||
|
||||
def on_GET(self, origin, content, query, room_id):
|
||||
return self.handler.on_state_ids_request(
|
||||
origin,
|
||||
room_id,
|
||||
query.get("event_id", [None])[0],
|
||||
)
|
||||
|
||||
|
||||
class FederationBackfillServlet(BaseFederationServlet):
|
||||
PATH = "/backfill/(?P<context>[^/]*)/"
|
||||
|
||||
|
@ -365,10 +401,8 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet):
|
|||
class FederationClientKeysQueryServlet(BaseFederationServlet):
|
||||
PATH = "/user/keys/query"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, origin, content, query):
|
||||
response = yield self.handler.on_query_client_keys(origin, content)
|
||||
defer.returnValue((200, response))
|
||||
return self.handler.on_query_client_keys(origin, content)
|
||||
|
||||
|
||||
class FederationClientKeysClaimServlet(BaseFederationServlet):
|
||||
|
@ -386,7 +420,7 @@ class FederationQueryAuthServlet(BaseFederationServlet):
|
|||
@defer.inlineCallbacks
|
||||
def on_POST(self, origin, content, query, context, event_id):
|
||||
new_content = yield self.handler.on_query_auth_request(
|
||||
origin, content, event_id
|
||||
origin, content, context, event_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, new_content))
|
||||
|
@ -418,9 +452,10 @@ class FederationGetMissingEventsServlet(BaseFederationServlet):
|
|||
class On3pidBindServlet(BaseFederationServlet):
|
||||
PATH = "/3pid/onbind"
|
||||
|
||||
REQUIRE_AUTH = False
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
content = parse_json_object_from_request(request)
|
||||
def on_POST(self, origin, content, query):
|
||||
if "invites" in content:
|
||||
last_exception = None
|
||||
for invite in content["invites"]:
|
||||
|
@ -442,11 +477,6 @@ class On3pidBindServlet(BaseFederationServlet):
|
|||
raise last_exception
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
# Avoid doing remote HS authorization checks which are done by default by
|
||||
# BaseFederationServlet.
|
||||
def _wrap(self, code):
|
||||
return code
|
||||
|
||||
|
||||
class OpenIdUserInfo(BaseFederationServlet):
|
||||
"""
|
||||
|
@ -467,9 +497,11 @@ class OpenIdUserInfo(BaseFederationServlet):
|
|||
|
||||
PATH = "/openid/userinfo"
|
||||
|
||||
REQUIRE_AUTH = False
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
token = parse_string(request, "access_token")
|
||||
def on_GET(self, origin, content, query):
|
||||
token = query.get("access_token", [None])[0]
|
||||
if token is None:
|
||||
defer.returnValue((401, {
|
||||
"errcode": "M_MISSING_TOKEN", "error": "Access Token required"
|
||||
|
@ -486,10 +518,58 @@ class OpenIdUserInfo(BaseFederationServlet):
|
|||
|
||||
defer.returnValue((200, {"sub": user_id}))
|
||||
|
||||
# Avoid doing remote HS authorization checks which are done by default by
|
||||
# BaseFederationServlet.
|
||||
def _wrap(self, code):
|
||||
return code
|
||||
|
||||
class PublicRoomList(BaseFederationServlet):
|
||||
"""
|
||||
Fetch the public room list for this server.
|
||||
|
||||
This API returns information in the same format as /publicRooms on the
|
||||
client API, but will only ever include local public rooms and hence is
|
||||
intended for consumption by other home servers.
|
||||
|
||||
GET /publicRooms HTTP/1.1
|
||||
|
||||
HTTP/1.1 200 OK
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"chunk": [
|
||||
{
|
||||
"aliases": [
|
||||
"#test:localhost"
|
||||
],
|
||||
"guest_can_join": false,
|
||||
"name": "test room",
|
||||
"num_joined_members": 3,
|
||||
"room_id": "!whkydVegtvatLfXmPN:localhost",
|
||||
"world_readable": false
|
||||
}
|
||||
],
|
||||
"end": "END",
|
||||
"start": "START"
|
||||
}
|
||||
"""
|
||||
|
||||
PATH = "/publicRooms"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, origin, content, query):
|
||||
data = yield self.room_list_handler.get_local_public_room_list()
|
||||
defer.returnValue((200, data))
|
||||
|
||||
|
||||
class FederationVersionServlet(BaseFederationServlet):
|
||||
PATH = "/version"
|
||||
|
||||
REQUIRE_AUTH = False
|
||||
|
||||
def on_GET(self, origin, content, query):
|
||||
return defer.succeed((200, {
|
||||
"server": {
|
||||
"name": "Synapse",
|
||||
"version": get_version_string(synapse)
|
||||
},
|
||||
}))
|
||||
|
||||
|
||||
SERVLET_CLASSES = (
|
||||
|
@ -497,6 +577,7 @@ SERVLET_CLASSES = (
|
|||
FederationPullServlet,
|
||||
FederationEventServlet,
|
||||
FederationStateServlet,
|
||||
FederationStateIdsServlet,
|
||||
FederationBackfillServlet,
|
||||
FederationQueryServlet,
|
||||
FederationMakeJoinServlet,
|
||||
|
@ -513,6 +594,8 @@ SERVLET_CLASSES = (
|
|||
FederationThirdPartyInviteExchangeServlet,
|
||||
On3pidBindServlet,
|
||||
OpenIdUserInfo,
|
||||
PublicRoomList,
|
||||
FederationVersionServlet,
|
||||
)
|
||||
|
||||
|
||||
|
@ -523,4 +606,5 @@ def register_servlets(hs, resource, authenticator, ratelimiter):
|
|||
authenticator=authenticator,
|
||||
ratelimiter=ratelimiter,
|
||||
server_name=hs.hostname,
|
||||
room_list_handler=hs.get_room_list_handler(),
|
||||
).register(resource)
|
||||
|
|
|
@ -13,11 +13,9 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.appservice.scheduler import AppServiceScheduler
|
||||
from synapse.appservice.api import ApplicationServiceApi
|
||||
from .register import RegistrationHandler
|
||||
from .room import (
|
||||
RoomCreationHandler, RoomListHandler, RoomContextHandler,
|
||||
RoomCreationHandler, RoomContextHandler,
|
||||
)
|
||||
from .room_member import RoomMemberHandler
|
||||
from .message import MessageHandler
|
||||
|
@ -26,8 +24,6 @@ from .federation import FederationHandler
|
|||
from .profile import ProfileHandler
|
||||
from .directory import DirectoryHandler
|
||||
from .admin import AdminHandler
|
||||
from .appservice import ApplicationServicesHandler
|
||||
from .auth import AuthHandler
|
||||
from .identity import IdentityHandler
|
||||
from .receipts import ReceiptsHandler
|
||||
from .search import SearchHandler
|
||||
|
@ -35,10 +31,21 @@ from .search import SearchHandler
|
|||
|
||||
class Handlers(object):
|
||||
|
||||
""" A collection of all the event handlers.
|
||||
""" Deprecated. A collection of handlers.
|
||||
|
||||
There's no need to lazily create these; we'll just make them all eagerly
|
||||
at construction time.
|
||||
At some point most of the classes whose name ended "Handler" were
|
||||
accessed through this class.
|
||||
|
||||
However this makes it painful to unit test the handlers and to run cut
|
||||
down versions of synapse that only use specific handlers because using a
|
||||
single handler required creating all of the handlers. So some of the
|
||||
handlers have been lifted out of the Handlers object and are now accessed
|
||||
directly through the homeserver object itself.
|
||||
|
||||
Any new handlers should follow the new pattern of being accessed through
|
||||
the homeserver object and should not be added to the Handlers object.
|
||||
|
||||
The remaining handlers should be moved out of the handlers object.
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
|
@ -50,19 +57,9 @@ class Handlers(object):
|
|||
self.event_handler = EventHandler(hs)
|
||||
self.federation_handler = FederationHandler(hs)
|
||||
self.profile_handler = ProfileHandler(hs)
|
||||
self.room_list_handler = RoomListHandler(hs)
|
||||
self.directory_handler = DirectoryHandler(hs)
|
||||
self.admin_handler = AdminHandler(hs)
|
||||
self.receipts_handler = ReceiptsHandler(hs)
|
||||
asapi = ApplicationServiceApi(hs)
|
||||
self.appservice_handler = ApplicationServicesHandler(
|
||||
hs, asapi, AppServiceScheduler(
|
||||
clock=hs.get_clock(),
|
||||
store=hs.get_datastore(),
|
||||
as_api=asapi
|
||||
)
|
||||
)
|
||||
self.auth_handler = AuthHandler(hs)
|
||||
self.identity_handler = IdentityHandler(hs)
|
||||
self.search_handler = SearchHandler(hs)
|
||||
self.room_context_handler = RoomContextHandler(hs)
|
||||
|
|
|
@ -13,14 +13,14 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import LimitExceededError
|
||||
import synapse.types
|
||||
from synapse.api.constants import Membership, EventTypes
|
||||
from synapse.types import UserID, Requester
|
||||
|
||||
|
||||
import logging
|
||||
from synapse.api.errors import LimitExceededError
|
||||
from synapse.types import UserID
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -31,11 +31,15 @@ class BaseHandler(object):
|
|||
Common base class for the event handlers.
|
||||
|
||||
Attributes:
|
||||
store (synapse.storage.events.StateStore):
|
||||
store (synapse.storage.DataStore):
|
||||
state_handler (synapse.state.StateHandler):
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
Args:
|
||||
hs (synapse.server.HomeServer):
|
||||
"""
|
||||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
self.notifier = hs.get_notifier()
|
||||
|
@ -120,7 +124,8 @@ class BaseHandler(object):
|
|||
# and having homeservers have their own users leave keeps more
|
||||
# of that decision-making and control local to the guest-having
|
||||
# homeserver.
|
||||
requester = Requester(target_user, "", True)
|
||||
requester = synapse.types.create_requester(
|
||||
target_user, is_guest=True)
|
||||
handler = self.hs.get_handlers().room_member_handler
|
||||
yield handler.update_membership(
|
||||
requester,
|
||||
|
|
|
@ -17,7 +17,6 @@ from twisted.internet import defer
|
|||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.appservice import ApplicationService
|
||||
from synapse.types import UserID
|
||||
|
||||
import logging
|
||||
|
||||
|
@ -35,16 +34,13 @@ def log_failure(failure):
|
|||
)
|
||||
|
||||
|
||||
# NB: Purposefully not inheriting BaseHandler since that contains way too much
|
||||
# setup code which this handler does not need or use. This makes testing a lot
|
||||
# easier.
|
||||
class ApplicationServicesHandler(object):
|
||||
|
||||
def __init__(self, hs, appservice_api, appservice_scheduler):
|
||||
def __init__(self, hs):
|
||||
self.store = hs.get_datastore()
|
||||
self.hs = hs
|
||||
self.appservice_api = appservice_api
|
||||
self.scheduler = appservice_scheduler
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
self.appservice_api = hs.get_application_service_api()
|
||||
self.scheduler = hs.get_application_service_scheduler()
|
||||
self.started_scheduler = False
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -169,8 +165,7 @@ class ApplicationServicesHandler(object):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def _is_unknown_user(self, user_id):
|
||||
user = UserID.from_string(user_id)
|
||||
if not self.hs.is_mine(user):
|
||||
if not self.is_mine_id(user_id):
|
||||
# we don't know if they are unknown or not since it isn't one of our
|
||||
# users. We can't poke ASes.
|
||||
defer.returnValue(False)
|
||||
|
|
|
@ -18,8 +18,9 @@ from twisted.internet import defer
|
|||
from ._base import BaseHandler
|
||||
from synapse.api.constants import LoginType
|
||||
from synapse.types import UserID
|
||||
from synapse.api.errors import AuthError, LoginError, Codes
|
||||
from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError
|
||||
from synapse.util.async import run_on_reactor
|
||||
from synapse.config.ldap import LDAPMode
|
||||
|
||||
from twisted.web.client import PartialDownloadError
|
||||
|
||||
|
@ -28,6 +29,12 @@ import bcrypt
|
|||
import pymacaroons
|
||||
import simplejson
|
||||
|
||||
try:
|
||||
import ldap3
|
||||
except ImportError:
|
||||
ldap3 = None
|
||||
pass
|
||||
|
||||
import synapse.util.stringutils as stringutils
|
||||
|
||||
|
||||
|
@ -38,6 +45,10 @@ class AuthHandler(BaseHandler):
|
|||
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
Args:
|
||||
hs (synapse.server.HomeServer):
|
||||
"""
|
||||
super(AuthHandler, self).__init__(hs)
|
||||
self.checkers = {
|
||||
LoginType.PASSWORD: self._check_password_auth,
|
||||
|
@ -50,19 +61,23 @@ class AuthHandler(BaseHandler):
|
|||
self.INVALID_TOKEN_HTTP_STATUS = 401
|
||||
|
||||
self.ldap_enabled = hs.config.ldap_enabled
|
||||
self.ldap_server = hs.config.ldap_server
|
||||
self.ldap_port = hs.config.ldap_port
|
||||
self.ldap_tls = hs.config.ldap_tls
|
||||
self.ldap_search_base = hs.config.ldap_search_base
|
||||
self.ldap_search_property = hs.config.ldap_search_property
|
||||
self.ldap_email_property = hs.config.ldap_email_property
|
||||
self.ldap_full_name_property = hs.config.ldap_full_name_property
|
||||
|
||||
if self.ldap_enabled is True:
|
||||
import ldap
|
||||
logger.info("Import ldap version: %s", ldap.__version__)
|
||||
if self.ldap_enabled:
|
||||
if not ldap3:
|
||||
raise RuntimeError(
|
||||
'Missing ldap3 library. This is required for LDAP Authentication.'
|
||||
)
|
||||
self.ldap_mode = hs.config.ldap_mode
|
||||
self.ldap_uri = hs.config.ldap_uri
|
||||
self.ldap_start_tls = hs.config.ldap_start_tls
|
||||
self.ldap_base = hs.config.ldap_base
|
||||
self.ldap_filter = hs.config.ldap_filter
|
||||
self.ldap_attributes = hs.config.ldap_attributes
|
||||
if self.ldap_mode == LDAPMode.SEARCH:
|
||||
self.ldap_bind_dn = hs.config.ldap_bind_dn
|
||||
self.ldap_bind_password = hs.config.ldap_bind_password
|
||||
|
||||
self.hs = hs # FIXME better possibility to access registrationHandler later?
|
||||
self.device_handler = hs.get_device_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_auth(self, flows, clientdict, clientip):
|
||||
|
@ -220,7 +235,6 @@ class AuthHandler(BaseHandler):
|
|||
sess = self._get_session_info(session_id)
|
||||
return sess.setdefault('serverdict', {}).get(key, default)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_password_auth(self, authdict, _):
|
||||
if "user" not in authdict or "password" not in authdict:
|
||||
raise LoginError(400, "", Codes.MISSING_PARAM)
|
||||
|
@ -230,11 +244,7 @@ class AuthHandler(BaseHandler):
|
|||
if not user_id.startswith('@'):
|
||||
user_id = UserID.create(user_id, self.hs.hostname).to_string()
|
||||
|
||||
if not (yield self._check_password(user_id, password)):
|
||||
logger.warn("Failed password login for user %s", user_id)
|
||||
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
||||
|
||||
defer.returnValue(user_id)
|
||||
return self._check_password(user_id, password)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_recaptcha(self, authdict, clientip):
|
||||
|
@ -270,7 +280,16 @@ class AuthHandler(BaseHandler):
|
|||
data = pde.response
|
||||
resp_body = simplejson.loads(data)
|
||||
|
||||
if 'success' in resp_body and resp_body['success']:
|
||||
if 'success' in resp_body:
|
||||
# Note that we do NOT check the hostname here: we explicitly
|
||||
# intend the CAPTCHA to be presented by whatever client the
|
||||
# user is using, we just care that they have completed a CAPTCHA.
|
||||
logger.info(
|
||||
"%s reCAPTCHA from hostname %s",
|
||||
"Successful" if resp_body['success'] else "Failed",
|
||||
resp_body.get('hostname')
|
||||
)
|
||||
if resp_body['success']:
|
||||
defer.returnValue(True)
|
||||
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
|
||||
|
||||
|
@ -338,67 +357,84 @@ class AuthHandler(BaseHandler):
|
|||
|
||||
return self.sessions[session_id]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def login_with_password(self, user_id, password):
|
||||
def validate_password_login(self, user_id, password):
|
||||
"""
|
||||
Authenticates the user with their username and password.
|
||||
|
||||
Used only by the v1 login API.
|
||||
|
||||
Args:
|
||||
user_id (str): User ID
|
||||
user_id (str): complete @user:id
|
||||
password (str): Password
|
||||
Returns:
|
||||
A tuple of:
|
||||
The user's ID.
|
||||
The access token for the user's session.
|
||||
The refresh token for the user's session.
|
||||
defer.Deferred: (str) canonical user id
|
||||
Raises:
|
||||
StoreError if there was a problem storing the token.
|
||||
StoreError if there was a problem accessing the database
|
||||
LoginError if there was an authentication problem.
|
||||
"""
|
||||
|
||||
if not (yield self._check_password(user_id, password)):
|
||||
logger.warn("Failed password login for user %s", user_id)
|
||||
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
||||
|
||||
logger.info("Logging in user %s", user_id)
|
||||
access_token = yield self.issue_access_token(user_id)
|
||||
refresh_token = yield self.issue_refresh_token(user_id)
|
||||
defer.returnValue((user_id, access_token, refresh_token))
|
||||
return self._check_password(user_id, password)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_login_tuple_for_user_id(self, user_id):
|
||||
def get_login_tuple_for_user_id(self, user_id, device_id=None,
|
||||
initial_display_name=None):
|
||||
"""
|
||||
Gets login tuple for the user with the given user ID.
|
||||
|
||||
Creates a new access/refresh token for the user.
|
||||
|
||||
The user is assumed to have been authenticated by some other
|
||||
machanism (e.g. CAS)
|
||||
machanism (e.g. CAS), and the user_id converted to the canonical case.
|
||||
|
||||
The device will be recorded in the table if it is not there already.
|
||||
|
||||
Args:
|
||||
user_id (str): User ID
|
||||
user_id (str): canonical User ID
|
||||
device_id (str|None): the device ID to associate with the tokens.
|
||||
None to leave the tokens unassociated with a device (deprecated:
|
||||
we should always have a device ID)
|
||||
initial_display_name (str): display name to associate with the
|
||||
device if it needs re-registering
|
||||
Returns:
|
||||
A tuple of:
|
||||
The user's ID.
|
||||
The access token for the user's session.
|
||||
The refresh token for the user's session.
|
||||
Raises:
|
||||
StoreError if there was a problem storing the token.
|
||||
LoginError if there was an authentication problem.
|
||||
"""
|
||||
user_id, ignored = yield self._find_user_id_and_pwd_hash(user_id)
|
||||
logger.info("Logging in user %s on device %s", user_id, device_id)
|
||||
access_token = yield self.issue_access_token(user_id, device_id)
|
||||
refresh_token = yield self.issue_refresh_token(user_id, device_id)
|
||||
|
||||
logger.info("Logging in user %s", user_id)
|
||||
access_token = yield self.issue_access_token(user_id)
|
||||
refresh_token = yield self.issue_refresh_token(user_id)
|
||||
defer.returnValue((user_id, access_token, refresh_token))
|
||||
# the device *should* have been registered before we got here; however,
|
||||
# it's possible we raced against a DELETE operation. The thing we
|
||||
# really don't want is active access_tokens without a record of the
|
||||
# device, so we double-check it here.
|
||||
if device_id is not None:
|
||||
yield self.device_handler.check_device_registered(
|
||||
user_id, device_id, initial_display_name
|
||||
)
|
||||
|
||||
defer.returnValue((access_token, refresh_token))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def does_user_exist(self, user_id):
|
||||
def check_user_exists(self, user_id):
|
||||
"""
|
||||
Checks to see if a user with the given id exists. Will check case
|
||||
insensitively, but return None if there are multiple inexact matches.
|
||||
|
||||
Args:
|
||||
(str) user_id: complete @user:id
|
||||
|
||||
Returns:
|
||||
defer.Deferred: (str) canonical_user_id, or None if zero or
|
||||
multiple matches
|
||||
"""
|
||||
try:
|
||||
yield self._find_user_id_and_pwd_hash(user_id)
|
||||
defer.returnValue(True)
|
||||
res = yield self._find_user_id_and_pwd_hash(user_id)
|
||||
defer.returnValue(res[0])
|
||||
except LoginError:
|
||||
defer.returnValue(False)
|
||||
defer.returnValue(None)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _find_user_id_and_pwd_hash(self, user_id):
|
||||
|
@ -428,84 +464,232 @@ class AuthHandler(BaseHandler):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def _check_password(self, user_id, password):
|
||||
"""
|
||||
"""Authenticate a user against the LDAP and local databases.
|
||||
|
||||
user_id is checked case insensitively against the local database, but
|
||||
will throw if there are multiple inexact matches.
|
||||
|
||||
Args:
|
||||
user_id (str): complete @user:id
|
||||
Returns:
|
||||
True if the user_id successfully authenticated
|
||||
(str) the canonical_user_id
|
||||
Raises:
|
||||
LoginError if the password was incorrect
|
||||
"""
|
||||
valid_ldap = yield self._check_ldap_password(user_id, password)
|
||||
if valid_ldap:
|
||||
defer.returnValue(True)
|
||||
defer.returnValue(user_id)
|
||||
|
||||
valid_local_password = yield self._check_local_password(user_id, password)
|
||||
if valid_local_password:
|
||||
defer.returnValue(True)
|
||||
|
||||
defer.returnValue(False)
|
||||
result = yield self._check_local_password(user_id, password)
|
||||
defer.returnValue(result)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_local_password(self, user_id, password):
|
||||
try:
|
||||
"""Authenticate a user against the local password database.
|
||||
|
||||
user_id is checked case insensitively, but will throw if there are
|
||||
multiple inexact matches.
|
||||
|
||||
Args:
|
||||
user_id (str): complete @user:id
|
||||
Returns:
|
||||
(str) the canonical_user_id
|
||||
Raises:
|
||||
LoginError if the password was incorrect
|
||||
"""
|
||||
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
|
||||
defer.returnValue(self.validate_hash(password, password_hash))
|
||||
except LoginError:
|
||||
defer.returnValue(False)
|
||||
result = self.validate_hash(password, password_hash)
|
||||
if not result:
|
||||
logger.warn("Failed password login for user %s", user_id)
|
||||
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
||||
defer.returnValue(user_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_ldap_password(self, user_id, password):
|
||||
if not self.ldap_enabled:
|
||||
logger.debug("LDAP not configured")
|
||||
""" Attempt to authenticate a user against an LDAP Server
|
||||
and register an account if none exists.
|
||||
|
||||
Returns:
|
||||
True if authentication against LDAP was successful
|
||||
"""
|
||||
|
||||
if not ldap3 or not self.ldap_enabled:
|
||||
defer.returnValue(False)
|
||||
|
||||
import ldap
|
||||
|
||||
logger.info("Authenticating %s with LDAP" % user_id)
|
||||
try:
|
||||
ldap_url = "%s:%s" % (self.ldap_server, self.ldap_port)
|
||||
logger.debug("Connecting LDAP server at %s" % ldap_url)
|
||||
l = ldap.initialize(ldap_url)
|
||||
if self.ldap_tls:
|
||||
logger.debug("Initiating TLS")
|
||||
self._connection.start_tls_s()
|
||||
|
||||
local_name = UserID.from_string(user_id).localpart
|
||||
|
||||
dn = "%s=%s, %s" % (
|
||||
self.ldap_search_property,
|
||||
local_name,
|
||||
self.ldap_search_base)
|
||||
logger.debug("DN for LDAP authentication: %s" % dn)
|
||||
|
||||
l.simple_bind_s(dn.encode('utf-8'), password.encode('utf-8'))
|
||||
|
||||
if not (yield self.does_user_exist(user_id)):
|
||||
handler = self.hs.get_handlers().registration_handler
|
||||
user_id, access_token = (
|
||||
yield handler.register(localpart=local_name)
|
||||
if self.ldap_mode not in LDAPMode.LIST:
|
||||
raise RuntimeError(
|
||||
'Invalid ldap mode specified: {mode}'.format(
|
||||
mode=self.ldap_mode
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
server = ldap3.Server(self.ldap_uri)
|
||||
logger.debug(
|
||||
"Attempting ldap connection with %s",
|
||||
self.ldap_uri
|
||||
)
|
||||
|
||||
localpart = UserID.from_string(user_id).localpart
|
||||
if self.ldap_mode == LDAPMode.SIMPLE:
|
||||
# bind with the the local users ldap credentials
|
||||
bind_dn = "{prop}={value},{base}".format(
|
||||
prop=self.ldap_attributes['uid'],
|
||||
value=localpart,
|
||||
base=self.ldap_base
|
||||
)
|
||||
conn = ldap3.Connection(server, bind_dn, password)
|
||||
logger.debug(
|
||||
"Established ldap connection in simple mode: %s",
|
||||
conn
|
||||
)
|
||||
|
||||
if self.ldap_start_tls:
|
||||
conn.start_tls()
|
||||
logger.debug(
|
||||
"Upgraded ldap connection in simple mode through StartTLS: %s",
|
||||
conn
|
||||
)
|
||||
|
||||
conn.bind()
|
||||
|
||||
elif self.ldap_mode == LDAPMode.SEARCH:
|
||||
# connect with preconfigured credentials and search for local user
|
||||
conn = ldap3.Connection(
|
||||
server,
|
||||
self.ldap_bind_dn,
|
||||
self.ldap_bind_password
|
||||
)
|
||||
logger.debug(
|
||||
"Established ldap connection in search mode: %s",
|
||||
conn
|
||||
)
|
||||
|
||||
if self.ldap_start_tls:
|
||||
conn.start_tls()
|
||||
logger.debug(
|
||||
"Upgraded ldap connection in search mode through StartTLS: %s",
|
||||
conn
|
||||
)
|
||||
|
||||
conn.bind()
|
||||
|
||||
# find matching dn
|
||||
query = "({prop}={value})".format(
|
||||
prop=self.ldap_attributes['uid'],
|
||||
value=localpart
|
||||
)
|
||||
if self.ldap_filter:
|
||||
query = "(&{query}{filter})".format(
|
||||
query=query,
|
||||
filter=self.ldap_filter
|
||||
)
|
||||
logger.debug("ldap search filter: %s", query)
|
||||
result = conn.search(self.ldap_base, query)
|
||||
|
||||
if result and len(conn.response) == 1:
|
||||
# found exactly one result
|
||||
user_dn = conn.response[0]['dn']
|
||||
logger.debug('ldap search found dn: %s', user_dn)
|
||||
|
||||
# unbind and reconnect, rebind with found dn
|
||||
conn.unbind()
|
||||
conn = ldap3.Connection(
|
||||
server,
|
||||
user_dn,
|
||||
password,
|
||||
auto_bind=True
|
||||
)
|
||||
else:
|
||||
# found 0 or > 1 results, abort!
|
||||
logger.warn(
|
||||
"ldap search returned unexpected (%d!=1) amount of results",
|
||||
len(conn.response)
|
||||
)
|
||||
defer.returnValue(False)
|
||||
|
||||
logger.info(
|
||||
"User authenticated against ldap server: %s",
|
||||
conn
|
||||
)
|
||||
|
||||
# check for existing account, if none exists, create one
|
||||
if not (yield self.check_user_exists(user_id)):
|
||||
# query user metadata for account creation
|
||||
query = "({prop}={value})".format(
|
||||
prop=self.ldap_attributes['uid'],
|
||||
value=localpart
|
||||
)
|
||||
|
||||
if self.ldap_mode == LDAPMode.SEARCH and self.ldap_filter:
|
||||
query = "(&{filter}{user_filter})".format(
|
||||
filter=query,
|
||||
user_filter=self.ldap_filter
|
||||
)
|
||||
logger.debug("ldap registration filter: %s", query)
|
||||
|
||||
result = conn.search(
|
||||
search_base=self.ldap_base,
|
||||
search_filter=query,
|
||||
attributes=[
|
||||
self.ldap_attributes['name'],
|
||||
self.ldap_attributes['mail']
|
||||
]
|
||||
)
|
||||
|
||||
if len(conn.response) == 1:
|
||||
attrs = conn.response[0]['attributes']
|
||||
mail = attrs[self.ldap_attributes['mail']][0]
|
||||
name = attrs[self.ldap_attributes['name']][0]
|
||||
|
||||
# create account
|
||||
registration_handler = self.hs.get_handlers().registration_handler
|
||||
user_id, access_token = (
|
||||
yield registration_handler.register(localpart=localpart)
|
||||
)
|
||||
|
||||
# TODO: bind email, set displayname with data from ldap directory
|
||||
|
||||
logger.info(
|
||||
"ldap registration successful: %d: %s (%s, %)",
|
||||
user_id,
|
||||
localpart,
|
||||
name,
|
||||
mail
|
||||
)
|
||||
else:
|
||||
logger.warn(
|
||||
"ldap registration failed: unexpected (%d!=1) amount of results",
|
||||
len(result)
|
||||
)
|
||||
defer.returnValue(False)
|
||||
|
||||
defer.returnValue(True)
|
||||
except ldap.LDAPError, e:
|
||||
logger.warn("LDAP error: %s", e)
|
||||
except ldap3.core.exceptions.LDAPException as e:
|
||||
logger.warn("Error during ldap authentication: %s", e)
|
||||
defer.returnValue(False)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def issue_access_token(self, user_id):
|
||||
def issue_access_token(self, user_id, device_id=None):
|
||||
access_token = self.generate_access_token(user_id)
|
||||
yield self.store.add_access_token_to_user(user_id, access_token)
|
||||
yield self.store.add_access_token_to_user(user_id, access_token,
|
||||
device_id)
|
||||
defer.returnValue(access_token)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def issue_refresh_token(self, user_id):
|
||||
def issue_refresh_token(self, user_id, device_id=None):
|
||||
refresh_token = self.generate_refresh_token(user_id)
|
||||
yield self.store.add_refresh_token_to_user(user_id, refresh_token)
|
||||
yield self.store.add_refresh_token_to_user(user_id, refresh_token,
|
||||
device_id)
|
||||
defer.returnValue(refresh_token)
|
||||
|
||||
def generate_access_token(self, user_id, extra_caveats=None):
|
||||
def generate_access_token(self, user_id, extra_caveats=None,
|
||||
duration_in_ms=(60 * 60 * 1000)):
|
||||
extra_caveats = extra_caveats or []
|
||||
macaroon = self._generate_base_macaroon(user_id)
|
||||
macaroon.add_first_party_caveat("type = access")
|
||||
now = self.hs.get_clock().time_msec()
|
||||
expiry = now + (60 * 60 * 1000)
|
||||
expiry = now + duration_in_ms
|
||||
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
||||
for caveat in extra_caveats:
|
||||
macaroon.add_first_party_caveat(caveat)
|
||||
|
@ -529,14 +713,20 @@ class AuthHandler(BaseHandler):
|
|||
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
||||
return macaroon.serialize()
|
||||
|
||||
def generate_delete_pusher_token(self, user_id):
|
||||
macaroon = self._generate_base_macaroon(user_id)
|
||||
macaroon.add_first_party_caveat("type = delete_pusher")
|
||||
return macaroon.serialize()
|
||||
|
||||
def validate_short_term_login_token_and_get_user_id(self, login_token):
|
||||
auth_api = self.hs.get_auth()
|
||||
try:
|
||||
macaroon = pymacaroons.Macaroon.deserialize(login_token)
|
||||
auth_api = self.hs.get_auth()
|
||||
auth_api.validate_macaroon(macaroon, "login", True)
|
||||
return self.get_user_from_macaroon(macaroon)
|
||||
except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
|
||||
raise AuthError(401, "Invalid token", errcode=Codes.UNKNOWN_TOKEN)
|
||||
user_id = auth_api.get_user_id_from_macaroon(macaroon)
|
||||
auth_api.validate_macaroon(macaroon, "login", True, user_id)
|
||||
return user_id
|
||||
except Exception:
|
||||
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
|
||||
|
||||
def _generate_base_macaroon(self, user_id):
|
||||
macaroon = pymacaroons.Macaroon(
|
||||
|
@ -547,23 +737,18 @@ class AuthHandler(BaseHandler):
|
|||
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
|
||||
return macaroon
|
||||
|
||||
def get_user_from_macaroon(self, macaroon):
|
||||
user_prefix = "user_id = "
|
||||
for caveat in macaroon.caveats:
|
||||
if caveat.caveat_id.startswith(user_prefix):
|
||||
return caveat.caveat_id[len(user_prefix):]
|
||||
raise AuthError(
|
||||
self.INVALID_TOKEN_HTTP_STATUS, "No user_id found in token",
|
||||
errcode=Codes.UNKNOWN_TOKEN
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def set_password(self, user_id, newpassword, requester=None):
|
||||
password_hash = self.hash(newpassword)
|
||||
|
||||
except_access_token_ids = [requester.access_token_id] if requester else []
|
||||
|
||||
try:
|
||||
yield self.store.user_set_password_hash(user_id, password_hash)
|
||||
except StoreError as e:
|
||||
if e.code == 404:
|
||||
raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
|
||||
raise e
|
||||
yield self.store.user_delete_access_tokens(
|
||||
user_id, except_access_token_ids
|
||||
)
|
||||
|
@ -603,7 +788,8 @@ class AuthHandler(BaseHandler):
|
|||
Returns:
|
||||
Hashed password (str).
|
||||
"""
|
||||
return bcrypt.hashpw(password, bcrypt.gensalt(self.bcrypt_rounds))
|
||||
return bcrypt.hashpw(password + self.hs.config.password_pepper,
|
||||
bcrypt.gensalt(self.bcrypt_rounds))
|
||||
|
||||
def validate_hash(self, password, stored_hash):
|
||||
"""Validates that self.hash(password) == stored_hash.
|
||||
|
@ -616,6 +802,7 @@ class AuthHandler(BaseHandler):
|
|||
Whether self.hash(password) == stored_hash (bool).
|
||||
"""
|
||||
if stored_hash:
|
||||
return bcrypt.hashpw(password, stored_hash) == stored_hash
|
||||
return bcrypt.hashpw(password + self.hs.config.password_pepper,
|
||||
stored_hash.encode('utf-8')) == stored_hash
|
||||
else:
|
||||
return False
|
||||
|
|
181
synapse/handlers/device.py
Normal file
181
synapse/handlers/device.py
Normal file
|
@ -0,0 +1,181 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.api import errors
|
||||
from synapse.util import stringutils
|
||||
from twisted.internet import defer
|
||||
from ._base import BaseHandler
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DeviceHandler(BaseHandler):
|
||||
def __init__(self, hs):
|
||||
super(DeviceHandler, self).__init__(hs)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_device_registered(self, user_id, device_id,
|
||||
initial_device_display_name=None):
|
||||
"""
|
||||
If the given device has not been registered, register it with the
|
||||
supplied display name.
|
||||
|
||||
If no device_id is supplied, we make one up.
|
||||
|
||||
Args:
|
||||
user_id (str): @user:id
|
||||
device_id (str | None): device id supplied by client
|
||||
initial_device_display_name (str | None): device display name from
|
||||
client
|
||||
Returns:
|
||||
str: device id (generated if none was supplied)
|
||||
"""
|
||||
if device_id is not None:
|
||||
yield self.store.store_device(
|
||||
user_id=user_id,
|
||||
device_id=device_id,
|
||||
initial_device_display_name=initial_device_display_name,
|
||||
ignore_if_known=True,
|
||||
)
|
||||
defer.returnValue(device_id)
|
||||
|
||||
# if the device id is not specified, we'll autogen one, but loop a few
|
||||
# times in case of a clash.
|
||||
attempts = 0
|
||||
while attempts < 5:
|
||||
try:
|
||||
device_id = stringutils.random_string_with_symbols(16)
|
||||
yield self.store.store_device(
|
||||
user_id=user_id,
|
||||
device_id=device_id,
|
||||
initial_device_display_name=initial_device_display_name,
|
||||
ignore_if_known=False,
|
||||
)
|
||||
defer.returnValue(device_id)
|
||||
except errors.StoreError:
|
||||
attempts += 1
|
||||
|
||||
raise errors.StoreError(500, "Couldn't generate a device ID.")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_devices_by_user(self, user_id):
|
||||
"""
|
||||
Retrieve the given user's devices
|
||||
|
||||
Args:
|
||||
user_id (str):
|
||||
Returns:
|
||||
defer.Deferred: list[dict[str, X]]: info on each device
|
||||
"""
|
||||
|
||||
device_map = yield self.store.get_devices_by_user(user_id)
|
||||
|
||||
ips = yield self.store.get_last_client_ip_by_device(
|
||||
devices=((user_id, device_id) for device_id in device_map.keys())
|
||||
)
|
||||
|
||||
devices = device_map.values()
|
||||
for device in devices:
|
||||
_update_device_from_client_ips(device, ips)
|
||||
|
||||
defer.returnValue(devices)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_device(self, user_id, device_id):
|
||||
""" Retrieve the given device
|
||||
|
||||
Args:
|
||||
user_id (str):
|
||||
device_id (str):
|
||||
|
||||
Returns:
|
||||
defer.Deferred: dict[str, X]: info on the device
|
||||
Raises:
|
||||
errors.NotFoundError: if the device was not found
|
||||
"""
|
||||
try:
|
||||
device = yield self.store.get_device(user_id, device_id)
|
||||
except errors.StoreError:
|
||||
raise errors.NotFoundError
|
||||
ips = yield self.store.get_last_client_ip_by_device(
|
||||
devices=((user_id, device_id),)
|
||||
)
|
||||
_update_device_from_client_ips(device, ips)
|
||||
defer.returnValue(device)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_device(self, user_id, device_id):
|
||||
""" Delete the given device
|
||||
|
||||
Args:
|
||||
user_id (str):
|
||||
device_id (str):
|
||||
|
||||
Returns:
|
||||
defer.Deferred:
|
||||
"""
|
||||
|
||||
try:
|
||||
yield self.store.delete_device(user_id, device_id)
|
||||
except errors.StoreError, e:
|
||||
if e.code == 404:
|
||||
# no match
|
||||
pass
|
||||
else:
|
||||
raise
|
||||
|
||||
yield self.store.user_delete_access_tokens(
|
||||
user_id, device_id=device_id,
|
||||
delete_refresh_tokens=True,
|
||||
)
|
||||
|
||||
yield self.store.delete_e2e_keys_by_device(
|
||||
user_id=user_id, device_id=device_id
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def update_device(self, user_id, device_id, content):
|
||||
""" Update the given device
|
||||
|
||||
Args:
|
||||
user_id (str):
|
||||
device_id (str):
|
||||
content (dict): body of update request
|
||||
|
||||
Returns:
|
||||
defer.Deferred:
|
||||
"""
|
||||
|
||||
try:
|
||||
yield self.store.update_device(
|
||||
user_id,
|
||||
device_id,
|
||||
new_display_name=content.get("display_name")
|
||||
)
|
||||
except errors.StoreError, e:
|
||||
if e.code == 404:
|
||||
raise errors.NotFoundError()
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
def _update_device_from_client_ips(device, client_ips):
|
||||
ip = client_ips.get((device["user_id"], device["device_id"]), {})
|
||||
device.update({
|
||||
"last_seen_ts": ip.get("last_seen"),
|
||||
"last_seen_ip": ip.get("ip"),
|
||||
})
|
|
@ -33,6 +33,7 @@ class DirectoryHandler(BaseHandler):
|
|||
super(DirectoryHandler, self).__init__(hs)
|
||||
|
||||
self.state = hs.get_state_handler()
|
||||
self.appservice_handler = hs.get_application_service_handler()
|
||||
|
||||
self.federation = hs.get_replication_layer()
|
||||
self.federation.register_query_handler(
|
||||
|
@ -281,7 +282,7 @@ class DirectoryHandler(BaseHandler):
|
|||
)
|
||||
if not result:
|
||||
# Query AS to see if it exists
|
||||
as_handler = self.hs.get_handlers().appservice_handler
|
||||
as_handler = self.appservice_handler
|
||||
result = yield as_handler.query_room_alias_exists(room_alias)
|
||||
defer.returnValue(result)
|
||||
|
||||
|
|
139
synapse/handlers/e2e_keys.py
Normal file
139
synapse/handlers/e2e_keys.py
Normal file
|
@ -0,0 +1,139 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import collections
|
||||
import json
|
||||
import logging
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api import errors
|
||||
import synapse.types
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class E2eKeysHandler(object):
|
||||
def __init__(self, hs):
|
||||
self.store = hs.get_datastore()
|
||||
self.federation = hs.get_replication_layer()
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
self.server_name = hs.hostname
|
||||
|
||||
# doesn't really work as part of the generic query API, because the
|
||||
# query request requires an object POST, but we abuse the
|
||||
# "query handler" interface.
|
||||
self.federation.register_query_handler(
|
||||
"client_keys", self.on_federation_query_client_keys
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def query_devices(self, query_body):
|
||||
""" Handle a device key query from a client
|
||||
|
||||
{
|
||||
"device_keys": {
|
||||
"<user_id>": ["<device_id>"]
|
||||
}
|
||||
}
|
||||
->
|
||||
{
|
||||
"device_keys": {
|
||||
"<user_id>": {
|
||||
"<device_id>": {
|
||||
...
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
device_keys_query = query_body.get("device_keys", {})
|
||||
|
||||
# separate users by domain.
|
||||
# make a map from domain to user_id to device_ids
|
||||
queries_by_domain = collections.defaultdict(dict)
|
||||
for user_id, device_ids in device_keys_query.items():
|
||||
user = synapse.types.UserID.from_string(user_id)
|
||||
queries_by_domain[user.domain][user_id] = device_ids
|
||||
|
||||
# do the queries
|
||||
# TODO: do these in parallel
|
||||
results = {}
|
||||
for destination, destination_query in queries_by_domain.items():
|
||||
if destination == self.server_name:
|
||||
res = yield self.query_local_devices(destination_query)
|
||||
else:
|
||||
res = yield self.federation.query_client_keys(
|
||||
destination, {"device_keys": destination_query}
|
||||
)
|
||||
res = res["device_keys"]
|
||||
for user_id, keys in res.items():
|
||||
if user_id in destination_query:
|
||||
results[user_id] = keys
|
||||
|
||||
defer.returnValue((200, {"device_keys": results}))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def query_local_devices(self, query):
|
||||
"""Get E2E device keys for local users
|
||||
|
||||
Args:
|
||||
query (dict[string, list[string]|None): map from user_id to a list
|
||||
of devices to query (None for all devices)
|
||||
|
||||
Returns:
|
||||
defer.Deferred: (resolves to dict[string, dict[string, dict]]):
|
||||
map from user_id -> device_id -> device details
|
||||
"""
|
||||
local_query = []
|
||||
|
||||
result_dict = {}
|
||||
for user_id, device_ids in query.items():
|
||||
if not self.is_mine_id(user_id):
|
||||
logger.warning("Request for keys for non-local user %s",
|
||||
user_id)
|
||||
raise errors.SynapseError(400, "Not a user here")
|
||||
|
||||
if not device_ids:
|
||||
local_query.append((user_id, None))
|
||||
else:
|
||||
for device_id in device_ids:
|
||||
local_query.append((user_id, device_id))
|
||||
|
||||
# make sure that each queried user appears in the result dict
|
||||
result_dict[user_id] = {}
|
||||
|
||||
results = yield self.store.get_e2e_device_keys(local_query)
|
||||
|
||||
# Build the result structure, un-jsonify the results, and add the
|
||||
# "unsigned" section
|
||||
for user_id, device_keys in results.items():
|
||||
for device_id, device_info in device_keys.items():
|
||||
r = json.loads(device_info["key_json"])
|
||||
r["unsigned"] = {}
|
||||
display_name = device_info["device_display_name"]
|
||||
if display_name is not None:
|
||||
r["unsigned"]["device_display_name"] = display_name
|
||||
result_dict[user_id][device_id] = r
|
||||
|
||||
defer.returnValue(result_dict)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_federation_query_client_keys(self, query_body):
|
||||
""" Handle a device key query from a federated server
|
||||
"""
|
||||
device_keys_query = query_body.get("device_keys", {})
|
||||
res = yield self.query_local_devices(device_keys_query)
|
||||
defer.returnValue({"device_keys": res})
|
|
@ -66,10 +66,6 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
self.hs = hs
|
||||
|
||||
self.distributor.observe("user_joined_room", self.user_joined_room)
|
||||
|
||||
self.waiting_for_join_list = {}
|
||||
|
||||
self.store = hs.get_datastore()
|
||||
self.replication_layer = hs.get_replication_layer()
|
||||
self.state_handler = hs.get_state_handler()
|
||||
|
@ -128,7 +124,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
try:
|
||||
event_stream_id, max_stream_id = yield self._persist_auth_tree(
|
||||
auth_chain, state, event
|
||||
origin, auth_chain, state, event
|
||||
)
|
||||
except AuthError as e:
|
||||
raise FederationError(
|
||||
|
@ -253,7 +249,7 @@ class FederationHandler(BaseHandler):
|
|||
if ev.type != EventTypes.Member:
|
||||
continue
|
||||
try:
|
||||
domain = UserID.from_string(ev.state_key).domain
|
||||
domain = get_domain_from_id(ev.state_key)
|
||||
except:
|
||||
continue
|
||||
|
||||
|
@ -339,16 +335,35 @@ class FederationHandler(BaseHandler):
|
|||
state_events.update({s.event_id: s for s in state})
|
||||
events_to_state[e_id] = state
|
||||
|
||||
seen_events = yield self.store.have_events(
|
||||
set(auth_events.keys()) | set(state_events.keys())
|
||||
)
|
||||
|
||||
all_events = events + state_events.values() + auth_events.values()
|
||||
required_auth = set(
|
||||
a_id for event in all_events for a_id, _ in event.auth_events
|
||||
a_id
|
||||
for event in events + state_events.values() + auth_events.values()
|
||||
for a_id, _ in event.auth_events
|
||||
)
|
||||
auth_events.update({
|
||||
e_id: event_map[e_id] for e_id in required_auth if e_id in event_map
|
||||
})
|
||||
missing_auth = required_auth - set(auth_events)
|
||||
failed_to_fetch = set()
|
||||
|
||||
# Try and fetch any missing auth events from both DB and remote servers.
|
||||
# We repeatedly do this until we stop finding new auth events.
|
||||
while missing_auth - failed_to_fetch:
|
||||
logger.info("Missing auth for backfill: %r", missing_auth)
|
||||
ret_events = yield self.store.get_events(missing_auth - failed_to_fetch)
|
||||
auth_events.update(ret_events)
|
||||
|
||||
required_auth.update(
|
||||
a_id for event in ret_events.values() for a_id, _ in event.auth_events
|
||||
)
|
||||
missing_auth = required_auth - set(auth_events)
|
||||
|
||||
if missing_auth - failed_to_fetch:
|
||||
logger.info(
|
||||
"Fetching missing auth for backfill: %r",
|
||||
missing_auth - failed_to_fetch
|
||||
)
|
||||
|
||||
missing_auth = required_auth - set(auth_events)
|
||||
results = yield defer.gatherResults(
|
||||
[
|
||||
self.replication_layer.get_pdu(
|
||||
|
@ -357,11 +372,21 @@ class FederationHandler(BaseHandler):
|
|||
outlier=True,
|
||||
timeout=10000,
|
||||
)
|
||||
for event_id in missing_auth
|
||||
for event_id in missing_auth - failed_to_fetch
|
||||
],
|
||||
consumeErrors=True
|
||||
).addErrback(unwrapFirstError)
|
||||
auth_events.update({a.event_id: a for a in results})
|
||||
required_auth.update(
|
||||
a_id for event in results for a_id, _ in event.auth_events
|
||||
)
|
||||
missing_auth = required_auth - set(auth_events)
|
||||
|
||||
failed_to_fetch = missing_auth - set(auth_events)
|
||||
|
||||
seen_events = yield self.store.have_events(
|
||||
set(auth_events.keys()) | set(state_events.keys())
|
||||
)
|
||||
|
||||
ev_infos = []
|
||||
for a in auth_events.values():
|
||||
|
@ -374,6 +399,7 @@ class FederationHandler(BaseHandler):
|
|||
(auth_events[a_id].type, auth_events[a_id].state_key):
|
||||
auth_events[a_id]
|
||||
for a_id, _ in a.auth_events
|
||||
if a_id in auth_events
|
||||
}
|
||||
})
|
||||
|
||||
|
@ -385,6 +411,7 @@ class FederationHandler(BaseHandler):
|
|||
(auth_events[a_id].type, auth_events[a_id].state_key):
|
||||
auth_events[a_id]
|
||||
for a_id, _ in event_map[e_id].auth_events
|
||||
if a_id in auth_events
|
||||
}
|
||||
})
|
||||
|
||||
|
@ -403,7 +430,7 @@ class FederationHandler(BaseHandler):
|
|||
# previous to work out the state.
|
||||
# TODO: We can probably do something more clever here.
|
||||
yield self._handle_new_event(
|
||||
dest, event
|
||||
dest, event, backfilled=True,
|
||||
)
|
||||
|
||||
defer.returnValue(events)
|
||||
|
@ -639,7 +666,7 @@ class FederationHandler(BaseHandler):
|
|||
pass
|
||||
|
||||
event_stream_id, max_stream_id = yield self._persist_auth_tree(
|
||||
auth_chain, state, event
|
||||
origin, auth_chain, state, event
|
||||
)
|
||||
|
||||
with PreserveLoggingContext():
|
||||
|
@ -690,7 +717,9 @@ class FederationHandler(BaseHandler):
|
|||
logger.warn("Failed to create join %r because %s", event, e)
|
||||
raise e
|
||||
|
||||
self.auth.check(event, auth_events=context.current_state)
|
||||
# The remote hasn't signed it yet, obviously. We'll do the full checks
|
||||
# when we get the event back in `on_send_join_request`
|
||||
self.auth.check(event, auth_events=context.current_state, do_sig_check=False)
|
||||
|
||||
defer.returnValue(event)
|
||||
|
||||
|
@ -920,7 +949,9 @@ class FederationHandler(BaseHandler):
|
|||
)
|
||||
|
||||
try:
|
||||
self.auth.check(event, auth_events=context.current_state)
|
||||
# The remote hasn't signed it yet, obviously. We'll do the full checks
|
||||
# when we get the event back in `on_send_leave_request`
|
||||
self.auth.check(event, auth_events=context.current_state, do_sig_check=False)
|
||||
except AuthError as e:
|
||||
logger.warn("Failed to create new leave %r because %s", event, e)
|
||||
raise e
|
||||
|
@ -989,14 +1020,9 @@ class FederationHandler(BaseHandler):
|
|||
defer.returnValue(None)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_for_pdu(self, origin, room_id, event_id, do_auth=True):
|
||||
def get_state_for_pdu(self, room_id, event_id):
|
||||
yield run_on_reactor()
|
||||
|
||||
if do_auth:
|
||||
in_room = yield self.auth.check_host_in_room(room_id, origin)
|
||||
if not in_room:
|
||||
raise AuthError(403, "Host not in room.")
|
||||
|
||||
state_groups = yield self.store.get_state_groups(
|
||||
room_id, [event_id]
|
||||
)
|
||||
|
@ -1020,6 +1046,9 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
res = results.values()
|
||||
for event in res:
|
||||
# We sign these again because there was a bug where we
|
||||
# incorrectly signed things the first time round
|
||||
if self.hs.is_mine_id(event.event_id):
|
||||
event.signatures.update(
|
||||
compute_event_signature(
|
||||
event,
|
||||
|
@ -1064,6 +1093,7 @@ class FederationHandler(BaseHandler):
|
|||
)
|
||||
|
||||
if event:
|
||||
if self.hs.is_mine_id(event.event_id):
|
||||
# FIXME: This is a temporary work around where we occasionally
|
||||
# return events slightly differently than when they were
|
||||
# originally signed
|
||||
|
@ -1083,6 +1113,12 @@ class FederationHandler(BaseHandler):
|
|||
if not in_room:
|
||||
raise AuthError(403, "Host not in room.")
|
||||
|
||||
events = yield self._filter_events_for_server(
|
||||
origin, event.room_id, [event]
|
||||
)
|
||||
|
||||
event = events[0]
|
||||
|
||||
defer.returnValue(event)
|
||||
else:
|
||||
defer.returnValue(None)
|
||||
|
@ -1091,15 +1127,6 @@ class FederationHandler(BaseHandler):
|
|||
def get_min_depth_for_context(self, context):
|
||||
return self.store.get_min_depth(context)
|
||||
|
||||
@log_function
|
||||
def user_joined_room(self, user, room_id):
|
||||
waiters = self.waiting_for_join_list.get(
|
||||
(user.to_string(), room_id),
|
||||
[]
|
||||
)
|
||||
while waiters:
|
||||
waiters.pop().callback(None)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def _handle_new_event(self, origin, event, state=None, auth_events=None,
|
||||
|
@ -1122,6 +1149,7 @@ class FederationHandler(BaseHandler):
|
|||
backfilled=backfilled,
|
||||
)
|
||||
|
||||
if not backfilled:
|
||||
# this intentionally does not yield: we don't care about the result
|
||||
# and don't need to wait for it.
|
||||
preserve_fn(self.hs.get_pusherpool().on_new_notifications)(
|
||||
|
@ -1158,11 +1186,19 @@ class FederationHandler(BaseHandler):
|
|||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _persist_auth_tree(self, auth_events, state, event):
|
||||
def _persist_auth_tree(self, origin, auth_events, state, event):
|
||||
"""Checks the auth chain is valid (and passes auth checks) for the
|
||||
state and event. Then persists the auth chain and state atomically.
|
||||
Persists the event seperately.
|
||||
|
||||
Will attempt to fetch missing auth events.
|
||||
|
||||
Args:
|
||||
origin (str): Where the events came from
|
||||
auth_events (list)
|
||||
state (list)
|
||||
event (Event)
|
||||
|
||||
Returns:
|
||||
2-tuple of (event_stream_id, max_stream_id) from the persist_event
|
||||
call for `event`
|
||||
|
@ -1175,7 +1211,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
event_map = {
|
||||
e.event_id: e
|
||||
for e in auth_events
|
||||
for e in itertools.chain(auth_events, state, [event])
|
||||
}
|
||||
|
||||
create_event = None
|
||||
|
@ -1184,10 +1220,29 @@ class FederationHandler(BaseHandler):
|
|||
create_event = e
|
||||
break
|
||||
|
||||
missing_auth_events = set()
|
||||
for e in itertools.chain(auth_events, state, [event]):
|
||||
for e_id, _ in e.auth_events:
|
||||
if e_id not in event_map:
|
||||
missing_auth_events.add(e_id)
|
||||
|
||||
for e_id in missing_auth_events:
|
||||
m_ev = yield self.replication_layer.get_pdu(
|
||||
[origin],
|
||||
e_id,
|
||||
outlier=True,
|
||||
timeout=10000,
|
||||
)
|
||||
if m_ev and m_ev.event_id == e_id:
|
||||
event_map[e_id] = m_ev
|
||||
else:
|
||||
logger.info("Failed to find auth event %r", e_id)
|
||||
|
||||
for e in itertools.chain(auth_events, state, [event]):
|
||||
auth_for_e = {
|
||||
(event_map[e_id].type, event_map[e_id].state_key): event_map[e_id]
|
||||
for e_id, _ in e.auth_events
|
||||
if e_id in event_map
|
||||
}
|
||||
if create_event:
|
||||
auth_for_e[(EventTypes.Create, "")] = create_event
|
||||
|
@ -1421,7 +1476,7 @@ class FederationHandler(BaseHandler):
|
|||
local_view = dict(auth_events)
|
||||
remote_view = dict(auth_events)
|
||||
remote_view.update({
|
||||
(d.type, d.state_key): d for d in different_events
|
||||
(d.type, d.state_key): d for d in different_events if d
|
||||
})
|
||||
|
||||
new_state, prev_state = self.state_handler.resolve_events(
|
||||
|
|
|
@ -21,7 +21,7 @@ from synapse.api.errors import (
|
|||
)
|
||||
from ._base import BaseHandler
|
||||
from synapse.util.async import run_on_reactor
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.api.errors import SynapseError, Codes
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
@ -41,6 +41,20 @@ class IdentityHandler(BaseHandler):
|
|||
hs.config.use_insecure_ssl_client_just_for_testing_do_not_use
|
||||
)
|
||||
|
||||
def _should_trust_id_server(self, id_server):
|
||||
if id_server not in self.trusted_id_servers:
|
||||
if self.trust_any_id_server_just_for_testing_do_not_use:
|
||||
logger.warn(
|
||||
"Trusting untrustworthy ID server %r even though it isn't"
|
||||
" in the trusted id list for testing because"
|
||||
" 'use_insecure_ssl_client_just_for_testing_do_not_use'"
|
||||
" is set in the config",
|
||||
id_server,
|
||||
)
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def threepid_from_creds(self, creds):
|
||||
yield run_on_reactor()
|
||||
|
@ -59,18 +73,11 @@ class IdentityHandler(BaseHandler):
|
|||
else:
|
||||
raise SynapseError(400, "No client_secret in creds")
|
||||
|
||||
if id_server not in self.trusted_id_servers:
|
||||
if self.trust_any_id_server_just_for_testing_do_not_use:
|
||||
if not self._should_trust_id_server(id_server):
|
||||
logger.warn(
|
||||
"Trusting untrustworthy ID server %r even though it isn't"
|
||||
" in the trusted id list for testing because"
|
||||
" 'use_insecure_ssl_client_just_for_testing_do_not_use'"
|
||||
" is set in the config",
|
||||
id_server,
|
||||
'%s is not a trusted ID server: rejecting 3pid ' +
|
||||
'credentials', id_server
|
||||
)
|
||||
else:
|
||||
logger.warn('%s is not a trusted ID server: rejecting 3pid ' +
|
||||
'credentials', id_server)
|
||||
defer.returnValue(None)
|
||||
|
||||
data = {}
|
||||
|
@ -129,6 +136,12 @@ class IdentityHandler(BaseHandler):
|
|||
def requestEmailToken(self, id_server, email, client_secret, send_attempt, **kwargs):
|
||||
yield run_on_reactor()
|
||||
|
||||
if not self._should_trust_id_server(id_server):
|
||||
raise SynapseError(
|
||||
400, "Untrusted ID server '%s'" % id_server,
|
||||
Codes.SERVER_NOT_TRUSTED
|
||||
)
|
||||
|
||||
params = {
|
||||
'email': email,
|
||||
'client_secret': client_secret,
|
||||
|
|
|
@ -26,9 +26,9 @@ from synapse.types import (
|
|||
UserID, RoomAlias, RoomStreamToken, StreamToken, get_domain_from_id
|
||||
)
|
||||
from synapse.util import unwrapFirstError
|
||||
from synapse.util.async import concurrently_execute
|
||||
from synapse.util.async import concurrently_execute, run_on_reactor, ReadWriteLock
|
||||
from synapse.util.caches.snapshot_cache import SnapshotCache
|
||||
from synapse.util.logcontext import PreserveLoggingContext, preserve_fn
|
||||
from synapse.util.logcontext import preserve_fn
|
||||
from synapse.visibility import filter_events_for_client
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
@ -50,9 +50,23 @@ class MessageHandler(BaseHandler):
|
|||
self.validator = EventValidator()
|
||||
self.snapshot_cache = SnapshotCache()
|
||||
|
||||
self.pagination_lock = ReadWriteLock()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def purge_history(self, room_id, event_id):
|
||||
event = yield self.store.get_event(event_id)
|
||||
|
||||
if event.room_id != room_id:
|
||||
raise SynapseError(400, "Event is for wrong room.")
|
||||
|
||||
depth = event.depth
|
||||
|
||||
with (yield self.pagination_lock.write(room_id)):
|
||||
yield self.store.delete_old_state(room_id, depth)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_messages(self, requester, room_id=None, pagin_config=None,
|
||||
as_client_event=True):
|
||||
as_client_event=True, event_filter=None):
|
||||
"""Get messages in a room.
|
||||
|
||||
Args:
|
||||
|
@ -61,11 +75,11 @@ class MessageHandler(BaseHandler):
|
|||
pagin_config (synapse.api.streams.PaginationConfig): The pagination
|
||||
config rules to apply, if any.
|
||||
as_client_event (bool): True to get events in client-server format.
|
||||
event_filter (Filter): Filter to apply to results or None
|
||||
Returns:
|
||||
dict: Pagination API results
|
||||
"""
|
||||
user_id = requester.user.to_string()
|
||||
data_source = self.hs.get_event_sources().sources["room"]
|
||||
|
||||
if pagin_config.from_token:
|
||||
room_token = pagin_config.from_token.room_key
|
||||
|
@ -85,6 +99,7 @@ class MessageHandler(BaseHandler):
|
|||
|
||||
source_config = pagin_config.get_source_config("room")
|
||||
|
||||
with (yield self.pagination_lock.read(room_id)):
|
||||
membership, member_event_id = yield self._check_in_room_or_world_readable(
|
||||
room_id, user_id
|
||||
)
|
||||
|
@ -95,7 +110,7 @@ class MessageHandler(BaseHandler):
|
|||
if room_token.topological:
|
||||
max_topo = room_token.topological
|
||||
else:
|
||||
max_topo = yield self.store.get_max_topological_token_for_stream_and_room(
|
||||
max_topo = yield self.store.get_max_topological_token(
|
||||
room_id, room_token.stream
|
||||
)
|
||||
|
||||
|
@ -114,8 +129,13 @@ class MessageHandler(BaseHandler):
|
|||
room_id, max_topo
|
||||
)
|
||||
|
||||
events, next_key = yield data_source.get_pagination_rows(
|
||||
requester.user, source_config, room_id
|
||||
events, next_key = yield self.store.paginate_room_events(
|
||||
room_id=room_id,
|
||||
from_key=source_config.from_key,
|
||||
to_key=source_config.to_key,
|
||||
direction=source_config.direction,
|
||||
limit=source_config.limit,
|
||||
event_filter=event_filter,
|
||||
)
|
||||
|
||||
next_token = pagin_config.from_token.copy_and_replace(
|
||||
|
@ -129,6 +149,9 @@ class MessageHandler(BaseHandler):
|
|||
"end": next_token.to_string(),
|
||||
})
|
||||
|
||||
if event_filter:
|
||||
events = event_filter.filter(events)
|
||||
|
||||
events = yield filter_events_for_client(
|
||||
self.store,
|
||||
user_id,
|
||||
|
@ -908,13 +931,16 @@ class MessageHandler(BaseHandler):
|
|||
"Failed to get destination from event %s", s.event_id
|
||||
)
|
||||
|
||||
with PreserveLoggingContext():
|
||||
# Don't block waiting on waking up all the listeners.
|
||||
@defer.inlineCallbacks
|
||||
def _notify():
|
||||
yield run_on_reactor()
|
||||
self.notifier.on_new_room_event(
|
||||
event, event_stream_id, max_stream_id,
|
||||
extra_users=extra_users
|
||||
)
|
||||
|
||||
preserve_fn(_notify)()
|
||||
|
||||
# If invite, remove room_state from unsigned before sending.
|
||||
event.unsigned.pop("invite_room_state", None)
|
||||
|
||||
|
|
|
@ -50,6 +50,8 @@ timers_fired_counter = metrics.register_counter("timers_fired")
|
|||
federation_presence_counter = metrics.register_counter("federation_presence")
|
||||
bump_active_time_counter = metrics.register_counter("bump_active_time")
|
||||
|
||||
get_updates_counter = metrics.register_counter("get_updates", labels=["type"])
|
||||
|
||||
|
||||
# If a user was last active in the last LAST_ACTIVE_GRANULARITY, consider them
|
||||
# "currently_active"
|
||||
|
@ -68,6 +70,10 @@ FEDERATION_TIMEOUT = 30 * 60 * 1000
|
|||
# How often to resend presence to remote servers
|
||||
FEDERATION_PING_INTERVAL = 25 * 60 * 1000
|
||||
|
||||
# How long we will wait before assuming that the syncs from an external process
|
||||
# are dead.
|
||||
EXTERNAL_PROCESS_EXPIRY = 5 * 60 * 1000
|
||||
|
||||
assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER
|
||||
|
||||
|
||||
|
@ -158,15 +164,26 @@ class PresenceHandler(object):
|
|||
self.serial_to_user = {}
|
||||
self._next_serial = 1
|
||||
|
||||
# Keeps track of the number of *ongoing* syncs. While this is non zero
|
||||
# a user will never go offline.
|
||||
# Keeps track of the number of *ongoing* syncs on this process. While
|
||||
# this is non zero a user will never go offline.
|
||||
self.user_to_num_current_syncs = {}
|
||||
|
||||
# Keeps track of the number of *ongoing* syncs on other processes.
|
||||
# While any sync is ongoing on another process the user will never
|
||||
# go offline.
|
||||
# Each process has a unique identifier and an update frequency. If
|
||||
# no update is received from that process within the update period then
|
||||
# we assume that all the sync requests on that process have stopped.
|
||||
# Stored as a dict from process_id to set of user_id, and a dict of
|
||||
# process_id to millisecond timestamp last updated.
|
||||
self.external_process_to_current_syncs = {}
|
||||
self.external_process_last_updated_ms = {}
|
||||
|
||||
# Start a LoopingCall in 30s that fires every 5s.
|
||||
# The initial delay is to allow disconnected clients a chance to
|
||||
# reconnect before we treat them as offline.
|
||||
self.clock.call_later(
|
||||
30 * 1000,
|
||||
30,
|
||||
self.clock.looping_call,
|
||||
self._handle_timeouts,
|
||||
5000,
|
||||
|
@ -266,19 +283,34 @@ class PresenceHandler(object):
|
|||
"""Checks the presence of users that have timed out and updates as
|
||||
appropriate.
|
||||
"""
|
||||
logger.info("Handling presence timeouts")
|
||||
now = self.clock.time_msec()
|
||||
|
||||
try:
|
||||
with Measure(self.clock, "presence_handle_timeouts"):
|
||||
# Fetch the list of users that *may* have timed out. Things may have
|
||||
# changed since the timeout was set, so we won't necessarily have to
|
||||
# take any action.
|
||||
users_to_check = self.wheel_timer.fetch(now)
|
||||
users_to_check = set(self.wheel_timer.fetch(now))
|
||||
|
||||
# Check whether the lists of syncing processes from an external
|
||||
# process have expired.
|
||||
expired_process_ids = [
|
||||
process_id for process_id, last_update
|
||||
in self.external_process_last_updated_ms.items()
|
||||
if now - last_update > EXTERNAL_PROCESS_EXPIRY
|
||||
]
|
||||
for process_id in expired_process_ids:
|
||||
users_to_check.update(
|
||||
self.external_process_last_updated_ms.pop(process_id, ())
|
||||
)
|
||||
self.external_process_last_update.pop(process_id)
|
||||
|
||||
states = [
|
||||
self.user_to_current_state.get(
|
||||
user_id, UserPresenceState.default(user_id)
|
||||
)
|
||||
for user_id in set(users_to_check)
|
||||
for user_id in users_to_check
|
||||
]
|
||||
|
||||
timers_fired_counter.inc_by(len(states))
|
||||
|
@ -286,11 +318,13 @@ class PresenceHandler(object):
|
|||
changes = handle_timeouts(
|
||||
states,
|
||||
is_mine_fn=self.is_mine_id,
|
||||
user_to_num_current_syncs=self.user_to_num_current_syncs,
|
||||
syncing_user_ids=self.get_currently_syncing_users(),
|
||||
now=now,
|
||||
)
|
||||
|
||||
preserve_fn(self._update_states)(changes)
|
||||
except:
|
||||
logger.exception("Exception in _handle_timeouts loop")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def bump_presence_active_time(self, user):
|
||||
|
@ -363,6 +397,74 @@ class PresenceHandler(object):
|
|||
|
||||
defer.returnValue(_user_syncing())
|
||||
|
||||
def get_currently_syncing_users(self):
|
||||
"""Get the set of user ids that are currently syncing on this HS.
|
||||
Returns:
|
||||
set(str): A set of user_id strings.
|
||||
"""
|
||||
syncing_user_ids = {
|
||||
user_id for user_id, count in self.user_to_num_current_syncs.items()
|
||||
if count
|
||||
}
|
||||
for user_ids in self.external_process_to_current_syncs.values():
|
||||
syncing_user_ids.update(user_ids)
|
||||
return syncing_user_ids
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def update_external_syncs(self, process_id, syncing_user_ids):
|
||||
"""Update the syncing users for an external process
|
||||
|
||||
Args:
|
||||
process_id(str): An identifier for the process the users are
|
||||
syncing against. This allows synapse to process updates
|
||||
as user start and stop syncing against a given process.
|
||||
syncing_user_ids(set(str)): The set of user_ids that are
|
||||
currently syncing on that server.
|
||||
"""
|
||||
|
||||
# Grab the previous list of user_ids that were syncing on that process
|
||||
prev_syncing_user_ids = (
|
||||
self.external_process_to_current_syncs.get(process_id, set())
|
||||
)
|
||||
# Grab the current presence state for both the users that are syncing
|
||||
# now and the users that were syncing before this update.
|
||||
prev_states = yield self.current_state_for_users(
|
||||
syncing_user_ids | prev_syncing_user_ids
|
||||
)
|
||||
updates = []
|
||||
time_now_ms = self.clock.time_msec()
|
||||
|
||||
# For each new user that is syncing check if we need to mark them as
|
||||
# being online.
|
||||
for new_user_id in syncing_user_ids - prev_syncing_user_ids:
|
||||
prev_state = prev_states[new_user_id]
|
||||
if prev_state.state == PresenceState.OFFLINE:
|
||||
updates.append(prev_state.copy_and_replace(
|
||||
state=PresenceState.ONLINE,
|
||||
last_active_ts=time_now_ms,
|
||||
last_user_sync_ts=time_now_ms,
|
||||
))
|
||||
else:
|
||||
updates.append(prev_state.copy_and_replace(
|
||||
last_user_sync_ts=time_now_ms,
|
||||
))
|
||||
|
||||
# For each user that is still syncing or stopped syncing update the
|
||||
# last sync time so that we will correctly apply the grace period when
|
||||
# they stop syncing.
|
||||
for old_user_id in prev_syncing_user_ids:
|
||||
prev_state = prev_states[old_user_id]
|
||||
updates.append(prev_state.copy_and_replace(
|
||||
last_user_sync_ts=time_now_ms,
|
||||
))
|
||||
|
||||
yield self._update_states(updates)
|
||||
|
||||
# Update the last updated time for the process. We expire the entries
|
||||
# if we don't receive an update in the given timeframe.
|
||||
self.external_process_last_updated_ms[process_id] = self.clock.time_msec()
|
||||
self.external_process_to_current_syncs[process_id] = syncing_user_ids
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def current_state_for_user(self, user_id):
|
||||
"""Get the current presence state for a user.
|
||||
|
@ -879,13 +981,13 @@ class PresenceEventSource(object):
|
|||
|
||||
user_ids_changed = set()
|
||||
changed = None
|
||||
if from_key and max_token - from_key < 100:
|
||||
# For small deltas, its quicker to get all changes and then
|
||||
# work out if we share a room or they're in our presence list
|
||||
if from_key:
|
||||
changed = stream_change_cache.get_all_entities_changed(from_key)
|
||||
|
||||
# get_all_entities_changed can return None
|
||||
if changed is not None:
|
||||
if changed is not None and len(changed) < 500:
|
||||
# For small deltas, its quicker to get all changes and then
|
||||
# work out if we share a room or they're in our presence list
|
||||
get_updates_counter.inc("stream")
|
||||
for other_user_id in changed:
|
||||
if other_user_id in friends:
|
||||
user_ids_changed.add(other_user_id)
|
||||
|
@ -897,6 +999,8 @@ class PresenceEventSource(object):
|
|||
else:
|
||||
# Too many possible updates. Find all users we can see and check
|
||||
# if any of them have changed.
|
||||
get_updates_counter.inc("full")
|
||||
|
||||
user_ids_to_check = set()
|
||||
for room_id in room_ids:
|
||||
users = yield self.store.get_users_in_room(room_id)
|
||||
|
@ -935,15 +1039,14 @@ class PresenceEventSource(object):
|
|||
return self.get_new_events(user, from_key=None, include_offline=False)
|
||||
|
||||
|
||||
def handle_timeouts(user_states, is_mine_fn, user_to_num_current_syncs, now):
|
||||
def handle_timeouts(user_states, is_mine_fn, syncing_user_ids, now):
|
||||
"""Checks the presence of users that have timed out and updates as
|
||||
appropriate.
|
||||
|
||||
Args:
|
||||
user_states(list): List of UserPresenceState's to check.
|
||||
is_mine_fn (fn): Function that returns if a user_id is ours
|
||||
user_to_num_current_syncs (dict): Mapping of user_id to number of currently
|
||||
active syncs.
|
||||
syncing_user_ids (set): Set of user_ids with active syncs.
|
||||
now (int): Current time in ms.
|
||||
|
||||
Returns:
|
||||
|
@ -954,21 +1057,20 @@ def handle_timeouts(user_states, is_mine_fn, user_to_num_current_syncs, now):
|
|||
for state in user_states:
|
||||
is_mine = is_mine_fn(state.user_id)
|
||||
|
||||
new_state = handle_timeout(state, is_mine, user_to_num_current_syncs, now)
|
||||
new_state = handle_timeout(state, is_mine, syncing_user_ids, now)
|
||||
if new_state:
|
||||
changes[state.user_id] = new_state
|
||||
|
||||
return changes.values()
|
||||
|
||||
|
||||
def handle_timeout(state, is_mine, user_to_num_current_syncs, now):
|
||||
def handle_timeout(state, is_mine, syncing_user_ids, now):
|
||||
"""Checks the presence of the user to see if any of the timers have elapsed
|
||||
|
||||
Args:
|
||||
state (UserPresenceState)
|
||||
is_mine (bool): Whether the user is ours
|
||||
user_to_num_current_syncs (dict): Mapping of user_id to number of currently
|
||||
active syncs.
|
||||
syncing_user_ids (set): Set of user_ids with active syncs.
|
||||
now (int): Current time in ms.
|
||||
|
||||
Returns:
|
||||
|
@ -1002,7 +1104,7 @@ def handle_timeout(state, is_mine, user_to_num_current_syncs, now):
|
|||
|
||||
# If there are have been no sync for a while (and none ongoing),
|
||||
# set presence to offline
|
||||
if not user_to_num_current_syncs.get(user_id, 0):
|
||||
if user_id not in syncing_user_ids:
|
||||
if now - state.last_user_sync_ts > SYNC_ONLINE_TIMEOUT:
|
||||
state = state.copy_and_replace(
|
||||
state=PresenceState.OFFLINE,
|
||||
|
|
|
@ -13,15 +13,15 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
import synapse.types
|
||||
from synapse.api.errors import SynapseError, AuthError, CodeMessageException
|
||||
from synapse.types import UserID, Requester
|
||||
|
||||
from synapse.types import UserID
|
||||
from ._base import BaseHandler
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -36,13 +36,6 @@ class ProfileHandler(BaseHandler):
|
|||
"profile", self.on_profile_query
|
||||
)
|
||||
|
||||
distributor = hs.get_distributor()
|
||||
|
||||
distributor.observe("registered_user", self.registered_user)
|
||||
|
||||
def registered_user(self, user):
|
||||
return self.store.create_profile(user.localpart)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_displayname(self, target_user):
|
||||
if self.hs.is_mine(target_user):
|
||||
|
@ -172,7 +165,9 @@ class ProfileHandler(BaseHandler):
|
|||
try:
|
||||
# Assume the user isn't a guest because we don't let guests set
|
||||
# profile or avatar data.
|
||||
requester = Requester(user, "", False)
|
||||
# XXX why are we recreating `requester` here for each room?
|
||||
# what was wrong with the `requester` we were passed?
|
||||
requester = synapse.types.create_requester(user)
|
||||
yield handler.update_membership(
|
||||
requester,
|
||||
user,
|
||||
|
|
|
@ -14,19 +14,19 @@
|
|||
# limitations under the License.
|
||||
|
||||
"""Contains functions for registering clients."""
|
||||
import logging
|
||||
import urllib
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.types import UserID
|
||||
import synapse.types
|
||||
from synapse.api.errors import (
|
||||
AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
|
||||
)
|
||||
from ._base import BaseHandler
|
||||
from synapse.util.async import run_on_reactor
|
||||
from synapse.http.client import CaptchaServerHttpClient
|
||||
from synapse.util.distributor import registered_user
|
||||
|
||||
import logging
|
||||
import urllib
|
||||
from synapse.types import UserID
|
||||
from synapse.util.async import run_on_reactor
|
||||
from ._base import BaseHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -37,8 +37,6 @@ class RegistrationHandler(BaseHandler):
|
|||
super(RegistrationHandler, self).__init__(hs)
|
||||
|
||||
self.auth = hs.get_auth()
|
||||
self.distributor = hs.get_distributor()
|
||||
self.distributor.declare("registered_user")
|
||||
self.captcha_client = CaptchaServerHttpClient(hs)
|
||||
|
||||
self._next_generated_user_id = None
|
||||
|
@ -55,6 +53,13 @@ class RegistrationHandler(BaseHandler):
|
|||
Codes.INVALID_USERNAME
|
||||
)
|
||||
|
||||
if localpart[0] == '_':
|
||||
raise SynapseError(
|
||||
400,
|
||||
"User ID may not begin with _",
|
||||
Codes.INVALID_USERNAME
|
||||
)
|
||||
|
||||
user = UserID(localpart, self.hs.hostname)
|
||||
user_id = user.to_string()
|
||||
|
||||
|
@ -93,7 +98,8 @@ class RegistrationHandler(BaseHandler):
|
|||
password=None,
|
||||
generate_token=True,
|
||||
guest_access_token=None,
|
||||
make_guest=False
|
||||
make_guest=False,
|
||||
admin=False,
|
||||
):
|
||||
"""Registers a new client on the server.
|
||||
|
||||
|
@ -103,6 +109,11 @@ class RegistrationHandler(BaseHandler):
|
|||
password (str) : The password to assign to this user so they can
|
||||
login again. This can be None which means they cannot login again
|
||||
via a password (e.g. the user is an application service user).
|
||||
generate_token (bool): Whether a new access token should be
|
||||
generated. Having this be True should be considered deprecated,
|
||||
since it offers no means of associating a device_id with the
|
||||
access_token. Instead you should call auth_handler.issue_access_token
|
||||
after registration.
|
||||
Returns:
|
||||
A tuple of (user_id, access_token).
|
||||
Raises:
|
||||
|
@ -140,9 +151,12 @@ class RegistrationHandler(BaseHandler):
|
|||
password_hash=password_hash,
|
||||
was_guest=was_guest,
|
||||
make_guest=make_guest,
|
||||
create_profile_with_localpart=(
|
||||
# If the user was a guest then they already have a profile
|
||||
None if was_guest else user.localpart
|
||||
),
|
||||
admin=admin,
|
||||
)
|
||||
|
||||
yield registered_user(self.distributor, user)
|
||||
else:
|
||||
# autogen a sequential user ID
|
||||
attempts = 0
|
||||
|
@ -160,7 +174,8 @@ class RegistrationHandler(BaseHandler):
|
|||
user_id=user_id,
|
||||
token=token,
|
||||
password_hash=password_hash,
|
||||
make_guest=make_guest
|
||||
make_guest=make_guest,
|
||||
create_profile_with_localpart=user.localpart,
|
||||
)
|
||||
except SynapseError:
|
||||
# if user id is taken, just generate another
|
||||
|
@ -168,7 +183,6 @@ class RegistrationHandler(BaseHandler):
|
|||
user_id = None
|
||||
token = None
|
||||
attempts += 1
|
||||
yield registered_user(self.distributor, user)
|
||||
|
||||
# We used to generate default identicons here, but nowadays
|
||||
# we want clients to generate their own as part of their branding
|
||||
|
@ -195,15 +209,13 @@ class RegistrationHandler(BaseHandler):
|
|||
user_id, allowed_appservice=service
|
||||
)
|
||||
|
||||
token = self.auth_handler().generate_access_token(user_id)
|
||||
yield self.store.register(
|
||||
user_id=user_id,
|
||||
token=token,
|
||||
password_hash="",
|
||||
appservice_id=service_id,
|
||||
create_profile_with_localpart=user.localpart,
|
||||
)
|
||||
yield registered_user(self.distributor, user)
|
||||
defer.returnValue((user_id, token))
|
||||
defer.returnValue(user_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_recaptcha(self, ip, private_key, challenge, response):
|
||||
|
@ -248,9 +260,9 @@ class RegistrationHandler(BaseHandler):
|
|||
yield self.store.register(
|
||||
user_id=user_id,
|
||||
token=token,
|
||||
password_hash=None
|
||||
password_hash=None,
|
||||
create_profile_with_localpart=user.localpart,
|
||||
)
|
||||
yield registered_user(self.distributor, user)
|
||||
except Exception as e:
|
||||
yield self.store.add_access_token_to_user(user_id, token)
|
||||
# Ignore Registration errors
|
||||
|
@ -359,8 +371,10 @@ class RegistrationHandler(BaseHandler):
|
|||
defer.returnValue(data)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_or_create_user(self, localpart, displayname, duration_seconds):
|
||||
"""Creates a new user or returns an access token for an existing one
|
||||
def get_or_create_user(self, localpart, displayname, duration_in_ms,
|
||||
password_hash=None):
|
||||
"""Creates a new user if the user does not exist,
|
||||
else revokes all previous access tokens and generates a new one.
|
||||
|
||||
Args:
|
||||
localpart : The local part of the user ID to register. If None,
|
||||
|
@ -387,32 +401,32 @@ class RegistrationHandler(BaseHandler):
|
|||
|
||||
user = UserID(localpart, self.hs.hostname)
|
||||
user_id = user.to_string()
|
||||
auth_handler = self.hs.get_handlers().auth_handler
|
||||
token = auth_handler.generate_short_term_login_token(user_id, duration_seconds)
|
||||
token = self.auth_handler().generate_access_token(
|
||||
user_id, None, duration_in_ms)
|
||||
|
||||
if need_register:
|
||||
yield self.store.register(
|
||||
user_id=user_id,
|
||||
token=token,
|
||||
password_hash=None
|
||||
password_hash=password_hash,
|
||||
create_profile_with_localpart=user.localpart,
|
||||
)
|
||||
|
||||
yield registered_user(self.distributor, user)
|
||||
else:
|
||||
yield self.store.flush_user(user_id=user_id)
|
||||
yield self.store.user_delete_access_tokens(user_id=user_id)
|
||||
yield self.store.add_access_token_to_user(user_id=user_id, token=token)
|
||||
|
||||
if displayname is not None:
|
||||
logger.info("setting user display name: %s -> %s", user_id, displayname)
|
||||
profile_handler = self.hs.get_handlers().profile_handler
|
||||
requester = synapse.types.create_requester(user)
|
||||
yield profile_handler.set_displayname(
|
||||
user, user, displayname
|
||||
user, requester, displayname
|
||||
)
|
||||
|
||||
defer.returnValue((user_id, token))
|
||||
|
||||
def auth_handler(self):
|
||||
return self.hs.get_handlers().auth_handler
|
||||
return self.hs.get_auth_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def guest_access_token_for(self, medium, address, inviter_user_id):
|
||||
|
|
|
@ -20,7 +20,7 @@ from ._base import BaseHandler
|
|||
|
||||
from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken
|
||||
from synapse.api.constants import (
|
||||
EventTypes, JoinRules, RoomCreationPreset,
|
||||
EventTypes, JoinRules, RoomCreationPreset, Membership,
|
||||
)
|
||||
from synapse.api.errors import AuthError, StoreError, SynapseError
|
||||
from synapse.util import stringutils
|
||||
|
@ -36,6 +36,8 @@ import string
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000
|
||||
|
||||
id_server_scheme = "https://"
|
||||
|
||||
|
||||
|
@ -343,9 +345,15 @@ class RoomCreationHandler(BaseHandler):
|
|||
class RoomListHandler(BaseHandler):
|
||||
def __init__(self, hs):
|
||||
super(RoomListHandler, self).__init__(hs)
|
||||
self.response_cache = ResponseCache()
|
||||
self.response_cache = ResponseCache(hs)
|
||||
self.remote_list_request_cache = ResponseCache(hs)
|
||||
self.remote_list_cache = {}
|
||||
self.fetch_looping_call = hs.get_clock().looping_call(
|
||||
self.fetch_all_remote_lists, REMOTE_ROOM_LIST_POLL_INTERVAL
|
||||
)
|
||||
self.fetch_all_remote_lists()
|
||||
|
||||
def get_public_room_list(self):
|
||||
def get_local_public_room_list(self):
|
||||
result = self.response_cache.get(())
|
||||
if not result:
|
||||
result = self.response_cache.set((), self._get_public_room_list())
|
||||
|
@ -359,14 +367,10 @@ class RoomListHandler(BaseHandler):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def handle_room(room_id):
|
||||
# We pull each bit of state out indvidually to avoid pulling the
|
||||
# full state into memory. Due to how the caching works this should
|
||||
# be fairly quick, even if not originally in the cache.
|
||||
def get_state(etype, state_key):
|
||||
return self.state_handler.get_current_state(room_id, etype, state_key)
|
||||
current_state = yield self.state_handler.get_current_state(room_id)
|
||||
|
||||
# Double check that this is actually a public room.
|
||||
join_rules_event = yield get_state(EventTypes.JoinRules, "")
|
||||
join_rules_event = current_state.get((EventTypes.JoinRules, ""))
|
||||
if join_rules_event:
|
||||
join_rule = join_rules_event.content.get("join_rule", None)
|
||||
if join_rule and join_rule != JoinRules.PUBLIC:
|
||||
|
@ -374,47 +378,51 @@ class RoomListHandler(BaseHandler):
|
|||
|
||||
result = {"room_id": room_id}
|
||||
|
||||
joined_users = yield self.store.get_users_in_room(room_id)
|
||||
if len(joined_users) == 0:
|
||||
num_joined_users = len([
|
||||
1 for _, event in current_state.items()
|
||||
if event.type == EventTypes.Member
|
||||
and event.membership == Membership.JOIN
|
||||
])
|
||||
if num_joined_users == 0:
|
||||
return
|
||||
|
||||
result["num_joined_members"] = len(joined_users)
|
||||
result["num_joined_members"] = num_joined_users
|
||||
|
||||
aliases = yield self.store.get_aliases_for_room(room_id)
|
||||
if aliases:
|
||||
result["aliases"] = aliases
|
||||
|
||||
name_event = yield get_state(EventTypes.Name, "")
|
||||
name_event = yield current_state.get((EventTypes.Name, ""))
|
||||
if name_event:
|
||||
name = name_event.content.get("name", None)
|
||||
if name:
|
||||
result["name"] = name
|
||||
|
||||
topic_event = yield get_state(EventTypes.Topic, "")
|
||||
topic_event = current_state.get((EventTypes.Topic, ""))
|
||||
if topic_event:
|
||||
topic = topic_event.content.get("topic", None)
|
||||
if topic:
|
||||
result["topic"] = topic
|
||||
|
||||
canonical_event = yield get_state(EventTypes.CanonicalAlias, "")
|
||||
canonical_event = current_state.get((EventTypes.CanonicalAlias, ""))
|
||||
if canonical_event:
|
||||
canonical_alias = canonical_event.content.get("alias", None)
|
||||
if canonical_alias:
|
||||
result["canonical_alias"] = canonical_alias
|
||||
|
||||
visibility_event = yield get_state(EventTypes.RoomHistoryVisibility, "")
|
||||
visibility_event = current_state.get((EventTypes.RoomHistoryVisibility, ""))
|
||||
visibility = None
|
||||
if visibility_event:
|
||||
visibility = visibility_event.content.get("history_visibility", None)
|
||||
result["world_readable"] = visibility == "world_readable"
|
||||
|
||||
guest_event = yield get_state(EventTypes.GuestAccess, "")
|
||||
guest_event = current_state.get((EventTypes.GuestAccess, ""))
|
||||
guest = None
|
||||
if guest_event:
|
||||
guest = guest_event.content.get("guest_access", None)
|
||||
result["guest_can_join"] = guest == "can_join"
|
||||
|
||||
avatar_event = yield get_state("m.room.avatar", "")
|
||||
avatar_event = current_state.get(("m.room.avatar", ""))
|
||||
if avatar_event:
|
||||
avatar_url = avatar_event.content.get("url", None)
|
||||
if avatar_url:
|
||||
|
@ -427,6 +435,55 @@ class RoomListHandler(BaseHandler):
|
|||
# FIXME (erikj): START is no longer a valid value
|
||||
defer.returnValue({"start": "START", "end": "END", "chunk": results})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def fetch_all_remote_lists(self):
|
||||
deferred = self.hs.get_replication_layer().get_public_rooms(
|
||||
self.hs.config.secondary_directory_servers
|
||||
)
|
||||
self.remote_list_request_cache.set((), deferred)
|
||||
self.remote_list_cache = yield deferred
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_aggregated_public_room_list(self):
|
||||
"""
|
||||
Get the public room list from this server and the servers
|
||||
specified in the secondary_directory_servers config option.
|
||||
XXX: Pagination...
|
||||
"""
|
||||
# We return the results from out cache which is updated by a looping call,
|
||||
# unless we're missing a cache entry, in which case wait for the result
|
||||
# of the fetch if there's one in progress. If not, omit that server.
|
||||
wait = False
|
||||
for s in self.hs.config.secondary_directory_servers:
|
||||
if s not in self.remote_list_cache:
|
||||
logger.warn("No cached room list from %s: waiting for fetch", s)
|
||||
wait = True
|
||||
break
|
||||
|
||||
if wait and self.remote_list_request_cache.get(()):
|
||||
yield self.remote_list_request_cache.get(())
|
||||
|
||||
public_rooms = yield self.get_local_public_room_list()
|
||||
|
||||
# keep track of which room IDs we've seen so we can de-dup
|
||||
room_ids = set()
|
||||
|
||||
# tag all the ones in our list with our server name.
|
||||
# Also add the them to the de-deping set
|
||||
for room in public_rooms['chunk']:
|
||||
room["server_name"] = self.hs.hostname
|
||||
room_ids.add(room["room_id"])
|
||||
|
||||
# Now add the results from federation
|
||||
for server_name, server_result in self.remote_list_cache.items():
|
||||
for room in server_result["chunk"]:
|
||||
if room["room_id"] not in room_ids:
|
||||
room["server_name"] = server_name
|
||||
public_rooms["chunk"].append(room)
|
||||
room_ids.add(room["room_id"])
|
||||
|
||||
defer.returnValue(public_rooms)
|
||||
|
||||
|
||||
class RoomContextHandler(BaseHandler):
|
||||
@defer.inlineCallbacks
|
||||
|
|
|
@ -14,24 +14,22 @@
|
|||
# limitations under the License.
|
||||
|
||||
|
||||
import logging
|
||||
|
||||
from signedjson.key import decode_verify_key_bytes
|
||||
from signedjson.sign import verify_signed_json
|
||||
from twisted.internet import defer
|
||||
from unpaddedbase64 import decode_base64
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
||||
from synapse.types import UserID, RoomID, Requester
|
||||
import synapse.types
|
||||
from synapse.api.constants import (
|
||||
EventTypes, Membership,
|
||||
)
|
||||
from synapse.api.errors import AuthError, SynapseError, Codes
|
||||
from synapse.types import UserID, RoomID
|
||||
from synapse.util.async import Linearizer
|
||||
from synapse.util.distributor import user_left_room, user_joined_room
|
||||
|
||||
from signedjson.sign import verify_signed_json
|
||||
from signedjson.key import decode_verify_key_bytes
|
||||
|
||||
from unpaddedbase64 import decode_base64
|
||||
|
||||
import logging
|
||||
from ._base import BaseHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -315,7 +313,7 @@ class RoomMemberHandler(BaseHandler):
|
|||
)
|
||||
assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,)
|
||||
else:
|
||||
requester = Requester(target_user, None, False)
|
||||
requester = synapse.types.create_requester(target_user)
|
||||
|
||||
message_handler = self.hs.get_handlers().message_handler
|
||||
prev_event = message_handler.deduplicate_state_event(event, context)
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
# A tiny object useful for storing a user's membership in a room, as a mapping
|
||||
# key
|
||||
RoomMember = namedtuple("RoomMember", ("room_id", "user"))
|
||||
RoomMember = namedtuple("RoomMember", ("room_id", "user_id"))
|
||||
|
||||
|
||||
class TypingHandler(object):
|
||||
|
@ -38,7 +38,7 @@ class TypingHandler(object):
|
|||
self.store = hs.get_datastore()
|
||||
self.server_name = hs.config.server_name
|
||||
self.auth = hs.get_auth()
|
||||
self.is_mine = hs.is_mine
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
self.notifier = hs.get_notifier()
|
||||
|
||||
self.clock = hs.get_clock()
|
||||
|
@ -67,20 +67,23 @@ class TypingHandler(object):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def started_typing(self, target_user, auth_user, room_id, timeout):
|
||||
if not self.is_mine(target_user):
|
||||
target_user_id = target_user.to_string()
|
||||
auth_user_id = auth_user.to_string()
|
||||
|
||||
if not self.is_mine_id(target_user_id):
|
||||
raise SynapseError(400, "User is not hosted on this Home Server")
|
||||
|
||||
if target_user != auth_user:
|
||||
if target_user_id != auth_user_id:
|
||||
raise AuthError(400, "Cannot set another user's typing state")
|
||||
|
||||
yield self.auth.check_joined_room(room_id, target_user.to_string())
|
||||
yield self.auth.check_joined_room(room_id, target_user_id)
|
||||
|
||||
logger.debug(
|
||||
"%s has started typing in %s", target_user.to_string(), room_id
|
||||
"%s has started typing in %s", target_user_id, room_id
|
||||
)
|
||||
|
||||
until = self.clock.time_msec() + timeout
|
||||
member = RoomMember(room_id=room_id, user=target_user)
|
||||
member = RoomMember(room_id=room_id, user_id=target_user_id)
|
||||
|
||||
was_present = member in self._member_typing_until
|
||||
|
||||
|
@ -104,25 +107,28 @@ class TypingHandler(object):
|
|||
|
||||
yield self._push_update(
|
||||
room_id=room_id,
|
||||
user=target_user,
|
||||
user_id=target_user_id,
|
||||
typing=True,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def stopped_typing(self, target_user, auth_user, room_id):
|
||||
if not self.is_mine(target_user):
|
||||
target_user_id = target_user.to_string()
|
||||
auth_user_id = auth_user.to_string()
|
||||
|
||||
if not self.is_mine_id(target_user_id):
|
||||
raise SynapseError(400, "User is not hosted on this Home Server")
|
||||
|
||||
if target_user != auth_user:
|
||||
if target_user_id != auth_user_id:
|
||||
raise AuthError(400, "Cannot set another user's typing state")
|
||||
|
||||
yield self.auth.check_joined_room(room_id, target_user.to_string())
|
||||
yield self.auth.check_joined_room(room_id, target_user_id)
|
||||
|
||||
logger.debug(
|
||||
"%s has stopped typing in %s", target_user.to_string(), room_id
|
||||
"%s has stopped typing in %s", target_user_id, room_id
|
||||
)
|
||||
|
||||
member = RoomMember(room_id=room_id, user=target_user)
|
||||
member = RoomMember(room_id=room_id, user_id=target_user_id)
|
||||
|
||||
if member in self._member_typing_timer:
|
||||
self.clock.cancel_call_later(self._member_typing_timer[member])
|
||||
|
@ -132,8 +138,9 @@ class TypingHandler(object):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def user_left_room(self, user, room_id):
|
||||
if self.is_mine(user):
|
||||
member = RoomMember(room_id=room_id, user=user)
|
||||
user_id = user.to_string()
|
||||
if self.is_mine_id(user_id):
|
||||
member = RoomMember(room_id=room_id, user_id=user_id)
|
||||
yield self._stopped_typing(member)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -144,7 +151,7 @@ class TypingHandler(object):
|
|||
|
||||
yield self._push_update(
|
||||
room_id=member.room_id,
|
||||
user=member.user,
|
||||
user_id=member.user_id,
|
||||
typing=False,
|
||||
)
|
||||
|
||||
|
@ -156,7 +163,7 @@ class TypingHandler(object):
|
|||
del self._member_typing_timer[member]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _push_update(self, room_id, user, typing):
|
||||
def _push_update(self, room_id, user_id, typing):
|
||||
domains = yield self.store.get_joined_hosts_for_room(room_id)
|
||||
|
||||
deferreds = []
|
||||
|
@ -164,7 +171,7 @@ class TypingHandler(object):
|
|||
if domain == self.server_name:
|
||||
self._push_update_local(
|
||||
room_id=room_id,
|
||||
user=user,
|
||||
user_id=user_id,
|
||||
typing=typing
|
||||
)
|
||||
else:
|
||||
|
@ -173,7 +180,7 @@ class TypingHandler(object):
|
|||
edu_type="m.typing",
|
||||
content={
|
||||
"room_id": room_id,
|
||||
"user_id": user.to_string(),
|
||||
"user_id": user_id,
|
||||
"typing": typing,
|
||||
},
|
||||
))
|
||||
|
@ -183,23 +190,26 @@ class TypingHandler(object):
|
|||
@defer.inlineCallbacks
|
||||
def _recv_edu(self, origin, content):
|
||||
room_id = content["room_id"]
|
||||
user = UserID.from_string(content["user_id"])
|
||||
user_id = content["user_id"]
|
||||
|
||||
# Check that the string is a valid user id
|
||||
UserID.from_string(user_id)
|
||||
|
||||
domains = yield self.store.get_joined_hosts_for_room(room_id)
|
||||
|
||||
if self.server_name in domains:
|
||||
self._push_update_local(
|
||||
room_id=room_id,
|
||||
user=user,
|
||||
user_id=user_id,
|
||||
typing=content["typing"]
|
||||
)
|
||||
|
||||
def _push_update_local(self, room_id, user, typing):
|
||||
def _push_update_local(self, room_id, user_id, typing):
|
||||
room_set = self._room_typing.setdefault(room_id, set())
|
||||
if typing:
|
||||
room_set.add(user)
|
||||
room_set.add(user_id)
|
||||
else:
|
||||
room_set.discard(user)
|
||||
room_set.discard(user_id)
|
||||
|
||||
self._latest_room_serial += 1
|
||||
self._room_serials[room_id] = self._latest_room_serial
|
||||
|
@ -211,13 +221,14 @@ class TypingHandler(object):
|
|||
|
||||
def get_all_typing_updates(self, last_id, current_id):
|
||||
# TODO: Work out a way to do this without scanning the entire state.
|
||||
if last_id == current_id:
|
||||
return []
|
||||
|
||||
rows = []
|
||||
for room_id, serial in self._room_serials.items():
|
||||
if last_id < serial and serial <= current_id:
|
||||
typing = self._room_typing[room_id]
|
||||
typing_bytes = json.dumps([
|
||||
u.to_string() for u in typing
|
||||
], ensure_ascii=False)
|
||||
typing_bytes = json.dumps(list(typing), ensure_ascii=False)
|
||||
rows.append((serial, room_id, typing_bytes))
|
||||
rows.sort()
|
||||
return rows
|
||||
|
@ -239,7 +250,7 @@ class TypingNotificationEventSource(object):
|
|||
"type": "m.typing",
|
||||
"room_id": room_id,
|
||||
"content": {
|
||||
"user_ids": [u.to_string() for u in typing],
|
||||
"user_ids": list(typing),
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
@ -24,12 +24,13 @@ from synapse.http.endpoint import SpiderEndpoint
|
|||
|
||||
from canonicaljson import encode_canonical_json
|
||||
|
||||
from twisted.internet import defer, reactor, ssl, protocol
|
||||
from twisted.internet import defer, reactor, ssl, protocol, task
|
||||
from twisted.internet.endpoints import SSL4ClientEndpoint, TCP4ClientEndpoint
|
||||
from twisted.web.client import (
|
||||
BrowserLikeRedirectAgent, ContentDecoderAgent, GzipDecoder, Agent,
|
||||
readBody, FileBodyProducer, PartialDownloadError,
|
||||
readBody, PartialDownloadError,
|
||||
)
|
||||
from twisted.web.client import FileBodyProducer as TwistedFileBodyProducer
|
||||
from twisted.web.http import PotentialDataLoss
|
||||
from twisted.web.http_headers import Headers
|
||||
from twisted.web._newclient import ResponseDone
|
||||
|
@ -468,3 +469,26 @@ class InsecureInterceptableContextFactory(ssl.ContextFactory):
|
|||
|
||||
def creatorForNetloc(self, hostname, port):
|
||||
return self
|
||||
|
||||
|
||||
class FileBodyProducer(TwistedFileBodyProducer):
|
||||
"""Workaround for https://twistedmatrix.com/trac/ticket/8473
|
||||
|
||||
We override the pauseProducing and resumeProducing methods in twisted's
|
||||
FileBodyProducer so that they do not raise exceptions if the task has
|
||||
already completed.
|
||||
"""
|
||||
|
||||
def pauseProducing(self):
|
||||
try:
|
||||
super(FileBodyProducer, self).pauseProducing()
|
||||
except task.TaskDone:
|
||||
# task has already completed
|
||||
pass
|
||||
|
||||
def resumeProducing(self):
|
||||
try:
|
||||
super(FileBodyProducer, self).resumeProducing()
|
||||
except task.NotPaused:
|
||||
# task was not paused (probably because it had already completed)
|
||||
pass
|
||||
|
|
|
@ -155,9 +155,7 @@ class MatrixFederationHttpClient(object):
|
|||
time_out=timeout / 1000. if timeout else 60,
|
||||
)
|
||||
|
||||
response = yield preserve_context_over_fn(
|
||||
send_request,
|
||||
)
|
||||
response = yield preserve_context_over_fn(send_request)
|
||||
|
||||
log_result = "%d %s" % (response.code, response.phrase,)
|
||||
break
|
||||
|
|
|
@ -205,6 +205,7 @@ class JsonResource(HttpServer, resource.Resource):
|
|||
|
||||
def register_paths(self, method, path_patterns, callback):
|
||||
for path_pattern in path_patterns:
|
||||
logger.debug("Registering for %s %s", method, path_pattern.pattern)
|
||||
self.path_regexs.setdefault(method, []).append(
|
||||
self._PathEntry(path_pattern, callback)
|
||||
)
|
||||
|
|
|
@ -22,22 +22,20 @@ import functools
|
|||
import os
|
||||
import stat
|
||||
import time
|
||||
import gc
|
||||
|
||||
from twisted.internet import reactor
|
||||
|
||||
from .metric import (
|
||||
CounterMetric, CallbackMetric, DistributionMetric, CacheMetric
|
||||
CounterMetric, CallbackMetric, DistributionMetric, CacheMetric,
|
||||
MemoryUsageMetric,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# We'll keep all the available metrics in a single toplevel dict, one shared
|
||||
# for the entire process. We don't currently support per-HomeServer instances
|
||||
# of metrics, because in practice any one python VM will host only one
|
||||
# HomeServer anyway. This makes a lot of implementation neater
|
||||
all_metrics = {}
|
||||
all_metrics = []
|
||||
|
||||
|
||||
class Metrics(object):
|
||||
|
@ -53,7 +51,7 @@ class Metrics(object):
|
|||
|
||||
metric = metric_class(full_name, *args, **kwargs)
|
||||
|
||||
all_metrics[full_name] = metric
|
||||
all_metrics.append(metric)
|
||||
return metric
|
||||
|
||||
def register_counter(self, *args, **kwargs):
|
||||
|
@ -69,6 +67,21 @@ class Metrics(object):
|
|||
return self._register(CacheMetric, *args, **kwargs)
|
||||
|
||||
|
||||
def register_memory_metrics(hs):
|
||||
try:
|
||||
import psutil
|
||||
process = psutil.Process()
|
||||
process.memory_info().rss
|
||||
except (ImportError, AttributeError):
|
||||
logger.warn(
|
||||
"psutil is not installed or incorrect version."
|
||||
" Disabling memory metrics."
|
||||
)
|
||||
return
|
||||
metric = MemoryUsageMetric(hs, psutil)
|
||||
all_metrics.append(metric)
|
||||
|
||||
|
||||
def get_metrics_for(pkg_name):
|
||||
""" Returns a Metrics instance for conveniently creating metrics
|
||||
namespaced with the given name prefix. """
|
||||
|
@ -84,12 +97,12 @@ def render_all():
|
|||
# TODO(paul): Internal hack
|
||||
update_resource_metrics()
|
||||
|
||||
for name in sorted(all_metrics.keys()):
|
||||
for metric in all_metrics:
|
||||
try:
|
||||
strs += all_metrics[name].render()
|
||||
strs += metric.render()
|
||||
except Exception:
|
||||
strs += ["# FAILED to render %s" % name]
|
||||
logger.exception("Failed to render %s metric", name)
|
||||
strs += ["# FAILED to render"]
|
||||
logger.exception("Failed to render metric")
|
||||
|
||||
strs.append("") # to generate a final CRLF
|
||||
|
||||
|
@ -156,6 +169,13 @@ reactor_metrics = get_metrics_for("reactor")
|
|||
tick_time = reactor_metrics.register_distribution("tick_time")
|
||||
pending_calls_metric = reactor_metrics.register_distribution("pending_calls")
|
||||
|
||||
gc_time = reactor_metrics.register_distribution("gc_time", labels=["gen"])
|
||||
gc_unreachable = reactor_metrics.register_counter("gc_unreachable", labels=["gen"])
|
||||
|
||||
reactor_metrics.register_callback(
|
||||
"gc_counts", lambda: {(i,): v for i, v in enumerate(gc.get_count())}, labels=["gen"]
|
||||
)
|
||||
|
||||
|
||||
def runUntilCurrentTimer(func):
|
||||
|
||||
|
@ -182,6 +202,22 @@ def runUntilCurrentTimer(func):
|
|||
end = time.time() * 1000
|
||||
tick_time.inc_by(end - start)
|
||||
pending_calls_metric.inc_by(num_pending)
|
||||
|
||||
# Check if we need to do a manual GC (since its been disabled), and do
|
||||
# one if necessary.
|
||||
threshold = gc.get_threshold()
|
||||
counts = gc.get_count()
|
||||
for i in (2, 1, 0):
|
||||
if threshold[i] < counts[i]:
|
||||
logger.info("Collecting gc %d", i)
|
||||
|
||||
start = time.time() * 1000
|
||||
unreachable = gc.collect(i)
|
||||
end = time.time() * 1000
|
||||
|
||||
gc_time.inc_by(end - start, i)
|
||||
gc_unreachable.inc_by(unreachable, i)
|
||||
|
||||
return ret
|
||||
|
||||
return f
|
||||
|
@ -196,5 +232,9 @@ try:
|
|||
# runUntilCurrent is called when we have pending calls. It is called once
|
||||
# per iteratation after fd polling.
|
||||
reactor.runUntilCurrent = runUntilCurrentTimer(reactor.runUntilCurrent)
|
||||
|
||||
# We manually run the GC each reactor tick so that we can get some metrics
|
||||
# about time spent doing GC,
|
||||
gc.disable()
|
||||
except AttributeError:
|
||||
pass
|
||||
|
|
|
@ -47,9 +47,6 @@ class BaseMetric(object):
|
|||
for k, v in zip(self.labels, values)])
|
||||
)
|
||||
|
||||
def render(self):
|
||||
return map_concat(self.render_item, sorted(self.counts.keys()))
|
||||
|
||||
|
||||
class CounterMetric(BaseMetric):
|
||||
"""The simplest kind of metric; one that stores a monotonically-increasing
|
||||
|
@ -83,6 +80,9 @@ class CounterMetric(BaseMetric):
|
|||
def render_item(self, k):
|
||||
return ["%s%s %d" % (self.name, self._render_key(k), self.counts[k])]
|
||||
|
||||
def render(self):
|
||||
return map_concat(self.render_item, sorted(self.counts.keys()))
|
||||
|
||||
|
||||
class CallbackMetric(BaseMetric):
|
||||
"""A metric that returns the numeric value returned by a callback whenever
|
||||
|
@ -126,30 +126,70 @@ class DistributionMetric(object):
|
|||
|
||||
|
||||
class CacheMetric(object):
|
||||
"""A combination of two CounterMetrics, one to count cache hits and one to
|
||||
count a total, and a callback metric to yield the current size.
|
||||
__slots__ = ("name", "cache_name", "hits", "misses", "size_callback")
|
||||
|
||||
This metric generates standard metric name pairs, so that monitoring rules
|
||||
can easily be applied to measure hit ratio."""
|
||||
|
||||
def __init__(self, name, size_callback, labels=[]):
|
||||
def __init__(self, name, size_callback, cache_name):
|
||||
self.name = name
|
||||
self.cache_name = cache_name
|
||||
|
||||
self.hits = CounterMetric(name + ":hits", labels=labels)
|
||||
self.total = CounterMetric(name + ":total", labels=labels)
|
||||
self.hits = 0
|
||||
self.misses = 0
|
||||
|
||||
self.size = CallbackMetric(
|
||||
name + ":size",
|
||||
callback=size_callback,
|
||||
labels=labels,
|
||||
)
|
||||
self.size_callback = size_callback
|
||||
|
||||
def inc_hits(self, *values):
|
||||
self.hits.inc(*values)
|
||||
self.total.inc(*values)
|
||||
def inc_hits(self):
|
||||
self.hits += 1
|
||||
|
||||
def inc_misses(self, *values):
|
||||
self.total.inc(*values)
|
||||
def inc_misses(self):
|
||||
self.misses += 1
|
||||
|
||||
def render(self):
|
||||
return self.hits.render() + self.total.render() + self.size.render()
|
||||
size = self.size_callback()
|
||||
hits = self.hits
|
||||
total = self.misses + self.hits
|
||||
|
||||
return [
|
||||
"""%s:hits{name="%s"} %d""" % (self.name, self.cache_name, hits),
|
||||
"""%s:total{name="%s"} %d""" % (self.name, self.cache_name, total),
|
||||
"""%s:size{name="%s"} %d""" % (self.name, self.cache_name, size),
|
||||
]
|
||||
|
||||
|
||||
class MemoryUsageMetric(object):
|
||||
"""Keeps track of the current memory usage, using psutil.
|
||||
|
||||
The class will keep the current min/max/sum/counts of rss over the last
|
||||
WINDOW_SIZE_SEC, by polling UPDATE_HZ times per second
|
||||
"""
|
||||
|
||||
UPDATE_HZ = 2 # number of times to get memory per second
|
||||
WINDOW_SIZE_SEC = 30 # the size of the window in seconds
|
||||
|
||||
def __init__(self, hs, psutil):
|
||||
clock = hs.get_clock()
|
||||
self.memory_snapshots = []
|
||||
|
||||
self.process = psutil.Process()
|
||||
|
||||
clock.looping_call(self._update_curr_values, 1000 / self.UPDATE_HZ)
|
||||
|
||||
def _update_curr_values(self):
|
||||
max_size = self.UPDATE_HZ * self.WINDOW_SIZE_SEC
|
||||
self.memory_snapshots.append(self.process.memory_info().rss)
|
||||
self.memory_snapshots[:] = self.memory_snapshots[-max_size:]
|
||||
|
||||
def render(self):
|
||||
if not self.memory_snapshots:
|
||||
return []
|
||||
|
||||
max_rss = max(self.memory_snapshots)
|
||||
min_rss = min(self.memory_snapshots)
|
||||
sum_rss = sum(self.memory_snapshots)
|
||||
len_rss = len(self.memory_snapshots)
|
||||
|
||||
return [
|
||||
"process_psutil_rss:max %d" % max_rss,
|
||||
"process_psutil_rss:min %d" % min_rss,
|
||||
"process_psutil_rss:total %d" % sum_rss,
|
||||
"process_psutil_rss:count %d" % len_rss,
|
||||
]
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.errors import AuthError
|
||||
|
||||
from synapse.util.logutils import log_function
|
||||
|
@ -140,8 +140,6 @@ class Notifier(object):
|
|||
UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000
|
||||
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
|
||||
self.user_to_user_stream = {}
|
||||
self.room_to_user_streams = {}
|
||||
self.appservice_to_user_streams = {}
|
||||
|
@ -151,10 +149,8 @@ class Notifier(object):
|
|||
self.pending_new_room_events = []
|
||||
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
hs.get_distributor().observe(
|
||||
"user_joined_room", self._user_joined_room
|
||||
)
|
||||
self.appservice_handler = hs.get_application_service_handler()
|
||||
self.state_handler = hs.get_state_handler()
|
||||
|
||||
self.clock.looping_call(
|
||||
self.remove_expired_streams, self.UNUSED_STREAM_EXPIRY_MS
|
||||
|
@ -232,9 +228,7 @@ class Notifier(object):
|
|||
def _on_new_room_event(self, event, room_stream_id, extra_users=[]):
|
||||
"""Notify any user streams that are interested in this room event"""
|
||||
# poke any interested application service.
|
||||
self.hs.get_handlers().appservice_handler.notify_interested_services(
|
||||
event
|
||||
)
|
||||
self.appservice_handler.notify_interested_services(event)
|
||||
|
||||
app_streams = set()
|
||||
|
||||
|
@ -250,6 +244,9 @@ class Notifier(object):
|
|||
)
|
||||
app_streams |= app_user_streams
|
||||
|
||||
if event.type == EventTypes.Member and event.membership == Membership.JOIN:
|
||||
self._user_joined_room(event.state_key, event.room_id)
|
||||
|
||||
self.on_new_event(
|
||||
"room_key", room_stream_id,
|
||||
users=extra_users,
|
||||
|
@ -449,7 +446,7 @@ class Notifier(object):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def _is_world_readable(self, room_id):
|
||||
state = yield self.hs.get_state_handler().get_current_state(
|
||||
state = yield self.state_handler.get_current_state(
|
||||
room_id,
|
||||
EventTypes.RoomHistoryVisibility
|
||||
)
|
||||
|
@ -485,9 +482,8 @@ class Notifier(object):
|
|||
user_stream.appservice, set()
|
||||
).add(user_stream)
|
||||
|
||||
def _user_joined_room(self, user, room_id):
|
||||
user = str(user)
|
||||
new_user_stream = self.user_to_user_stream.get(user)
|
||||
def _user_joined_room(self, user_id, room_id):
|
||||
new_user_stream = self.user_to_user_stream.get(user_id)
|
||||
if new_user_stream is not None:
|
||||
room_streams = self.room_to_user_streams.setdefault(room_id, set())
|
||||
room_streams.add(new_user_stream)
|
||||
|
|
|
@ -40,7 +40,7 @@ class ActionGenerator:
|
|||
def handle_push_actions_for_event(self, event, context):
|
||||
with Measure(self.clock, "handle_push_actions_for_event"):
|
||||
bulk_evaluator = yield evaluator_for_event(
|
||||
event, self.hs, self.store
|
||||
event, self.hs, self.store, context.current_state
|
||||
)
|
||||
|
||||
actions_by_user = yield bulk_evaluator.action_for_event_by_user(
|
||||
|
|
|
@ -14,84 +14,56 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import ujson as json
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from .baserules import list_with_base_rules
|
||||
from .push_rule_evaluator import PushRuleEvaluatorForEvent
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.visibility import filter_events_for_clients
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def decode_rule_json(rule):
|
||||
rule['conditions'] = json.loads(rule['conditions'])
|
||||
rule['actions'] = json.loads(rule['actions'])
|
||||
return rule
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_rules(room_id, user_ids, store):
|
||||
rules_by_user = yield store.bulk_get_push_rules(user_ids)
|
||||
rules_enabled_by_user = yield store.bulk_get_push_rules_enabled(user_ids)
|
||||
|
||||
rules_by_user = {
|
||||
uid: list_with_base_rules([
|
||||
decode_rule_json(rule_list)
|
||||
for rule_list in rules_by_user.get(uid, [])
|
||||
])
|
||||
for uid in user_ids
|
||||
}
|
||||
|
||||
# We apply the rules-enabled map here: bulk_get_push_rules doesn't
|
||||
# fetch disabled rules, but this won't account for any server default
|
||||
# rules the user has disabled, so we need to do this too.
|
||||
for uid in user_ids:
|
||||
if uid not in rules_enabled_by_user:
|
||||
continue
|
||||
|
||||
user_enabled_map = rules_enabled_by_user[uid]
|
||||
|
||||
for i, rule in enumerate(rules_by_user[uid]):
|
||||
rule_id = rule['rule_id']
|
||||
|
||||
if rule_id in user_enabled_map:
|
||||
if rule.get('enabled', True) != bool(user_enabled_map[rule_id]):
|
||||
# Rules are cached across users.
|
||||
rule = dict(rule)
|
||||
rule['enabled'] = bool(user_enabled_map[rule_id])
|
||||
rules_by_user[uid][i] = rule
|
||||
rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
|
||||
|
||||
defer.returnValue(rules_by_user)
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def evaluator_for_event(event, hs, store):
|
||||
def evaluator_for_event(event, hs, store, current_state):
|
||||
room_id = event.room_id
|
||||
|
||||
# users in the room who have pushers need to get push rules run because
|
||||
# that's how their pushers work
|
||||
users_with_pushers = yield store.get_users_with_pushers_in_room(room_id)
|
||||
|
||||
# We also will want to generate notifs for other people in the room so
|
||||
# their unread countss are correct in the event stream, but to avoid
|
||||
# generating them for bot / AS users etc, we only do so for people who've
|
||||
# sent a read receipt into the room.
|
||||
|
||||
all_in_room = yield store.get_users_in_room(room_id)
|
||||
all_in_room = set(all_in_room)
|
||||
local_users_in_room = set(
|
||||
e.state_key for e in current_state.values()
|
||||
if e.type == EventTypes.Member and e.membership == Membership.JOIN
|
||||
and hs.is_mine_id(e.state_key)
|
||||
)
|
||||
|
||||
receipts = yield store.get_receipts_for_room(room_id, "m.read")
|
||||
# users in the room who have pushers need to get push rules run because
|
||||
# that's how their pushers work
|
||||
if_users_with_pushers = yield store.get_if_users_have_pushers(
|
||||
local_users_in_room
|
||||
)
|
||||
user_ids = set(
|
||||
uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
|
||||
)
|
||||
|
||||
users_with_receipts = yield store.get_users_with_read_receipts_in_room(room_id)
|
||||
|
||||
# any users with pushers must be ours: they have pushers
|
||||
user_ids = set(users_with_pushers)
|
||||
for r in receipts:
|
||||
if hs.is_mine_id(r['user_id']) and r['user_id'] in all_in_room:
|
||||
user_ids.add(r['user_id'])
|
||||
for uid in users_with_receipts:
|
||||
if uid in local_users_in_room:
|
||||
user_ids.add(uid)
|
||||
|
||||
# if this event is an invite event, we may need to run rules for the user
|
||||
# who's been invited, otherwise they won't get told they've been invited
|
||||
|
@ -102,8 +74,6 @@ def evaluator_for_event(event, hs, store):
|
|||
if has_pusher:
|
||||
user_ids.add(invited_user)
|
||||
|
||||
user_ids = list(user_ids)
|
||||
|
||||
rules_by_user = yield _get_rules(room_id, user_ids, store)
|
||||
|
||||
defer.returnValue(BulkPushRuleEvaluator(
|
||||
|
@ -141,7 +111,10 @@ class BulkPushRuleEvaluator:
|
|||
self.store, user_tuples, [event], {event.event_id: current_state}
|
||||
)
|
||||
|
||||
room_members = yield self.store.get_users_in_room(self.room_id)
|
||||
room_members = set(
|
||||
e.state_key for e in current_state.values()
|
||||
if e.type == EventTypes.Member and e.membership == Membership.JOIN
|
||||
)
|
||||
|
||||
evaluator = PushRuleEvaluatorForEvent(event, len(room_members))
|
||||
|
||||
|
|
|
@ -13,29 +13,19 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.push.baserules import list_with_base_rules
|
||||
|
||||
from synapse.push.rulekinds import (
|
||||
PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP
|
||||
)
|
||||
|
||||
import copy
|
||||
import simplejson as json
|
||||
|
||||
|
||||
def format_push_rules_for_user(user, rawrules, enabled_map):
|
||||
def format_push_rules_for_user(user, ruleslist):
|
||||
"""Converts a list of rawrules and a enabled map into nested dictionaries
|
||||
to match the Matrix client-server format for push rules"""
|
||||
|
||||
ruleslist = []
|
||||
for rawrule in rawrules:
|
||||
rule = dict(rawrule)
|
||||
rule["conditions"] = json.loads(rawrule["conditions"])
|
||||
rule["actions"] = json.loads(rawrule["actions"])
|
||||
ruleslist.append(rule)
|
||||
|
||||
# We're going to be mutating this a lot, so do a deep copy
|
||||
ruleslist = copy.deepcopy(list_with_base_rules(ruleslist))
|
||||
ruleslist = copy.deepcopy(ruleslist)
|
||||
|
||||
rules = {'global': {}, 'device': {}}
|
||||
|
||||
|
@ -60,9 +50,7 @@ def format_push_rules_for_user(user, rawrules, enabled_map):
|
|||
|
||||
template_rule = _rule_to_template(r)
|
||||
if template_rule:
|
||||
if r['rule_id'] in enabled_map:
|
||||
template_rule['enabled'] = enabled_map[r['rule_id']]
|
||||
elif 'enabled' in r:
|
||||
if 'enabled' in r:
|
||||
template_rule['enabled'] = r['enabled']
|
||||
else:
|
||||
template_rule['enabled'] = True
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
|
||||
|
||||
import logging
|
||||
|
||||
|
@ -32,12 +33,20 @@ DELAY_BEFORE_MAIL_MS = 10 * 60 * 1000
|
|||
# Each room maintains its own throttle counter, but each new mail notification
|
||||
# sends the pending notifications for all rooms.
|
||||
THROTTLE_START_MS = 10 * 60 * 1000
|
||||
THROTTLE_MAX_MS = 24 * 60 * 60 * 1000 # (2 * 60 * 1000) * (2 ** 11) # ~3 days
|
||||
THROTTLE_MULTIPLIER = 6 # 10 mins, 1 hour, 6 hours, 24 hours
|
||||
THROTTLE_MAX_MS = 24 * 60 * 60 * 1000 # 24h
|
||||
# THROTTLE_MULTIPLIER = 6 # 10 mins, 1 hour, 6 hours, 24 hours
|
||||
THROTTLE_MULTIPLIER = 144 # 10 mins, 24 hours - i.e. jump straight to 1 day
|
||||
|
||||
# If no event triggers a notification for this long after the previous,
|
||||
# the throttle is released.
|
||||
THROTTLE_RESET_AFTER_MS = (2 * 60 * 1000) * (2 ** 11) # ~3 days
|
||||
# 12 hours - a gap of 12 hours in conversation is surely enough to merit a new
|
||||
# notification when things get going again...
|
||||
THROTTLE_RESET_AFTER_MS = (12 * 60 * 60 * 1000)
|
||||
|
||||
# does each email include all unread notifs, or just the ones which have happened
|
||||
# since the last mail?
|
||||
# XXX: this is currently broken as it includes ones from parted rooms(!)
|
||||
INCLUDE_ALL_UNREAD_NOTIFS = False
|
||||
|
||||
|
||||
class EmailPusher(object):
|
||||
|
@ -65,7 +74,12 @@ class EmailPusher(object):
|
|||
self.processing = False
|
||||
|
||||
if self.hs.config.email_enable_notifs:
|
||||
self.mailer = Mailer(self.hs)
|
||||
if 'data' in pusherdict and 'brand' in pusherdict['data']:
|
||||
app_name = pusherdict['data']['brand']
|
||||
else:
|
||||
app_name = self.hs.config.email_app_name
|
||||
|
||||
self.mailer = Mailer(self.hs, app_name)
|
||||
else:
|
||||
self.mailer = None
|
||||
|
||||
|
@ -79,7 +93,11 @@ class EmailPusher(object):
|
|||
|
||||
def on_stop(self):
|
||||
if self.timed_call:
|
||||
try:
|
||||
self.timed_call.cancel()
|
||||
except (AlreadyCalled, AlreadyCancelled):
|
||||
pass
|
||||
self.timed_call = None
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_new_notifications(self, min_stream_ordering, max_stream_ordering):
|
||||
|
@ -126,9 +144,9 @@ class EmailPusher(object):
|
|||
up logging, measures and guards against multiple instances of it
|
||||
being run.
|
||||
"""
|
||||
unprocessed = yield self.store.get_unread_push_actions_for_user_in_range(
|
||||
self.user_id, self.last_stream_ordering, self.max_stream_ordering
|
||||
)
|
||||
start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering
|
||||
fn = self.store.get_unread_push_actions_for_user_in_range_for_email
|
||||
unprocessed = yield fn(self.user_id, start, self.max_stream_ordering)
|
||||
|
||||
soonest_due_at = None
|
||||
|
||||
|
@ -150,7 +168,6 @@ class EmailPusher(object):
|
|||
# we then consider all previously outstanding notifications
|
||||
# to be delivered.
|
||||
|
||||
# debugging:
|
||||
reason = {
|
||||
'room_id': push_action['room_id'],
|
||||
'now': self.clock.time_msec(),
|
||||
|
@ -165,8 +182,11 @@ class EmailPusher(object):
|
|||
yield self.save_last_stream_ordering_and_success(max([
|
||||
ea['stream_ordering'] for ea in unprocessed
|
||||
]))
|
||||
|
||||
# we update the throttle on all the possible unprocessed push actions
|
||||
for ea in unprocessed:
|
||||
yield self.sent_notif_update_throttle(
|
||||
push_action['room_id'], push_action
|
||||
ea['room_id'], ea
|
||||
)
|
||||
break
|
||||
else:
|
||||
|
@ -174,7 +194,10 @@ class EmailPusher(object):
|
|||
soonest_due_at = should_notify_at
|
||||
|
||||
if self.timed_call is not None:
|
||||
try:
|
||||
self.timed_call.cancel()
|
||||
except (AlreadyCalled, AlreadyCancelled):
|
||||
pass
|
||||
self.timed_call = None
|
||||
|
||||
if soonest_due_at is not None:
|
||||
|
@ -263,5 +286,5 @@ class EmailPusher(object):
|
|||
logger.info("Sending notif email for user %r", self.user_id)
|
||||
|
||||
yield self.mailer.send_notification_mail(
|
||||
self.user_id, self.email, push_actions, reason
|
||||
self.app_id, self.user_id, self.email, push_actions, reason
|
||||
)
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
from synapse.push import PusherConfigException
|
||||
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
|
||||
|
||||
import logging
|
||||
import push_rule_evaluator
|
||||
|
@ -38,6 +39,7 @@ class HttpPusher(object):
|
|||
self.hs = hs
|
||||
self.store = self.hs.get_datastore()
|
||||
self.clock = self.hs.get_clock()
|
||||
self.state_handler = self.hs.get_state_handler()
|
||||
self.user_id = pusherdict['user_name']
|
||||
self.app_id = pusherdict['app_id']
|
||||
self.app_display_name = pusherdict['app_display_name']
|
||||
|
@ -108,7 +110,11 @@ class HttpPusher(object):
|
|||
|
||||
def on_stop(self):
|
||||
if self.timed_call:
|
||||
try:
|
||||
self.timed_call.cancel()
|
||||
except (AlreadyCalled, AlreadyCancelled):
|
||||
pass
|
||||
self.timed_call = None
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _process(self):
|
||||
|
@ -140,7 +146,8 @@ class HttpPusher(object):
|
|||
run once per pusher.
|
||||
"""
|
||||
|
||||
unprocessed = yield self.store.get_unread_push_actions_for_user_in_range(
|
||||
fn = self.store.get_unread_push_actions_for_user_in_range_for_http
|
||||
unprocessed = yield fn(
|
||||
self.user_id, self.last_stream_ordering, self.max_stream_ordering
|
||||
)
|
||||
|
||||
|
@ -237,7 +244,9 @@ class HttpPusher(object):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def _build_notification_dict(self, event, tweaks, badge):
|
||||
ctx = yield push_tools.get_context_for_event(self.hs.get_datastore(), event)
|
||||
ctx = yield push_tools.get_context_for_event(
|
||||
self.state_handler, event, self.user_id
|
||||
)
|
||||
|
||||
d = {
|
||||
'notification': {
|
||||
|
@ -269,8 +278,8 @@ class HttpPusher(object):
|
|||
if 'content' in event:
|
||||
d['notification']['content'] = event.content
|
||||
|
||||
if len(ctx['aliases']):
|
||||
d['notification']['room_alias'] = ctx['aliases'][0]
|
||||
# We no longer send aliases separately, instead, we send the human
|
||||
# readable name of the room, which may be an alias.
|
||||
if 'sender_display_name' in ctx and len(ctx['sender_display_name']) > 0:
|
||||
d['notification']['sender_display_name'] = ctx['sender_display_name']
|
||||
if 'name' in ctx and len(ctx['name']) > 0:
|
||||
|
|
|
@ -41,11 +41,14 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
MESSAGE_FROM_PERSON_IN_ROOM = "You have a message on %(app)s from %(person)s " \
|
||||
"in the %s room..."
|
||||
"in the %(room)s room..."
|
||||
MESSAGE_FROM_PERSON = "You have a message on %(app)s from %(person)s..."
|
||||
MESSAGES_FROM_PERSON = "You have messages on %(app)s from %(person)s..."
|
||||
MESSAGES_IN_ROOM = "There are some messages on %(app)s for you in the %(room)s room..."
|
||||
MESSAGES_IN_ROOMS = "Here are some messages on %(app)s you may have missed..."
|
||||
MESSAGES_IN_ROOM = "You have messages on %(app)s in the %(room)s room..."
|
||||
MESSAGES_IN_ROOM_AND_OTHERS = \
|
||||
"You have messages on %(app)s in the %(room)s room and others..."
|
||||
MESSAGES_FROM_PERSON_AND_OTHERS = \
|
||||
"You have messages on %(app)s from %(person)s and others..."
|
||||
INVITE_FROM_PERSON_TO_ROOM = "%(person)s has invited you to join the " \
|
||||
"%(room)s room on %(app)s..."
|
||||
INVITE_FROM_PERSON = "%(person)s has invited you to chat on %(app)s..."
|
||||
|
@ -75,12 +78,14 @@ ALLOWED_ATTRS = {
|
|||
|
||||
|
||||
class Mailer(object):
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs, app_name):
|
||||
self.hs = hs
|
||||
self.store = self.hs.get_datastore()
|
||||
self.auth_handler = self.hs.get_auth_handler()
|
||||
self.state_handler = self.hs.get_state_handler()
|
||||
loader = jinja2.FileSystemLoader(self.hs.config.email_template_dir)
|
||||
self.app_name = self.hs.config.email_app_name
|
||||
self.app_name = app_name
|
||||
logger.info("Created Mailer for app_name %s" % app_name)
|
||||
env = jinja2.Environment(loader=loader)
|
||||
env.filters["format_ts"] = format_ts_filter
|
||||
env.filters["mxc_to_http"] = self.mxc_to_http_filter
|
||||
|
@ -92,8 +97,16 @@ class Mailer(object):
|
|||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send_notification_mail(self, user_id, email_address, push_actions, reason):
|
||||
raw_from = email.utils.parseaddr(self.hs.config.email_notif_from)[1]
|
||||
def send_notification_mail(self, app_id, user_id, email_address,
|
||||
push_actions, reason):
|
||||
try:
|
||||
from_string = self.hs.config.email_notif_from % {
|
||||
"app": self.app_name
|
||||
}
|
||||
except TypeError:
|
||||
from_string = self.hs.config.email_notif_from
|
||||
|
||||
raw_from = email.utils.parseaddr(from_string)[1]
|
||||
raw_to = email.utils.parseaddr(email_address)[1]
|
||||
|
||||
if raw_to == '':
|
||||
|
@ -119,6 +132,8 @@ class Mailer(object):
|
|||
user_display_name = yield self.store.get_profile_displayname(
|
||||
UserID.from_string(user_id).localpart
|
||||
)
|
||||
if user_display_name is None:
|
||||
user_display_name = user_id
|
||||
except StoreError:
|
||||
user_display_name = user_id
|
||||
|
||||
|
@ -128,9 +143,14 @@ class Mailer(object):
|
|||
state_by_room[room_id] = room_state
|
||||
|
||||
# Run at most 3 of these at once: sync does 10 at a time but email
|
||||
# notifs are much realtime than sync so we can afford to wait a bit.
|
||||
# notifs are much less realtime than sync so we can afford to wait a bit.
|
||||
yield concurrently_execute(_fetch_room_state, rooms_in_order, 3)
|
||||
|
||||
# actually sort our so-called rooms_in_order list, most recent room first
|
||||
rooms_in_order.sort(
|
||||
key=lambda r: -(notifs_by_room[r][-1]['received_ts'] or 0)
|
||||
)
|
||||
|
||||
rooms = []
|
||||
|
||||
for r in rooms_in_order:
|
||||
|
@ -139,17 +159,19 @@ class Mailer(object):
|
|||
)
|
||||
rooms.append(roomvars)
|
||||
|
||||
summary_text = self.make_summary_text(
|
||||
notifs_by_room, state_by_room, notif_events, user_id
|
||||
reason['room_name'] = calculate_room_name(
|
||||
state_by_room[reason['room_id']], user_id, fallback_to_members=True
|
||||
)
|
||||
|
||||
reason['room_name'] = calculate_room_name(
|
||||
state_by_room[reason['room_id']], user_id, fallback_to_members=False
|
||||
summary_text = self.make_summary_text(
|
||||
notifs_by_room, state_by_room, notif_events, user_id, reason
|
||||
)
|
||||
|
||||
template_vars = {
|
||||
"user_display_name": user_display_name,
|
||||
"unsubscribe_link": self.make_unsubscribe_link(),
|
||||
"unsubscribe_link": self.make_unsubscribe_link(
|
||||
user_id, app_id, email_address
|
||||
),
|
||||
"summary_text": summary_text,
|
||||
"app_name": self.app_name,
|
||||
"rooms": rooms,
|
||||
|
@ -164,7 +186,7 @@ class Mailer(object):
|
|||
|
||||
multipart_msg = MIMEMultipart('alternative')
|
||||
multipart_msg['Subject'] = "[%s] %s" % (self.app_name, summary_text)
|
||||
multipart_msg['From'] = self.hs.config.email_notif_from
|
||||
multipart_msg['From'] = from_string
|
||||
multipart_msg['To'] = email_address
|
||||
multipart_msg['Date'] = email.utils.formatdate()
|
||||
multipart_msg['Message-ID'] = email.utils.make_msgid()
|
||||
|
@ -251,14 +273,16 @@ class Mailer(object):
|
|||
|
||||
sender_state_event = room_state[("m.room.member", event.sender)]
|
||||
sender_name = name_from_member_event(sender_state_event)
|
||||
sender_avatar_url = sender_state_event.content["avatar_url"]
|
||||
sender_avatar_url = sender_state_event.content.get("avatar_url")
|
||||
|
||||
# 'hash' for deterministically picking default images: use
|
||||
# sender_hash % the number of default images to choose from
|
||||
sender_hash = string_ordinal_total(event.sender)
|
||||
|
||||
msgtype = event.content.get("msgtype")
|
||||
|
||||
ret = {
|
||||
"msgtype": event.content["msgtype"],
|
||||
"msgtype": msgtype,
|
||||
"is_historical": event.event_id != notif['event_id'],
|
||||
"id": event.event_id,
|
||||
"ts": event.origin_server_ts,
|
||||
|
@ -267,9 +291,9 @@ class Mailer(object):
|
|||
"sender_hash": sender_hash,
|
||||
}
|
||||
|
||||
if event.content["msgtype"] == "m.text":
|
||||
if msgtype == "m.text":
|
||||
self.add_text_message_vars(ret, event)
|
||||
elif event.content["msgtype"] == "m.image":
|
||||
elif msgtype == "m.image":
|
||||
self.add_image_message_vars(ret, event)
|
||||
|
||||
if "body" in event.content:
|
||||
|
@ -278,16 +302,17 @@ class Mailer(object):
|
|||
return ret
|
||||
|
||||
def add_text_message_vars(self, messagevars, event):
|
||||
if "format" in event.content:
|
||||
msgformat = event.content["format"]
|
||||
else:
|
||||
msgformat = None
|
||||
msgformat = event.content.get("format")
|
||||
|
||||
messagevars["format"] = msgformat
|
||||
|
||||
if msgformat == "org.matrix.custom.html":
|
||||
messagevars["body_text_html"] = safe_markup(event.content["formatted_body"])
|
||||
else:
|
||||
messagevars["body_text_html"] = safe_text(event.content["body"])
|
||||
formatted_body = event.content.get("formatted_body")
|
||||
body = event.content.get("body")
|
||||
|
||||
if msgformat == "org.matrix.custom.html" and formatted_body:
|
||||
messagevars["body_text_html"] = safe_markup(formatted_body)
|
||||
elif body:
|
||||
messagevars["body_text_html"] = safe_text(body)
|
||||
|
||||
return messagevars
|
||||
|
||||
|
@ -296,7 +321,8 @@ class Mailer(object):
|
|||
|
||||
return messagevars
|
||||
|
||||
def make_summary_text(self, notifs_by_room, state_by_room, notif_events, user_id):
|
||||
def make_summary_text(self, notifs_by_room, state_by_room,
|
||||
notif_events, user_id, reason):
|
||||
if len(notifs_by_room) == 1:
|
||||
# Only one room has new stuff
|
||||
room_id = notifs_by_room.keys()[0]
|
||||
|
@ -371,7 +397,26 @@ class Mailer(object):
|
|||
}
|
||||
else:
|
||||
# Stuff's happened in multiple different rooms
|
||||
return MESSAGES_IN_ROOMS % {
|
||||
|
||||
# ...but we still refer to the 'reason' room which triggered the mail
|
||||
if reason['room_name'] is not None:
|
||||
return MESSAGES_IN_ROOM_AND_OTHERS % {
|
||||
"room": reason['room_name'],
|
||||
"app": self.app_name,
|
||||
}
|
||||
else:
|
||||
# If the reason room doesn't have a name, say who the messages
|
||||
# are from explicitly to avoid, "messages in the Bob room"
|
||||
sender_ids = list(set([
|
||||
notif_events[n['event_id']].sender
|
||||
for n in notifs_by_room[reason['room_id']]
|
||||
]))
|
||||
|
||||
return MESSAGES_FROM_PERSON_AND_OTHERS % {
|
||||
"person": descriptor_from_member_events([
|
||||
state_by_room[reason['room_id']][("m.room.member", s)]
|
||||
for s in sender_ids
|
||||
]),
|
||||
"app": self.app_name,
|
||||
}
|
||||
|
||||
|
@ -393,9 +438,18 @@ class Mailer(object):
|
|||
notif['room_id'], notif['event_id']
|
||||
)
|
||||
|
||||
def make_unsubscribe_link(self):
|
||||
# XXX: matrix.to
|
||||
return "https://vector.im/#/settings"
|
||||
def make_unsubscribe_link(self, user_id, app_id, email_address):
|
||||
params = {
|
||||
"access_token": self.auth_handler.generate_delete_pusher_token(user_id),
|
||||
"app_id": app_id,
|
||||
"pushkey": email_address,
|
||||
}
|
||||
|
||||
# XXX: make r0 once API is stable
|
||||
return "%s_matrix/client/unstable/pushers/remove?%s" % (
|
||||
self.hs.config.public_baseurl,
|
||||
urllib.urlencode(params),
|
||||
)
|
||||
|
||||
def mxc_to_http_filter(self, value, width, height, resize_method="crop"):
|
||||
if value[0:6] != "mxc://":
|
||||
|
|
|
@ -14,6 +14,9 @@
|
|||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
from synapse.util.presentable_names import (
|
||||
calculate_room_name, name_from_member_event
|
||||
)
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -45,24 +48,21 @@ def get_badge_count(store, user_id):
|
|||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_context_for_event(store, ev):
|
||||
name_aliases = yield store.get_room_name_and_aliases(
|
||||
ev.room_id
|
||||
)
|
||||
def get_context_for_event(state_handler, ev, user_id):
|
||||
ctx = {}
|
||||
|
||||
ctx = {'aliases': name_aliases[1]}
|
||||
if name_aliases[0] is not None:
|
||||
ctx['name'] = name_aliases[0]
|
||||
room_state = yield state_handler.get_current_state(ev.room_id)
|
||||
|
||||
their_member_events_for_room = yield store.get_current_state(
|
||||
room_id=ev.room_id,
|
||||
event_type='m.room.member',
|
||||
state_key=ev.user_id
|
||||
# we no longer bother setting room_alias, and make room_name the
|
||||
# human-readable name instead, be that m.room.name, an alias or
|
||||
# a list of people in the room
|
||||
name = calculate_room_name(
|
||||
room_state, user_id, fallback_to_single_member=False
|
||||
)
|
||||
for mev in their_member_events_for_room:
|
||||
if mev.content['membership'] == 'join' and 'displayname' in mev.content:
|
||||
dn = mev.content['displayname']
|
||||
if dn is not None:
|
||||
ctx['sender_display_name'] = dn
|
||||
if name:
|
||||
ctx['name'] = name
|
||||
|
||||
sender_state_event = room_state[("m.room.member", ev.sender)]
|
||||
ctx['sender_display_name'] = name_from_member_event(sender_state_event)
|
||||
|
||||
defer.returnValue(ctx)
|
||||
|
|
|
@ -48,6 +48,12 @@ CONDITIONAL_REQUIREMENTS = {
|
|||
"Jinja2>=2.8": ["Jinja2>=2.8"],
|
||||
"bleach>=1.4.2": ["bleach>=1.4.2"],
|
||||
},
|
||||
"ldap": {
|
||||
"ldap3>=1.0": ["ldap3>=1.0"],
|
||||
},
|
||||
"psutil": {
|
||||
"psutil>=2.0.0": ["psutil>=2.0.0"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
|
59
synapse/replication/presence_resource.py
Normal file
59
synapse/replication/presence_resource.py
Normal file
|
@ -0,0 +1,59 @@
|
|||
# Copyright 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.http.server import respond_with_json_bytes, request_handler
|
||||
from synapse.http.servlet import parse_json_object_from_request
|
||||
|
||||
from twisted.web.resource import Resource
|
||||
from twisted.web.server import NOT_DONE_YET
|
||||
from twisted.internet import defer
|
||||
|
||||
|
||||
class PresenceResource(Resource):
|
||||
"""
|
||||
HTTP endpoint for marking users as syncing.
|
||||
|
||||
POST /_synapse/replication/presence HTTP/1.1
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"process_id": "<process_id>",
|
||||
"syncing_users": ["<user_id>"]
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
Resource.__init__(self) # Resource is old-style, so no super()
|
||||
|
||||
self.version_string = hs.version_string
|
||||
self.clock = hs.get_clock()
|
||||
self.presence_handler = hs.get_presence_handler()
|
||||
|
||||
def render_POST(self, request):
|
||||
self._async_render_POST(request)
|
||||
return NOT_DONE_YET
|
||||
|
||||
@request_handler()
|
||||
@defer.inlineCallbacks
|
||||
def _async_render_POST(self, request):
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
process_id = content["process_id"]
|
||||
syncing_user_ids = content["syncing_users"]
|
||||
|
||||
yield self.presence_handler.update_external_syncs(
|
||||
process_id, set(syncing_user_ids)
|
||||
)
|
||||
|
||||
respond_with_json_bytes(request, 200, "{}")
|
|
@ -16,6 +16,7 @@
|
|||
from synapse.http.servlet import parse_integer, parse_string
|
||||
from synapse.http.server import request_handler, finish_request
|
||||
from synapse.replication.pusher_resource import PusherResource
|
||||
from synapse.replication.presence_resource import PresenceResource
|
||||
|
||||
from twisted.web.resource import Resource
|
||||
from twisted.web.server import NOT_DONE_YET
|
||||
|
@ -115,6 +116,7 @@ class ReplicationResource(Resource):
|
|||
self.clock = hs.get_clock()
|
||||
|
||||
self.putChild("remove_pushers", PusherResource(hs))
|
||||
self.putChild("syncing_users", PresenceResource(hs))
|
||||
|
||||
def render_GET(self, request):
|
||||
self._async_render_GET(request)
|
||||
|
|
|
@ -15,7 +15,10 @@
|
|||
|
||||
from ._base import BaseSlavedStore
|
||||
from ._slaved_id_tracker import SlavedIdTracker
|
||||
from synapse.storage import DataStore
|
||||
from synapse.storage.account_data import AccountDataStore
|
||||
from synapse.storage.tags import TagsStore
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
|
||||
|
||||
class SlavedAccountDataStore(BaseSlavedStore):
|
||||
|
@ -25,6 +28,14 @@ class SlavedAccountDataStore(BaseSlavedStore):
|
|||
self._account_data_id_gen = SlavedIdTracker(
|
||||
db_conn, "account_data_max_stream_id", "stream_id",
|
||||
)
|
||||
self._account_data_stream_cache = StreamChangeCache(
|
||||
"AccountDataAndTagsChangeCache",
|
||||
self._account_data_id_gen.get_current_token(),
|
||||
)
|
||||
|
||||
get_account_data_for_user = (
|
||||
AccountDataStore.__dict__["get_account_data_for_user"]
|
||||
)
|
||||
|
||||
get_global_account_data_by_type_for_users = (
|
||||
AccountDataStore.__dict__["get_global_account_data_by_type_for_users"]
|
||||
|
@ -34,6 +45,16 @@ class SlavedAccountDataStore(BaseSlavedStore):
|
|||
AccountDataStore.__dict__["get_global_account_data_by_type_for_user"]
|
||||
)
|
||||
|
||||
get_tags_for_user = TagsStore.__dict__["get_tags_for_user"]
|
||||
|
||||
get_updated_tags = DataStore.get_updated_tags.__func__
|
||||
get_updated_account_data_for_user = (
|
||||
DataStore.get_updated_account_data_for_user.__func__
|
||||
)
|
||||
|
||||
def get_max_account_data_stream_id(self):
|
||||
return self._account_data_id_gen.get_current_token()
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(SlavedAccountDataStore, self).stream_positions()
|
||||
position = self._account_data_id_gen.get_current_token()
|
||||
|
@ -47,15 +68,33 @@ class SlavedAccountDataStore(BaseSlavedStore):
|
|||
if stream:
|
||||
self._account_data_id_gen.advance(int(stream["position"]))
|
||||
for row in stream["rows"]:
|
||||
user_id, data_type = row[1:3]
|
||||
position, user_id, data_type = row[:3]
|
||||
self.get_global_account_data_by_type_for_user.invalidate(
|
||||
(data_type, user_id,)
|
||||
)
|
||||
self.get_account_data_for_user.invalidate((user_id,))
|
||||
self._account_data_stream_cache.entity_has_changed(
|
||||
user_id, position
|
||||
)
|
||||
|
||||
stream = result.get("room_account_data")
|
||||
if stream:
|
||||
self._account_data_id_gen.advance(int(stream["position"]))
|
||||
for row in stream["rows"]:
|
||||
position, user_id = row[:2]
|
||||
self.get_account_data_for_user.invalidate((user_id,))
|
||||
self._account_data_stream_cache.entity_has_changed(
|
||||
user_id, position
|
||||
)
|
||||
|
||||
stream = result.get("tag_account_data")
|
||||
if stream:
|
||||
self._account_data_id_gen.advance(int(stream["position"]))
|
||||
for row in stream["rows"]:
|
||||
position, user_id = row[:2]
|
||||
self.get_tags_for_user.invalidate((user_id,))
|
||||
self._account_data_stream_cache.entity_has_changed(
|
||||
user_id, position
|
||||
)
|
||||
|
||||
return super(SlavedAccountDataStore, self).process_replication(result)
|
||||
|
|
30
synapse/replication/slave/storage/appservice.py
Normal file
30
synapse/replication/slave/storage/appservice.py
Normal file
|
@ -0,0 +1,30 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import BaseSlavedStore
|
||||
from synapse.storage import DataStore
|
||||
from synapse.config.appservice import load_appservices
|
||||
|
||||
|
||||
class SlavedApplicationServiceStore(BaseSlavedStore):
|
||||
def __init__(self, db_conn, hs):
|
||||
super(SlavedApplicationServiceStore, self).__init__(db_conn, hs)
|
||||
self.services_cache = load_appservices(
|
||||
hs.config.server_name,
|
||||
hs.config.app_service_config_files
|
||||
)
|
||||
|
||||
get_app_service_by_token = DataStore.get_app_service_by_token.__func__
|
||||
get_app_service_by_user_id = DataStore.get_app_service_by_user_id.__func__
|
23
synapse/replication/slave/storage/directory.py
Normal file
23
synapse/replication/slave/storage/directory.py
Normal file
|
@ -0,0 +1,23 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import BaseSlavedStore
|
||||
from synapse.storage.directory import DirectoryStore
|
||||
|
||||
|
||||
class DirectoryStore(BaseSlavedStore):
|
||||
get_aliases_for_room = DirectoryStore.__dict__[
|
||||
"get_aliases_for_room"
|
||||
].orig
|
|
@ -18,11 +18,11 @@ from ._slaved_id_tracker import SlavedIdTracker
|
|||
from synapse.api.constants import EventTypes
|
||||
from synapse.events import FrozenEvent
|
||||
from synapse.storage import DataStore
|
||||
from synapse.storage.room import RoomStore
|
||||
from synapse.storage.roommember import RoomMemberStore
|
||||
from synapse.storage.event_federation import EventFederationStore
|
||||
from synapse.storage.event_push_actions import EventPushActionsStore
|
||||
from synapse.storage.state import StateStore
|
||||
from synapse.storage.stream import StreamStore
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
|
||||
import ujson as json
|
||||
|
@ -57,10 +57,12 @@ class SlavedEventStore(BaseSlavedStore):
|
|||
"EventsRoomStreamChangeCache", min_event_val,
|
||||
prefilled_cache=event_cache_prefill,
|
||||
)
|
||||
self._membership_stream_cache = StreamChangeCache(
|
||||
"MembershipStreamChangeCache", events_max,
|
||||
)
|
||||
|
||||
# Cached functions can't be accessed through a class instance so we need
|
||||
# to reach inside the __dict__ to extract them.
|
||||
get_room_name_and_aliases = RoomStore.__dict__["get_room_name_and_aliases"]
|
||||
get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"]
|
||||
get_users_in_room = RoomMemberStore.__dict__["get_users_in_room"]
|
||||
get_latest_event_ids_in_room = EventFederationStore.__dict__[
|
||||
|
@ -87,9 +89,15 @@ class SlavedEventStore(BaseSlavedStore):
|
|||
_get_state_group_from_group = (
|
||||
StateStore.__dict__["_get_state_group_from_group"]
|
||||
)
|
||||
get_recent_event_ids_for_room = (
|
||||
StreamStore.__dict__["get_recent_event_ids_for_room"]
|
||||
)
|
||||
|
||||
get_unread_push_actions_for_user_in_range = (
|
||||
DataStore.get_unread_push_actions_for_user_in_range.__func__
|
||||
get_unread_push_actions_for_user_in_range_for_http = (
|
||||
DataStore.get_unread_push_actions_for_user_in_range_for_http.__func__
|
||||
)
|
||||
get_unread_push_actions_for_user_in_range_for_email = (
|
||||
DataStore.get_unread_push_actions_for_user_in_range_for_email.__func__
|
||||
)
|
||||
get_push_action_users_in_range = (
|
||||
DataStore.get_push_action_users_in_range.__func__
|
||||
|
@ -109,24 +117,25 @@ class SlavedEventStore(BaseSlavedStore):
|
|||
DataStore.get_room_events_stream_for_room.__func__
|
||||
)
|
||||
get_events_around = DataStore.get_events_around.__func__
|
||||
get_state_for_event = DataStore.get_state_for_event.__func__
|
||||
get_state_for_events = DataStore.get_state_for_events.__func__
|
||||
get_state_groups = DataStore.get_state_groups.__func__
|
||||
get_recent_events_for_room = DataStore.get_recent_events_for_room.__func__
|
||||
get_room_events_stream_for_rooms = (
|
||||
DataStore.get_room_events_stream_for_rooms.__func__
|
||||
)
|
||||
get_stream_token_for_event = DataStore.get_stream_token_for_event.__func__
|
||||
|
||||
_set_before_and_after = DataStore._set_before_and_after
|
||||
_set_before_and_after = staticmethod(DataStore._set_before_and_after)
|
||||
|
||||
_get_events = DataStore._get_events.__func__
|
||||
_get_events_from_cache = DataStore._get_events_from_cache.__func__
|
||||
|
||||
_invalidate_get_event_cache = DataStore._invalidate_get_event_cache.__func__
|
||||
_parse_events_txn = DataStore._parse_events_txn.__func__
|
||||
_get_events_txn = DataStore._get_events_txn.__func__
|
||||
_get_event_txn = DataStore._get_event_txn.__func__
|
||||
_enqueue_events = DataStore._enqueue_events.__func__
|
||||
_do_fetch = DataStore._do_fetch.__func__
|
||||
_fetch_events_txn = DataStore._fetch_events_txn.__func__
|
||||
_fetch_event_rows = DataStore._fetch_event_rows.__func__
|
||||
_get_event_from_row = DataStore._get_event_from_row.__func__
|
||||
_get_event_from_row_txn = DataStore._get_event_from_row_txn.__func__
|
||||
_get_rooms_for_user_where_membership_is_txn = (
|
||||
DataStore._get_rooms_for_user_where_membership_is_txn.__func__
|
||||
)
|
||||
|
@ -136,6 +145,15 @@ class SlavedEventStore(BaseSlavedStore):
|
|||
_get_events_around_txn = DataStore._get_events_around_txn.__func__
|
||||
_get_some_state_from_cache = DataStore._get_some_state_from_cache.__func__
|
||||
|
||||
get_backfill_events = DataStore.get_backfill_events.__func__
|
||||
_get_backfill_events = DataStore._get_backfill_events.__func__
|
||||
get_missing_events = DataStore.get_missing_events.__func__
|
||||
_get_missing_events = DataStore._get_missing_events.__func__
|
||||
|
||||
get_auth_chain = DataStore.get_auth_chain.__func__
|
||||
get_auth_chain_ids = DataStore.get_auth_chain_ids.__func__
|
||||
_get_auth_chain_ids_txn = DataStore._get_auth_chain_ids_txn.__func__
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(SlavedEventStore, self).stream_positions()
|
||||
result["events"] = self._stream_id_gen.get_current_token()
|
||||
|
@ -194,7 +212,6 @@ class SlavedEventStore(BaseSlavedStore):
|
|||
self.get_rooms_for_user.invalidate_all()
|
||||
self.get_users_in_room.invalidate((event.room_id,))
|
||||
# self.get_joined_hosts_for_room.invalidate((event.room_id,))
|
||||
self.get_room_name_and_aliases.invalidate((event.room_id,))
|
||||
|
||||
self._invalidate_get_event_cache(event.event_id)
|
||||
|
||||
|
@ -220,9 +237,9 @@ class SlavedEventStore(BaseSlavedStore):
|
|||
self.get_rooms_for_user.invalidate((event.state_key,))
|
||||
# self.get_joined_hosts_for_room.invalidate((event.room_id,))
|
||||
self.get_users_in_room.invalidate((event.room_id,))
|
||||
# self._membership_stream_cache.entity_has_changed(
|
||||
# event.state_key, event.internal_metadata.stream_ordering
|
||||
# )
|
||||
self._membership_stream_cache.entity_has_changed(
|
||||
event.state_key, event.internal_metadata.stream_ordering
|
||||
)
|
||||
self.get_invited_rooms_for_user.invalidate((event.state_key,))
|
||||
|
||||
if not event.is_state():
|
||||
|
@ -238,9 +255,3 @@ class SlavedEventStore(BaseSlavedStore):
|
|||
self._get_current_state_for_key.invalidate((
|
||||
event.room_id, event.type, event.state_key
|
||||
))
|
||||
|
||||
if event.type in [EventTypes.Name, EventTypes.Aliases]:
|
||||
self.get_room_name_and_aliases.invalidate(
|
||||
(event.room_id,)
|
||||
)
|
||||
pass
|
||||
|
|
25
synapse/replication/slave/storage/filtering.py
Normal file
25
synapse/replication/slave/storage/filtering.py
Normal file
|
@ -0,0 +1,25 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import BaseSlavedStore
|
||||
from synapse.storage.filtering import FilteringStore
|
||||
|
||||
|
||||
class SlavedFilteringStore(BaseSlavedStore):
|
||||
def __init__(self, db_conn, hs):
|
||||
super(SlavedFilteringStore, self).__init__(db_conn, hs)
|
||||
|
||||
# Filters are immutable so this cache doesn't need to be expired
|
||||
get_user_filter = FilteringStore.__dict__["get_user_filter"]
|
33
synapse/replication/slave/storage/keys.py
Normal file
33
synapse/replication/slave/storage/keys.py
Normal file
|
@ -0,0 +1,33 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import BaseSlavedStore
|
||||
from synapse.storage import DataStore
|
||||
from synapse.storage.keys import KeyStore
|
||||
|
||||
|
||||
class SlavedKeyStore(BaseSlavedStore):
|
||||
_get_server_verify_key = KeyStore.__dict__[
|
||||
"_get_server_verify_key"
|
||||
]
|
||||
|
||||
get_server_verify_keys = DataStore.get_server_verify_keys.__func__
|
||||
store_server_verify_key = DataStore.store_server_verify_key.__func__
|
||||
|
||||
get_server_certificate = DataStore.get_server_certificate.__func__
|
||||
store_server_certificate = DataStore.store_server_certificate.__func__
|
||||
|
||||
get_server_keys_json = DataStore.get_server_keys_json.__func__
|
||||
store_server_keys_json = DataStore.store_server_keys_json.__func__
|
59
synapse/replication/slave/storage/presence.py
Normal file
59
synapse/replication/slave/storage/presence.py
Normal file
|
@ -0,0 +1,59 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import BaseSlavedStore
|
||||
from ._slaved_id_tracker import SlavedIdTracker
|
||||
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
from synapse.storage import DataStore
|
||||
|
||||
|
||||
class SlavedPresenceStore(BaseSlavedStore):
|
||||
def __init__(self, db_conn, hs):
|
||||
super(SlavedPresenceStore, self).__init__(db_conn, hs)
|
||||
self._presence_id_gen = SlavedIdTracker(
|
||||
db_conn, "presence_stream", "stream_id",
|
||||
)
|
||||
|
||||
self._presence_on_startup = self._get_active_presence(db_conn)
|
||||
|
||||
self.presence_stream_cache = self.presence_stream_cache = StreamChangeCache(
|
||||
"PresenceStreamChangeCache", self._presence_id_gen.get_current_token()
|
||||
)
|
||||
|
||||
_get_active_presence = DataStore._get_active_presence.__func__
|
||||
take_presence_startup_info = DataStore.take_presence_startup_info.__func__
|
||||
get_presence_for_users = DataStore.get_presence_for_users.__func__
|
||||
|
||||
def get_current_presence_token(self):
|
||||
return self._presence_id_gen.get_current_token()
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(SlavedPresenceStore, self).stream_positions()
|
||||
position = self._presence_id_gen.get_current_token()
|
||||
result["presence"] = position
|
||||
return result
|
||||
|
||||
def process_replication(self, result):
|
||||
stream = result.get("presence")
|
||||
if stream:
|
||||
self._presence_id_gen.advance(int(stream["position"]))
|
||||
for row in stream["rows"]:
|
||||
position, user_id = row[:2]
|
||||
self.presence_stream_cache.entity_has_changed(
|
||||
user_id, position
|
||||
)
|
||||
|
||||
return super(SlavedPresenceStore, self).process_replication(result)
|
67
synapse/replication/slave/storage/push_rule.py
Normal file
67
synapse/replication/slave/storage/push_rule.py
Normal file
|
@ -0,0 +1,67 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .events import SlavedEventStore
|
||||
from ._slaved_id_tracker import SlavedIdTracker
|
||||
from synapse.storage import DataStore
|
||||
from synapse.storage.push_rule import PushRuleStore
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
|
||||
|
||||
class SlavedPushRuleStore(SlavedEventStore):
|
||||
def __init__(self, db_conn, hs):
|
||||
super(SlavedPushRuleStore, self).__init__(db_conn, hs)
|
||||
self._push_rules_stream_id_gen = SlavedIdTracker(
|
||||
db_conn, "push_rules_stream", "stream_id",
|
||||
)
|
||||
self.push_rules_stream_cache = StreamChangeCache(
|
||||
"PushRulesStreamChangeCache",
|
||||
self._push_rules_stream_id_gen.get_current_token(),
|
||||
)
|
||||
|
||||
get_push_rules_for_user = PushRuleStore.__dict__["get_push_rules_for_user"]
|
||||
get_push_rules_enabled_for_user = (
|
||||
PushRuleStore.__dict__["get_push_rules_enabled_for_user"]
|
||||
)
|
||||
have_push_rules_changed_for_user = (
|
||||
DataStore.have_push_rules_changed_for_user.__func__
|
||||
)
|
||||
|
||||
def get_push_rules_stream_token(self):
|
||||
return (
|
||||
self._push_rules_stream_id_gen.get_current_token(),
|
||||
self._stream_id_gen.get_current_token(),
|
||||
)
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(SlavedPushRuleStore, self).stream_positions()
|
||||
result["push_rules"] = self._push_rules_stream_id_gen.get_current_token()
|
||||
return result
|
||||
|
||||
def process_replication(self, result):
|
||||
stream = result.get("push_rules")
|
||||
if stream:
|
||||
for row in stream["rows"]:
|
||||
position = row[0]
|
||||
user_id = row[2]
|
||||
self.get_push_rules_for_user.invalidate((user_id,))
|
||||
self.get_push_rules_enabled_for_user.invalidate((user_id,))
|
||||
self.push_rules_stream_cache.entity_has_changed(
|
||||
user_id, position
|
||||
)
|
||||
|
||||
self._push_rules_stream_id_gen.advance(int(stream["position"]))
|
||||
|
||||
return super(SlavedPushRuleStore, self).process_replication(result)
|
|
@ -18,6 +18,7 @@ from ._slaved_id_tracker import SlavedIdTracker
|
|||
|
||||
from synapse.storage import DataStore
|
||||
from synapse.storage.receipts import ReceiptsStore
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
|
||||
# So, um, we want to borrow a load of functions intended for reading from
|
||||
# a DataStore, but we don't want to take functions that either write to the
|
||||
|
@ -37,11 +38,28 @@ class SlavedReceiptsStore(BaseSlavedStore):
|
|||
db_conn, "receipts_linearized", "stream_id"
|
||||
)
|
||||
|
||||
self._receipts_stream_cache = StreamChangeCache(
|
||||
"ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
|
||||
)
|
||||
|
||||
get_receipts_for_user = ReceiptsStore.__dict__["get_receipts_for_user"]
|
||||
get_linearized_receipts_for_room = (
|
||||
ReceiptsStore.__dict__["get_linearized_receipts_for_room"]
|
||||
)
|
||||
_get_linearized_receipts_for_rooms = (
|
||||
ReceiptsStore.__dict__["_get_linearized_receipts_for_rooms"]
|
||||
)
|
||||
get_last_receipt_event_id_for_user = (
|
||||
ReceiptsStore.__dict__["get_last_receipt_event_id_for_user"]
|
||||
)
|
||||
|
||||
get_max_receipt_stream_id = DataStore.get_max_receipt_stream_id.__func__
|
||||
get_all_updated_receipts = DataStore.get_all_updated_receipts.__func__
|
||||
|
||||
get_linearized_receipts_for_rooms = (
|
||||
DataStore.get_linearized_receipts_for_rooms.__func__
|
||||
)
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(SlavedReceiptsStore, self).stream_positions()
|
||||
result["receipts"] = self._receipts_id_gen.get_current_token()
|
||||
|
@ -52,10 +70,15 @@ class SlavedReceiptsStore(BaseSlavedStore):
|
|||
if stream:
|
||||
self._receipts_id_gen.advance(int(stream["position"]))
|
||||
for row in stream["rows"]:
|
||||
room_id, receipt_type, user_id = row[1:4]
|
||||
position, room_id, receipt_type, user_id = row[:4]
|
||||
self.invalidate_caches_for_receipt(room_id, receipt_type, user_id)
|
||||
self._receipts_stream_cache.entity_has_changed(room_id, position)
|
||||
|
||||
return super(SlavedReceiptsStore, self).process_replication(result)
|
||||
|
||||
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
|
||||
self.get_receipts_for_user.invalidate((user_id, receipt_type))
|
||||
self.get_linearized_receipts_for_room.invalidate_many((room_id,))
|
||||
self.get_last_receipt_event_id_for_user.invalidate(
|
||||
(user_id, room_id, receipt_type)
|
||||
)
|
||||
|
|
30
synapse/replication/slave/storage/registration.py
Normal file
30
synapse/replication/slave/storage/registration.py
Normal file
|
@ -0,0 +1,30 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import BaseSlavedStore
|
||||
from synapse.storage import DataStore
|
||||
from synapse.storage.registration import RegistrationStore
|
||||
|
||||
|
||||
class SlavedRegistrationStore(BaseSlavedStore):
|
||||
def __init__(self, db_conn, hs):
|
||||
super(SlavedRegistrationStore, self).__init__(db_conn, hs)
|
||||
|
||||
# TODO: use the cached version and invalidate deleted tokens
|
||||
get_user_by_access_token = RegistrationStore.__dict__[
|
||||
"get_user_by_access_token"
|
||||
].orig
|
||||
|
||||
_query_for_auth = DataStore._query_for_auth.__func__
|
21
synapse/replication/slave/storage/room.py
Normal file
21
synapse/replication/slave/storage/room.py
Normal file
|
@ -0,0 +1,21 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import BaseSlavedStore
|
||||
from synapse.storage import DataStore
|
||||
|
||||
|
||||
class RoomStore(BaseSlavedStore):
|
||||
get_public_room_ids = DataStore.get_public_room_ids.__func__
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue