Merge branch 'develop' of github.com:matrix-org/synapse into erikj/initial_sync_asnyc

This commit is contained in:
Erik Johnston 2019-12-11 17:01:41 +00:00
commit 6828b47c45
156 changed files with 1623 additions and 846 deletions

View file

@ -1,7 +1,7 @@
# Configuration file used for testing the 'synapse_port_db' script. # Configuration file used for testing the 'synapse_port_db' script.
# Tells the script to connect to the postgresql database that will be available in the # Tells the script to connect to the postgresql database that will be available in the
# CI's Docker setup at the point where this file is considered. # CI's Docker setup at the point where this file is considered.
server_name: "test" server_name: "localhost:8800"
signing_key_path: "/src/.buildkite/test.signing.key" signing_key_path: "/src/.buildkite/test.signing.key"

View file

@ -1,7 +1,7 @@
# Configuration file used for testing the 'synapse_port_db' script. # Configuration file used for testing the 'synapse_port_db' script.
# Tells the 'update_database' script to connect to the test SQLite database to upgrade its # Tells the 'update_database' script to connect to the test SQLite database to upgrade its
# schema and run background updates on it. # schema and run background updates on it.
server_name: "test" server_name: "localhost:8800"
signing_key_path: "/src/.buildkite/test.signing.key" signing_key_path: "/src/.buildkite/test.signing.key"

View file

@ -1,3 +1,86 @@
Synapse 1.7.0rc2 (2019-12-11)
=============================
Bugfixes
--------
- Fix incorrect error message for invalid requests when setting user's avatar URL. ([\#6497](https://github.com/matrix-org/synapse/issues/6497))
- Fix support for SQLite 3.7. ([\#6499](https://github.com/matrix-org/synapse/issues/6499))
- Fix regression where sending email push would not work when using a pusher worker. ([\#6507](https://github.com/matrix-org/synapse/issues/6507), [\#6509](https://github.com/matrix-org/synapse/issues/6509))
Synapse 1.7.0rc1 (2019-12-09)
=============================
Features
--------
- Implement per-room message retention policies. ([\#5815](https://github.com/matrix-org/synapse/issues/5815), [\#6436](https://github.com/matrix-org/synapse/issues/6436))
- Add etag and count fields to key backup endpoints to help clients guess if there are new keys. ([\#5858](https://github.com/matrix-org/synapse/issues/5858))
- Add `/admin/v2/users` endpoint with pagination. Contributed by Awesome Technologies Innovationslabor GmbH. ([\#5925](https://github.com/matrix-org/synapse/issues/5925))
- Require User-Interactive Authentication for `/account/3pid/add`, meaning the user's password will be required to add a third-party ID to their account. ([\#6119](https://github.com/matrix-org/synapse/issues/6119))
- Implement the `/_matrix/federation/unstable/net.atleastfornow/state/<context>` API as drafted in MSC2314. ([\#6176](https://github.com/matrix-org/synapse/issues/6176))
- Configure privacy-preserving settings by default for the room directory. ([\#6355](https://github.com/matrix-org/synapse/issues/6355))
- Add ephemeral messages support by partially implementing [MSC2228](https://github.com/matrix-org/matrix-doc/pull/2228). ([\#6409](https://github.com/matrix-org/synapse/issues/6409))
- Add support for [MSC 2367](https://github.com/matrix-org/matrix-doc/pull/2367), which allows specifying a reason on all membership events. ([\#6434](https://github.com/matrix-org/synapse/issues/6434))
Bugfixes
--------
- Transfer non-standard power levels on room upgrade. ([\#6237](https://github.com/matrix-org/synapse/issues/6237))
- Fix error from the Pillow library when uploading RGBA images. ([\#6241](https://github.com/matrix-org/synapse/issues/6241))
- Correctly apply the event filter to the `state`, `events_before` and `events_after` fields in the response to `/context` requests. ([\#6329](https://github.com/matrix-org/synapse/issues/6329))
- Fix caching devices for remote users when using workers, so that we don't attempt to refetch (and potentially fail) each time a user requests devices. ([\#6332](https://github.com/matrix-org/synapse/issues/6332))
- Prevent account data syncs getting lost across TCP replication. ([\#6333](https://github.com/matrix-org/synapse/issues/6333))
- Fix bug: TypeError in `register_user()` while using LDAP auth module. ([\#6406](https://github.com/matrix-org/synapse/issues/6406))
- Fix an intermittent exception when handling read-receipts. ([\#6408](https://github.com/matrix-org/synapse/issues/6408))
- Fix broken guest registration when there are existing blocks of numeric user IDs. ([\#6420](https://github.com/matrix-org/synapse/issues/6420))
- Fix startup error when http proxy is defined. ([\#6421](https://github.com/matrix-org/synapse/issues/6421))
- Fix error when using synapse_port_db on a vanilla synapse db. ([\#6449](https://github.com/matrix-org/synapse/issues/6449))
- Fix uploading multiple cross signing signatures for the same user. ([\#6451](https://github.com/matrix-org/synapse/issues/6451))
- Fix bug which lead to exceptions being thrown in a loop when a cross-signed device is deleted. ([\#6462](https://github.com/matrix-org/synapse/issues/6462))
- Fix `synapse_port_db` not exiting with a 0 code if something went wrong during the port process. ([\#6470](https://github.com/matrix-org/synapse/issues/6470))
- Improve sanity-checking when receiving events over federation. ([\#6472](https://github.com/matrix-org/synapse/issues/6472))
- Fix inaccurate per-block Prometheus metrics. ([\#6491](https://github.com/matrix-org/synapse/issues/6491))
- Fix small performance regression for sending invites. ([\#6493](https://github.com/matrix-org/synapse/issues/6493))
- Back out cross-signing code added in Synapse 1.5.0, which caused a performance regression. ([\#6494](https://github.com/matrix-org/synapse/issues/6494))
Improved Documentation
----------------------
- Update documentation and variables in user contributed systemd reference file. ([\#6369](https://github.com/matrix-org/synapse/issues/6369), [\#6490](https://github.com/matrix-org/synapse/issues/6490))
- Fix link in the user directory documentation. ([\#6388](https://github.com/matrix-org/synapse/issues/6388))
- Add build instructions to the docker readme. ([\#6390](https://github.com/matrix-org/synapse/issues/6390))
- Switch Ubuntu package install recommendation to use python3 packages in INSTALL.md. ([\#6443](https://github.com/matrix-org/synapse/issues/6443))
- Write some docs for the quarantine_media api. ([\#6458](https://github.com/matrix-org/synapse/issues/6458))
- Convert CONTRIBUTING.rst to markdown (among other small fixes). ([\#6461](https://github.com/matrix-org/synapse/issues/6461))
Deprecations and Removals
-------------------------
- Remove admin/v1/users_paginate endpoint. Contributed by Awesome Technologies Innovationslabor GmbH. ([\#5925](https://github.com/matrix-org/synapse/issues/5925))
- Remove fallback for federation with old servers which lack the /federation/v1/state_ids API. ([\#6488](https://github.com/matrix-org/synapse/issues/6488))
Internal Changes
----------------
- Add benchmarks for structured logging and improve output performance. ([\#6266](https://github.com/matrix-org/synapse/issues/6266))
- Improve the performance of outputting structured logging. ([\#6322](https://github.com/matrix-org/synapse/issues/6322))
- Refactor some code in the event authentication path for clarity. ([\#6343](https://github.com/matrix-org/synapse/issues/6343), [\#6468](https://github.com/matrix-org/synapse/issues/6468), [\#6480](https://github.com/matrix-org/synapse/issues/6480))
- Clean up some unnecessary quotation marks around the codebase. ([\#6362](https://github.com/matrix-org/synapse/issues/6362))
- Complain on startup instead of 500'ing during runtime when `public_baseurl` isn't set when necessary. ([\#6379](https://github.com/matrix-org/synapse/issues/6379))
- Add a test scenario to make sure room history purges don't break `/messages` in the future. ([\#6392](https://github.com/matrix-org/synapse/issues/6392))
- Clarifications for the email configuration settings. ([\#6423](https://github.com/matrix-org/synapse/issues/6423))
- Add more tests to the blacklist when running in worker mode. ([\#6429](https://github.com/matrix-org/synapse/issues/6429))
- Refactor data store layer to support multiple databases in the future. ([\#6454](https://github.com/matrix-org/synapse/issues/6454), [\#6464](https://github.com/matrix-org/synapse/issues/6464), [\#6469](https://github.com/matrix-org/synapse/issues/6469), [\#6487](https://github.com/matrix-org/synapse/issues/6487))
- Port synapse.rest.client.v1 to async/await. ([\#6482](https://github.com/matrix-org/synapse/issues/6482))
- Port synapse.rest.client.v2_alpha to async/await. ([\#6483](https://github.com/matrix-org/synapse/issues/6483))
- Port SyncHandler to async/await. ([\#6484](https://github.com/matrix-org/synapse/issues/6484))
Synapse 1.6.1 (2019-11-28) Synapse 1.6.1 (2019-11-28)
========================== ==========================

View file

@ -1 +0,0 @@
Implement per-room message retention policies.

View file

@ -1 +0,0 @@
Add etag and count fields to key backup endpoints to help clients guess if there are new keys.

View file

@ -1 +0,0 @@
Add admin/v2/users endpoint with pagination. Contributed by Awesome Technologies Innovationslabor GmbH.

View file

@ -1 +0,0 @@
Remove admin/v1/users_paginate endpoint. Contributed by Awesome Technologies Innovationslabor GmbH.

View file

@ -1 +0,0 @@
Require User-Interactive Authentication for `/account/3pid/add`, meaning the user's password will be required to add a third-party ID to their account.

View file

@ -1 +0,0 @@
Implement the `/_matrix/federation/unstable/net.atleastfornow/state/<context>` API as drafted in MSC2314.

View file

@ -1 +0,0 @@
Transfer non-standard power levels on room upgrade.

View file

@ -1 +0,0 @@
Fix error from the Pillow library when uploading RGBA images.

View file

@ -1 +0,0 @@
Add benchmarks for structured logging and improve output performance.

View file

@ -1 +0,0 @@
Improve the performance of outputting structured logging.

View file

@ -1 +0,0 @@
Correctly apply the event filter to the `state`, `events_before` and `events_after` fields in the response to `/context` requests.

View file

@ -1 +0,0 @@
Fix caching devices for remote users when using workers, so that we don't attempt to refetch (and potentially fail) each time a user requests devices.

View file

@ -1 +0,0 @@
Prevent account data syncs getting lost across TCP replication.

1
changelog.d/6349.feature Normal file
View file

@ -0,0 +1 @@
Implement v2 APIs for the `send_join` and `send_leave` federation endpoints (as described in [MSC1802](https://github.com/matrix-org/matrix-doc/pull/1802)).

View file

@ -1 +0,0 @@
Configure privacy preserving settings by default for the room directory.

View file

@ -1 +0,0 @@
Clean up some unnecessary quotation marks around the codebase.

View file

@ -1 +0,0 @@
Update documentation and variables in user contributed systemd reference file.

1
changelog.d/6377.bugfix Normal file
View file

@ -0,0 +1 @@
Prevent redacted events from being returned during message search.

View file

@ -1 +0,0 @@
Complain on startup instead of 500'ing during runtime when `public_baseurl` isn't set when necessary.

1
changelog.d/6385.bugfix Normal file
View file

@ -0,0 +1 @@
Prevent error on trying to search a upgraded room when the server is not in the predecessor room.

View file

@ -1 +0,0 @@
Fix link in the user directory documentation.

View file

@ -1 +0,0 @@
Add build instructions to the docker readme.

View file

@ -1 +0,0 @@
Add a test scenario to make sure room history purges don't break `/messages` in the future.

1
changelog.d/6394.feature Normal file
View file

@ -0,0 +1 @@
Add a develop script to generate full SQL schemas.

View file

@ -1 +0,0 @@
Fix bug: TypeError in `register_user()` while using LDAP auth module.

View file

@ -1 +0,0 @@
Fix an intermittent exception when handling read-receipts.

View file

@ -1 +0,0 @@
Add ephemeral messages support by partially implementing [MSC2228](https://github.com/matrix-org/matrix-doc/pull/2228).

1
changelog.d/6411.feature Normal file
View file

@ -0,0 +1 @@
Allow custom SAML username mapping functinality through an external provider plugin.

View file

@ -1 +0,0 @@
Fix broken guest registration when there are existing blocks of numeric user IDs.

View file

@ -1 +0,0 @@
Fix startup error when http proxy is defined.

View file

@ -1 +0,0 @@
Clarifications for the email configuration settings.

View file

@ -1 +0,0 @@
Clean up local threepids from user on account deactivation.

View file

@ -1 +0,0 @@
Add more tests to the blacklist when running in worker mode.

View file

@ -1 +0,0 @@
Add support for MSC 2367, which allows specifying a reason on all membership events.

View file

@ -1 +0,0 @@
Fix a bug where a room could become unusable with a low retention policy and a low activity.

View file

@ -1 +0,0 @@
Switch Ubuntu package install recommendation to use python3 packages in INSTALL.md.

View file

@ -1 +0,0 @@
Fix error when using synapse_port_db on a vanilla synapse db.

View file

@ -1 +0,0 @@
Fix uploading multiple cross signing signatures for the same user.

View file

@ -1 +0,0 @@
Move data store specific code out of `SQLBaseStore`.

View file

@ -1 +0,0 @@
Write some docs for the quarantine_media api.

View file

@ -1 +0,0 @@
Convert CONTRIBUTING.rst to markdown (among other small fixes).

View file

@ -1 +0,0 @@
Fix bug which lead to exceptions being thrown in a loop when a cross-signed device is deleted.

View file

@ -1 +0,0 @@
Prepare SQLBaseStore functions being moved out of the stores.

View file

@ -1 +0,0 @@
Refactor some code in the event authentication path for clarity.

View file

@ -1 +0,0 @@
Move per database functionality out of the data stores and into a dedicated `Database` class.

View file

@ -1 +0,0 @@
Fix `synapse_port_db` not exiting with a 0 code if something went wrong during the port process.

View file

@ -1 +0,0 @@
Improve sanity-checking when receiving events over federation.

View file

@ -1 +0,0 @@
Refactor some code in the event authentication path for clarity.

View file

@ -1 +0,0 @@
Port synapse.rest.client.v1 to async/await.

View file

@ -1 +0,0 @@
Port synapse.rest.client.v2_alpha to async/await.

View file

@ -1 +0,0 @@
Port SyncHandler to async/await.

View file

@ -1 +0,0 @@
Remove fallback for federation with old servers which lack the /federation/v1/state_ids API.

View file

@ -1 +0,0 @@
Fix inaccurate per-block Prometheus metrics.

View file

@ -1 +0,0 @@
Fix small performance regression for sending invites.

1
changelog.d/6501.misc Normal file
View file

@ -0,0 +1 @@
Refactor get_events_from_store_or_dest to return a dict.

1
changelog.d/6502.removal Normal file
View file

@ -0,0 +1 @@
Remove redundant code from event authorisation implementation.

1
changelog.d/6503.misc Normal file
View file

@ -0,0 +1 @@
Move get_state methods into FederationHandler.

1
changelog.d/6504.misc Normal file
View file

@ -0,0 +1 @@
Port handlers.account_data and handlers.account_validity to async/await.

1
changelog.d/6505.misc Normal file
View file

@ -0,0 +1 @@
Make `make_deferred_yieldable` to work with async/await.

1
changelog.d/6506.misc Normal file
View file

@ -0,0 +1 @@
Remove `SnapshotCache` in favour of `ResponseCache`.

1
changelog.d/6510.misc Normal file
View file

@ -0,0 +1 @@
Change phone home stats to not assume there is a single database and report information about the database used by the main data store.

1
changelog.d/6512.misc Normal file
View file

@ -0,0 +1 @@
Silence mypy errors for files outside those specified.

1
changelog.d/6514.bugfix Normal file
View file

@ -0,0 +1 @@
Fix race which occasionally caused deleted devices to reappear.

1
changelog.d/6515.misc Normal file
View file

@ -0,0 +1 @@
Clean up some logging when handling incoming events over federation.

1
changelog.d/6517.misc Normal file
View file

@ -0,0 +1 @@
Port some of FederationHandler to async/await.

View file

@ -25,7 +25,7 @@ Restart=on-abort
User=synapse User=synapse
Group=nogroup Group=nogroup
WorkingDirectory=/opt/synapse WorkingDirectory=/home/synapse/synapse
ExecStart=/home/synapse/synapse/env/bin/python -m synapse.app.homeserver --config-path=/home/synapse/synapse/homeserver.yaml ExecStart=/home/synapse/synapse/env/bin/python -m synapse.app.homeserver --config-path=/home/synapse/synapse/homeserver.yaml
SyslogIdentifier=matrix-synapse SyslogIdentifier=matrix-synapse

View file

@ -0,0 +1,77 @@
# SAML Mapping Providers
A SAML mapping provider is a Python class (loaded via a Python module) that
works out how to map attributes of a SAML response object to Matrix-specific
user attributes. Details such as user ID localpart, displayname, and even avatar
URLs are all things that can be mapped from talking to a SSO service.
As an example, a SSO service may return the email address
"john.smith@example.com" for a user, whereas Synapse will need to figure out how
to turn that into a displayname when creating a Matrix user for this individual.
It may choose `John Smith`, or `Smith, John [Example.com]` or any number of
variations. As each Synapse configuration may want something different, this is
where SAML mapping providers come into play.
## Enabling Providers
External mapping providers are provided to Synapse in the form of an external
Python module. Retrieve this module from [PyPi](https://pypi.org) or elsewhere,
then tell Synapse where to look for the handler class by editing the
`saml2_config.user_mapping_provider.module` config option.
`saml2_config.user_mapping_provider.config` allows you to provide custom
configuration options to the module. Check with the module's documentation for
what options it provides (if any). The options listed by default are for the
user mapping provider built in to Synapse. If using a custom module, you should
comment these options out and use those specified by the module instead.
## Building a Custom Mapping Provider
A custom mapping provider must specify the following methods:
* `__init__(self, parsed_config)`
- Arguments:
- `parsed_config` - A configuration object that is the return value of the
`parse_config` method. You should set any configuration options needed by
the module here.
* `saml_response_to_user_attributes(self, saml_response, failures)`
- Arguments:
- `saml_response` - A `saml2.response.AuthnResponse` object to extract user
information from.
- `failures` - An `int` that represents the amount of times the returned
mxid localpart mapping has failed. This should be used
to create a deduplicated mxid localpart which should be
returned instead. For example, if this method returns
`john.doe` as the value of `mxid_localpart` in the returned
dict, and that is already taken on the homeserver, this
method will be called again with the same parameters but
with failures=1. The method should then return a different
`mxid_localpart` value, such as `john.doe1`.
- This method must return a dictionary, which will then be used by Synapse
to build a new user. The following keys are allowed:
* `mxid_localpart` - Required. The mxid localpart of the new user.
* `displayname` - The displayname of the new user. If not provided, will default to
the value of `mxid_localpart`.
* `parse_config(config)`
- This method should have the `@staticmethod` decoration.
- Arguments:
- `config` - A `dict` representing the parsed content of the
`saml2_config.user_mapping_provider.config` homeserver config option.
Runs on homeserver startup. Providers should extract any option values
they need here.
- Whatever is returned will be passed back to the user mapping provider module's
`__init__` method during construction.
* `get_saml_attributes(config)`
- This method should have the `@staticmethod` decoration.
- Arguments:
- `config` - A object resulting from a call to `parse_config`.
- Returns a tuple of two sets. The first set equates to the saml auth
response attributes that are required for the module to function, whereas
the second set consists of those attributes which can be used if available,
but are not necessary.
## Synapse's Default Provider
Synapse has a built-in SAML mapping provider if a custom provider isn't
specified in the config. It is located at
[`synapse.handlers.saml_handler.DefaultSamlMappingProvider`](../synapse/handlers/saml_handler.py).

View file

@ -1250,33 +1250,58 @@ saml2_config:
# #
#config_path: "CONFDIR/sp_conf.py" #config_path: "CONFDIR/sp_conf.py"
# the lifetime of a SAML session. This defines how long a user has to # The lifetime of a SAML session. This defines how long a user has to
# complete the authentication process, if allow_unsolicited is unset. # complete the authentication process, if allow_unsolicited is unset.
# The default is 5 minutes. # The default is 5 minutes.
# #
#saml_session_lifetime: 5m #saml_session_lifetime: 5m
# The SAML attribute (after mapping via the attribute maps) to use to derive # An external module can be provided here as a custom solution to
# the Matrix ID from. 'uid' by default. # mapping attributes returned from a saml provider onto a matrix user.
# #
#mxid_source_attribute: displayName user_mapping_provider:
# The custom module's class. Uncomment to use a custom module.
#
#module: mapping_provider.SamlMappingProvider
# The mapping system to use for mapping the saml attribute onto a matrix ID. # Custom configuration values for the module. Below options are
# Options include: # intended for the built-in provider, they should be changed if
# * 'hexencode' (which maps unpermitted characters to '=xx') # using a custom module. This section will be passed as a Python
# * 'dotreplace' (which replaces unpermitted characters with '.'). # dictionary to the module's `parse_config` method.
# The default is 'hexencode'. #
# config:
#mxid_mapping: dotreplace # The SAML attribute (after mapping via the attribute maps) to use
# to derive the Matrix ID from. 'uid' by default.
#
# Note: This used to be configured by the
# saml2_config.mxid_source_attribute option. If that is still
# defined, its value will be used instead.
#
#mxid_source_attribute: displayName
# In previous versions of synapse, the mapping from SAML attribute to MXID was # The mapping system to use for mapping the saml attribute onto a
# always calculated dynamically rather than stored in a table. For backwards- # matrix ID.
# compatibility, we will look for user_ids matching such a pattern before #
# creating a new account. # Options include:
# * 'hexencode' (which maps unpermitted characters to '=xx')
# * 'dotreplace' (which replaces unpermitted characters with
# '.').
# The default is 'hexencode'.
#
# Note: This used to be configured by the
# saml2_config.mxid_mapping option. If that is still defined, its
# value will be used instead.
#
#mxid_mapping: dotreplace
# In previous versions of synapse, the mapping from SAML attribute to
# MXID was always calculated dynamically rather than stored in a
# table. For backwards- compatibility, we will look for user_ids
# matching such a pattern before creating a new account.
# #
# This setting controls the SAML attribute which will be used for this # This setting controls the SAML attribute which will be used for this
# backwards-compatibility lookup. Typically it should be 'uid', but if the # backwards-compatibility lookup. Typically it should be 'uid', but if
# attribute maps are changed, it may be necessary to change it. # the attribute maps are changed, it may be necessary to change it.
# #
# The default is 'uid'. # The default is 'uid'.
# #

View file

@ -1,7 +1,7 @@
[mypy] [mypy]
namespace_packages = True namespace_packages = True
plugins = mypy_zope:plugin plugins = mypy_zope:plugin
follow_imports = normal follow_imports = silent
check_untyped_defs = True check_untyped_defs = True
show_error_codes = True show_error_codes = True
show_traceback = True show_traceback = True

184
scripts-dev/make_full_schema.sh Executable file
View file

@ -0,0 +1,184 @@
#!/bin/bash
#
# This script generates SQL files for creating a brand new Synapse DB with the latest
# schema, on both SQLite3 and Postgres.
#
# It does so by having Synapse generate an up-to-date SQLite DB, then running
# synapse_port_db to convert it to Postgres. It then dumps the contents of both.
POSTGRES_HOST="localhost"
POSTGRES_DB_NAME="synapse_full_schema.$$"
SQLITE_FULL_SCHEMA_OUTPUT_FILE="full.sql.sqlite"
POSTGRES_FULL_SCHEMA_OUTPUT_FILE="full.sql.postgres"
REQUIRED_DEPS=("matrix-synapse" "psycopg2")
usage() {
echo
echo "Usage: $0 -p <postgres_username> -o <path> [-c] [-n] [-h]"
echo
echo "-p <postgres_username>"
echo " Username to connect to local postgres instance. The password will be requested"
echo " during script execution."
echo "-c"
echo " CI mode. Enables coverage tracking and prints every command that the script runs."
echo "-o <path>"
echo " Directory to output full schema files to."
echo "-h"
echo " Display this help text."
}
while getopts "p:co:h" opt; do
case $opt in
p)
POSTGRES_USERNAME=$OPTARG
;;
c)
# Print all commands that are being executed
set -x
# Modify required dependencies for coverage
REQUIRED_DEPS+=("coverage" "coverage-enable-subprocess")
COVERAGE=1
;;
o)
command -v realpath > /dev/null || (echo "The -o flag requires the 'realpath' binary to be installed" && exit 1)
OUTPUT_DIR="$(realpath "$OPTARG")"
;;
h)
usage
exit
;;
\?)
echo "ERROR: Invalid option: -$OPTARG" >&2
usage
exit
;;
esac
done
# Check that required dependencies are installed
unsatisfied_requirements=()
for dep in "${REQUIRED_DEPS[@]}"; do
pip show "$dep" --quiet || unsatisfied_requirements+=("$dep")
done
if [ ${#unsatisfied_requirements} -ne 0 ]; then
echo "Please install the following python packages: ${unsatisfied_requirements[*]}"
exit 1
fi
if [ -z "$POSTGRES_USERNAME" ]; then
echo "No postgres username supplied"
usage
exit 1
fi
if [ -z "$OUTPUT_DIR" ]; then
echo "No output directory supplied"
usage
exit 1
fi
# Create the output directory if it doesn't exist
mkdir -p "$OUTPUT_DIR"
read -rsp "Postgres password for '$POSTGRES_USERNAME': " POSTGRES_PASSWORD
echo ""
# Exit immediately if a command fails
set -e
# cd to root of the synapse directory
cd "$(dirname "$0")/.."
# Create temporary SQLite and Postgres homeserver db configs and key file
TMPDIR=$(mktemp -d)
KEY_FILE=$TMPDIR/test.signing.key # default Synapse signing key path
SQLITE_CONFIG=$TMPDIR/sqlite.conf
SQLITE_DB=$TMPDIR/homeserver.db
POSTGRES_CONFIG=$TMPDIR/postgres.conf
# Ensure these files are delete on script exit
trap 'rm -rf $TMPDIR' EXIT
cat > "$SQLITE_CONFIG" <<EOF
server_name: "test"
signing_key_path: "$KEY_FILE"
macaroon_secret_key: "abcde"
report_stats: false
database:
name: "sqlite3"
args:
database: "$SQLITE_DB"
# Suppress the key server warning.
trusted_key_servers: []
EOF
cat > "$POSTGRES_CONFIG" <<EOF
server_name: "test"
signing_key_path: "$KEY_FILE"
macaroon_secret_key: "abcde"
report_stats: false
database:
name: "psycopg2"
args:
user: "$POSTGRES_USERNAME"
host: "$POSTGRES_HOST"
password: "$POSTGRES_PASSWORD"
database: "$POSTGRES_DB_NAME"
# Suppress the key server warning.
trusted_key_servers: []
EOF
# Generate the server's signing key.
echo "Generating SQLite3 db schema..."
python -m synapse.app.homeserver --generate-keys -c "$SQLITE_CONFIG"
# Make sure the SQLite3 database is using the latest schema and has no pending background update.
echo "Running db background jobs..."
scripts-dev/update_database --database-config "$SQLITE_CONFIG"
# Create the PostgreSQL database.
echo "Creating postgres database..."
createdb $POSTGRES_DB_NAME
echo "Copying data from SQLite3 to Postgres with synapse_port_db..."
if [ -z "$COVERAGE" ]; then
# No coverage needed
scripts/synapse_port_db --sqlite-database "$SQLITE_DB" --postgres-config "$POSTGRES_CONFIG"
else
# Coverage desired
coverage run scripts/synapse_port_db --sqlite-database "$SQLITE_DB" --postgres-config "$POSTGRES_CONFIG"
fi
# Delete schema_version, applied_schema_deltas and applied_module_schemas tables
# This needs to be done after synapse_port_db is run
echo "Dropping unwanted db tables..."
SQL="
DROP TABLE schema_version;
DROP TABLE applied_schema_deltas;
DROP TABLE applied_module_schemas;
"
sqlite3 "$SQLITE_DB" <<< "$SQL"
psql $POSTGRES_DB_NAME -U "$POSTGRES_USERNAME" -w <<< "$SQL"
echo "Dumping SQLite3 schema to '$OUTPUT_DIR/$SQLITE_FULL_SCHEMA_OUTPUT_FILE'..."
sqlite3 "$SQLITE_DB" ".dump" > "$OUTPUT_DIR/$SQLITE_FULL_SCHEMA_OUTPUT_FILE"
echo "Dumping Postgres schema to '$OUTPUT_DIR/$POSTGRES_FULL_SCHEMA_OUTPUT_FILE'..."
pg_dump --format=plain --no-tablespaces --no-acl --no-owner $POSTGRES_DB_NAME | sed -e '/^--/d' -e 's/public\.//g' -e '/^SET /d' -e '/^SELECT /d' > "$OUTPUT_DIR/$POSTGRES_FULL_SCHEMA_OUTPUT_FILE"
echo "Cleaning up temporary Postgres database..."
dropdb $POSTGRES_DB_NAME
echo "Done! Files dumped to: $OUTPUT_DIR"

View file

@ -55,6 +55,7 @@ from synapse.storage.data_stores.main.stats import StatsStore
from synapse.storage.data_stores.main.user_directory import ( from synapse.storage.data_stores.main.user_directory import (
UserDirectoryBackgroundUpdateStore, UserDirectoryBackgroundUpdateStore,
) )
from synapse.storage.database import Database
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database from synapse.storage.prepare_database import prepare_database
from synapse.util import Clock from synapse.util import Clock
@ -139,39 +140,6 @@ class Store(
UserDirectoryBackgroundUpdateStore, UserDirectoryBackgroundUpdateStore,
StatsStore, StatsStore,
): ):
def __init__(self, db_conn, hs):
super().__init__(db_conn, hs)
self.db_pool = hs.get_db_pool()
@defer.inlineCallbacks
def runInteraction(self, desc, func, *args, **kwargs):
def r(conn):
try:
i = 0
N = 5
while True:
try:
txn = conn.cursor()
return func(
LoggingTransaction(txn, desc, self.database_engine, [], []),
*args,
**kwargs
)
except self.database_engine.module.DatabaseError as e:
if self.database_engine.is_deadlock(e):
logger.warning("[TXN DEADLOCK] {%s} %d/%d", desc, i, N)
if i < N:
i += 1
conn.rollback()
continue
raise
except Exception as e:
logger.debug("[TXN FAIL] {%s} %s", desc, e)
raise
with PreserveLoggingContext():
return (yield self.db_pool.runWithConnection(r))
def execute(self, f, *args, **kwargs): def execute(self, f, *args, **kwargs):
return self.db.runInteraction(f.__name__, f, *args, **kwargs) return self.db.runInteraction(f.__name__, f, *args, **kwargs)
@ -512,7 +480,7 @@ class Porter(object):
hs = MockHomeserver(self.hs_config, engine, conn, db_pool) hs = MockHomeserver(self.hs_config, engine, conn, db_pool)
store = Store(conn, hs) store = Store(Database(hs), conn, hs)
yield store.db.runInteraction( yield store.db.runInteraction(
"%s_engine.check_database" % config["name"], engine.check_database, "%s_engine.check_database" % config["name"], engine.check_database,

View file

@ -36,7 +36,7 @@ try:
except ImportError: except ImportError:
pass pass
__version__ = "1.6.1" __version__ = "1.7.0rc2"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when # We import here so that we don't have to install a bunch of deps when

View file

@ -40,6 +40,7 @@ from synapse.replication.slave.storage.transactions import SlavedTransactionStor
from synapse.replication.tcp.client import ReplicationClientHandler from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.replication.tcp.streams._base import ReceiptsStream from synapse.replication.tcp.streams._base import ReceiptsStream
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.database import Database
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.types import ReadReceipt from synapse.types import ReadReceipt
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
@ -59,8 +60,8 @@ class FederationSenderSlaveStore(
SlavedDeviceStore, SlavedDeviceStore,
SlavedPresenceStore, SlavedPresenceStore,
): ):
def __init__(self, db_conn, hs): def __init__(self, database: Database, db_conn, hs):
super(FederationSenderSlaveStore, self).__init__(db_conn, hs) super(FederationSenderSlaveStore, self).__init__(database, db_conn, hs)
# We pull out the current federation stream position now so that we # We pull out the current federation stream position now so that we
# always have a known value for the federation position in memory so # always have a known value for the federation position in memory so

View file

@ -68,9 +68,9 @@ from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.well_known import WellKnownResource from synapse.rest.well_known import WellKnownResource
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage import DataStore, are_all_users_on_domain from synapse.storage import DataStore
from synapse.storage.engines import IncorrectDatabaseSetup, create_engine from synapse.storage.engines import IncorrectDatabaseSetup, create_engine
from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_database from synapse.storage.prepare_database import UpgradeDatabaseException
from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.httpresourcetree import create_resource_tree from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.manhole import manhole from synapse.util.manhole import manhole
@ -294,22 +294,6 @@ class SynapseHomeServer(HomeServer):
else: else:
logger.warning("Unrecognized listener type: %s", listener["type"]) logger.warning("Unrecognized listener type: %s", listener["type"])
def run_startup_checks(self, db_conn, database_engine):
all_users_native = are_all_users_on_domain(
db_conn.cursor(), database_engine, self.hostname
)
if not all_users_native:
quit_with_error(
"Found users in database not native to %s!\n"
"You cannot changed a synapse server_name after it's been configured"
% (self.hostname,)
)
try:
database_engine.check_database(db_conn.cursor())
except IncorrectDatabaseSetup as e:
quit_with_error(str(e))
# Gauges to expose monthly active user control metrics # Gauges to expose monthly active user control metrics
current_mau_gauge = Gauge("synapse_admin_mau:current", "Current MAU") current_mau_gauge = Gauge("synapse_admin_mau:current", "Current MAU")
@ -357,16 +341,12 @@ def setup(config_options):
synapse.config.logger.setup_logging(hs, config, use_worker_options=False) synapse.config.logger.setup_logging(hs, config, use_worker_options=False)
logger.info("Preparing database: %s...", config.database_config["name"]) logger.info("Setting up server")
try: try:
with hs.get_db_conn(run_new_connection=False) as db_conn: hs.setup()
prepare_database(db_conn, database_engine, config=config) except IncorrectDatabaseSetup as e:
database_engine.on_new_connection(db_conn) quit_with_error(str(e))
hs.run_startup_checks(db_conn, database_engine)
db_conn.commit()
except UpgradeDatabaseException: except UpgradeDatabaseException:
sys.stderr.write( sys.stderr.write(
"\nFailed to upgrade database.\n" "\nFailed to upgrade database.\n"
@ -375,9 +355,6 @@ def setup(config_options):
) )
sys.exit(1) sys.exit(1)
logger.info("Database prepared in %s.", config.database_config["name"])
hs.setup()
hs.setup_master() hs.setup_master()
@defer.inlineCallbacks @defer.inlineCallbacks
@ -542,8 +519,10 @@ def phone_stats_home(hs, stats, stats_process=_stats_process):
# Database version # Database version
# #
stats["database_engine"] = hs.database_engine.module.__name__ # This only reports info about the *main* database.
stats["database_server_version"] = hs.database_engine.server_version stats["database_engine"] = hs.get_datastore().db.engine.module.__name__
stats["database_server_version"] = hs.get_datastore().db.engine.server_version
logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats)) logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats))
try: try:
yield hs.get_proxied_http_client().put_json( yield hs.get_proxied_http_client().put_json(

View file

@ -33,6 +33,7 @@ from synapse.replication.slave.storage.account_data import SlavedAccountDataStor
from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.pushers import SlavedPusherStore from synapse.replication.slave.storage.pushers import SlavedPusherStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.tcp.client import ReplicationClientHandler from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage import DataStore from synapse.storage import DataStore
@ -45,7 +46,11 @@ logger = logging.getLogger("synapse.app.pusher")
class PusherSlaveStore( class PusherSlaveStore(
SlavedEventStore, SlavedPusherStore, SlavedReceiptsStore, SlavedAccountDataStore SlavedEventStore,
SlavedPusherStore,
SlavedReceiptsStore,
SlavedAccountDataStore,
RoomStore,
): ):
update_pusher_last_stream_ordering_and_success = __func__( update_pusher_last_stream_ordering_and_success = __func__(
DataStore.update_pusher_last_stream_ordering_and_success DataStore.update_pusher_last_stream_ordering_and_success

View file

@ -43,6 +43,7 @@ from synapse.replication.tcp.streams.events import (
from synapse.rest.client.v2_alpha import user_directory from synapse.rest.client.v2_alpha import user_directory
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.data_stores.main.user_directory import UserDirectoryStore from synapse.storage.data_stores.main.user_directory import UserDirectoryStore
from synapse.storage.database import Database
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.httpresourcetree import create_resource_tree from synapse.util.httpresourcetree import create_resource_tree
@ -60,8 +61,8 @@ class UserDirectorySlaveStore(
UserDirectoryStore, UserDirectoryStore,
BaseSlavedStore, BaseSlavedStore,
): ):
def __init__(self, db_conn, hs): def __init__(self, database: Database, db_conn, hs):
super(UserDirectorySlaveStore, self).__init__(db_conn, hs) super(UserDirectorySlaveStore, self).__init__(database, db_conn, hs)
events_max = self._stream_id_gen.get_current_token() events_max = self._stream_id_gen.get_current_token()
curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict( curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict(

View file

@ -14,17 +14,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import re import logging
from synapse.python_dependencies import DependencyException, check_requirements from synapse.python_dependencies import DependencyException, check_requirements
from synapse.types import ( from synapse.util.module_loader import load_module, load_python_module
map_username_to_mxid_localpart,
mxid_localpart_allowed_characters,
)
from synapse.util.module_loader import load_python_module
from ._base import Config, ConfigError from ._base import Config, ConfigError
logger = logging.getLogger(__name__)
DEFAULT_USER_MAPPING_PROVIDER = (
"synapse.handlers.saml_handler.DefaultSamlMappingProvider"
)
def _dict_merge(merge_dict, into_dict): def _dict_merge(merge_dict, into_dict):
"""Do a deep merge of two dicts """Do a deep merge of two dicts
@ -75,15 +77,69 @@ class SAML2Config(Config):
self.saml2_enabled = True self.saml2_enabled = True
self.saml2_mxid_source_attribute = saml2_config.get(
"mxid_source_attribute", "uid"
)
self.saml2_grandfathered_mxid_source_attribute = saml2_config.get( self.saml2_grandfathered_mxid_source_attribute = saml2_config.get(
"grandfathered_mxid_source_attribute", "uid" "grandfathered_mxid_source_attribute", "uid"
) )
saml2_config_dict = self._default_saml_config_dict() # user_mapping_provider may be None if the key is present but has no value
ump_dict = saml2_config.get("user_mapping_provider") or {}
# Use the default user mapping provider if not set
ump_dict.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER)
# Ensure a config is present
ump_dict["config"] = ump_dict.get("config") or {}
if ump_dict["module"] == DEFAULT_USER_MAPPING_PROVIDER:
# Load deprecated options for use by the default module
old_mxid_source_attribute = saml2_config.get("mxid_source_attribute")
if old_mxid_source_attribute:
logger.warning(
"The config option saml2_config.mxid_source_attribute is deprecated. "
"Please use saml2_config.user_mapping_provider.config"
".mxid_source_attribute instead."
)
ump_dict["config"]["mxid_source_attribute"] = old_mxid_source_attribute
old_mxid_mapping = saml2_config.get("mxid_mapping")
if old_mxid_mapping:
logger.warning(
"The config option saml2_config.mxid_mapping is deprecated. Please "
"use saml2_config.user_mapping_provider.config.mxid_mapping instead."
)
ump_dict["config"]["mxid_mapping"] = old_mxid_mapping
# Retrieve an instance of the module's class
# Pass the config dictionary to the module for processing
(
self.saml2_user_mapping_provider_class,
self.saml2_user_mapping_provider_config,
) = load_module(ump_dict)
# Ensure loaded user mapping module has defined all necessary methods
# Note parse_config() is already checked during the call to load_module
required_methods = [
"get_saml_attributes",
"saml_response_to_user_attributes",
]
missing_methods = [
method
for method in required_methods
if not hasattr(self.saml2_user_mapping_provider_class, method)
]
if missing_methods:
raise ConfigError(
"Class specified by saml2_config."
"user_mapping_provider.module is missing required "
"methods: %s" % (", ".join(missing_methods),)
)
# Get the desired saml auth response attributes from the module
saml2_config_dict = self._default_saml_config_dict(
*self.saml2_user_mapping_provider_class.get_saml_attributes(
self.saml2_user_mapping_provider_config
)
)
_dict_merge( _dict_merge(
merge_dict=saml2_config.get("sp_config", {}), into_dict=saml2_config_dict merge_dict=saml2_config.get("sp_config", {}), into_dict=saml2_config_dict
) )
@ -103,22 +159,27 @@ class SAML2Config(Config):
saml2_config.get("saml_session_lifetime", "5m") saml2_config.get("saml_session_lifetime", "5m")
) )
mapping = saml2_config.get("mxid_mapping", "hexencode") def _default_saml_config_dict(
try: self, required_attributes: set, optional_attributes: set
self.saml2_mxid_mapper = MXID_MAPPER_MAP[mapping] ):
except KeyError: """Generate a configuration dictionary with required and optional attributes that
raise ConfigError("%s is not a known mxid_mapping" % (mapping,)) will be needed to process new user registration
def _default_saml_config_dict(self): Args:
required_attributes: SAML auth response attributes that are
necessary to function
optional_attributes: SAML auth response attributes that can be used to add
additional information to Synapse user accounts, but are not required
Returns:
dict: A SAML configuration dictionary
"""
import saml2 import saml2
public_baseurl = self.public_baseurl public_baseurl = self.public_baseurl
if public_baseurl is None: if public_baseurl is None:
raise ConfigError("saml2_config requires a public_baseurl to be set") raise ConfigError("saml2_config requires a public_baseurl to be set")
required_attributes = {"uid", self.saml2_mxid_source_attribute}
optional_attributes = {"displayName"}
if self.saml2_grandfathered_mxid_source_attribute: if self.saml2_grandfathered_mxid_source_attribute:
optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute) optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute)
optional_attributes -= required_attributes optional_attributes -= required_attributes
@ -207,33 +268,58 @@ class SAML2Config(Config):
# #
#config_path: "%(config_dir_path)s/sp_conf.py" #config_path: "%(config_dir_path)s/sp_conf.py"
# the lifetime of a SAML session. This defines how long a user has to # The lifetime of a SAML session. This defines how long a user has to
# complete the authentication process, if allow_unsolicited is unset. # complete the authentication process, if allow_unsolicited is unset.
# The default is 5 minutes. # The default is 5 minutes.
# #
#saml_session_lifetime: 5m #saml_session_lifetime: 5m
# The SAML attribute (after mapping via the attribute maps) to use to derive # An external module can be provided here as a custom solution to
# the Matrix ID from. 'uid' by default. # mapping attributes returned from a saml provider onto a matrix user.
# #
#mxid_source_attribute: displayName user_mapping_provider:
# The custom module's class. Uncomment to use a custom module.
#
#module: mapping_provider.SamlMappingProvider
# The mapping system to use for mapping the saml attribute onto a matrix ID. # Custom configuration values for the module. Below options are
# Options include: # intended for the built-in provider, they should be changed if
# * 'hexencode' (which maps unpermitted characters to '=xx') # using a custom module. This section will be passed as a Python
# * 'dotreplace' (which replaces unpermitted characters with '.'). # dictionary to the module's `parse_config` method.
# The default is 'hexencode'. #
# config:
#mxid_mapping: dotreplace # The SAML attribute (after mapping via the attribute maps) to use
# to derive the Matrix ID from. 'uid' by default.
#
# Note: This used to be configured by the
# saml2_config.mxid_source_attribute option. If that is still
# defined, its value will be used instead.
#
#mxid_source_attribute: displayName
# In previous versions of synapse, the mapping from SAML attribute to MXID was # The mapping system to use for mapping the saml attribute onto a
# always calculated dynamically rather than stored in a table. For backwards- # matrix ID.
# compatibility, we will look for user_ids matching such a pattern before #
# creating a new account. # Options include:
# * 'hexencode' (which maps unpermitted characters to '=xx')
# * 'dotreplace' (which replaces unpermitted characters with
# '.').
# The default is 'hexencode'.
#
# Note: This used to be configured by the
# saml2_config.mxid_mapping option. If that is still defined, its
# value will be used instead.
#
#mxid_mapping: dotreplace
# In previous versions of synapse, the mapping from SAML attribute to
# MXID was always calculated dynamically rather than stored in a
# table. For backwards- compatibility, we will look for user_ids
# matching such a pattern before creating a new account.
# #
# This setting controls the SAML attribute which will be used for this # This setting controls the SAML attribute which will be used for this
# backwards-compatibility lookup. Typically it should be 'uid', but if the # backwards-compatibility lookup. Typically it should be 'uid', but if
# attribute maps are changed, it may be necessary to change it. # the attribute maps are changed, it may be necessary to change it.
# #
# The default is 'uid'. # The default is 'uid'.
# #
@ -241,23 +327,3 @@ class SAML2Config(Config):
""" % { """ % {
"config_dir_path": config_dir_path "config_dir_path": config_dir_path
} }
DOT_REPLACE_PATTERN = re.compile(
("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
)
def dot_replace_for_mxid(username: str) -> str:
username = username.lower()
username = DOT_REPLACE_PATTERN.sub(".", username)
# regular mxids aren't allowed to start with an underscore either
username = re.sub("^_", "", username)
return username
MXID_MAPPER_MAP = {
"hexencode": map_username_to_mxid_localpart,
"dotreplace": dot_replace_for_mxid,
}

View file

@ -42,6 +42,8 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru
Returns: Returns:
if the auth checks pass. if the auth checks pass.
""" """
assert isinstance(auth_events, dict)
if do_size_check: if do_size_check:
_check_size_limits(event) _check_size_limits(event)
@ -74,12 +76,6 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru
if not event.signatures.get(event_id_domain): if not event.signatures.get(event_id_domain):
raise AuthError(403, "Event not signed by sending server") raise AuthError(403, "Event not signed by sending server")
if auth_events is None:
# Oh, we don't know what the state of the room was, so we
# are trusting that this is allowed (at least for now)
logger.warning("Trusting event: %s", event.event_id)
return
if event.type == EventTypes.Create: if event.type == EventTypes.Create:
sender_domain = get_domain_from_id(event.sender) sender_domain = get_domain_from_id(event.sender)
room_id_domain = get_domain_from_id(event.room_id) room_id_domain = get_domain_from_id(event.room_id)

View file

@ -18,8 +18,6 @@ import copy
import itertools import itertools
import logging import logging
from six.moves import range
from prometheus_client import Counter from prometheus_client import Counter
from twisted.internet import defer from twisted.internet import defer
@ -39,7 +37,7 @@ from synapse.api.room_versions import (
) )
from synapse.events import builder, room_version_to_event_format from synapse.events import builder, room_version_to_event_format
from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.federation.federation_base import FederationBase, event_from_pdu_json
from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.context import make_deferred_yieldable
from synapse.logging.utils import log_function from synapse.logging.utils import log_function
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
@ -310,19 +308,12 @@ class FederationClient(FederationBase):
return signed_pdu return signed_pdu
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function def get_room_state_ids(self, destination: str, room_id: str, event_id: str):
def get_state_for_room(self, destination, room_id, event_id): """Calls the /state_ids endpoint to fetch the state at a particular point
"""Requests all of the room state at a given event from a remote homeserver. in the room, and the auth events for the given event
Args:
destination (str): The remote homeserver to query for the state.
room_id (str): The id of the room we're interested in.
event_id (str): The id of the event we want the state at.
Returns: Returns:
Deferred[Tuple[List[EventBase], List[EventBase]]]: Tuple[List[str], List[str]]: a tuple of (state event_ids, auth event_ids)
A list of events in the state, and a list of events in the auth chain
for the given event.
""" """
result = yield self.transport_layer.get_room_state_ids( result = yield self.transport_layer.get_room_state_ids(
destination, room_id, event_id=event_id destination, room_id, event_id=event_id
@ -331,86 +322,12 @@ class FederationClient(FederationBase):
state_event_ids = result["pdu_ids"] state_event_ids = result["pdu_ids"]
auth_event_ids = result.get("auth_chain_ids", []) auth_event_ids = result.get("auth_chain_ids", [])
fetched_events, failed_to_fetch = yield self.get_events_from_store_or_dest( if not isinstance(state_event_ids, list) or not isinstance(
destination, room_id, set(state_event_ids + auth_event_ids) auth_event_ids, list
) ):
raise Exception("invalid response from /state_ids")
if failed_to_fetch: return state_event_ids, auth_event_ids
logger.warning(
"Failed to fetch missing state/auth events for %s: %s",
room_id,
failed_to_fetch,
)
event_map = {ev.event_id: ev for ev in fetched_events}
pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map]
auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
auth_chain.sort(key=lambda e: e.depth)
return pdus, auth_chain
@defer.inlineCallbacks
def get_events_from_store_or_dest(self, destination, room_id, event_ids):
"""Fetch events from a remote destination, checking if we already have them.
Args:
destination (str)
room_id (str)
event_ids (list)
Returns:
Deferred: A deferred resolving to a 2-tuple where the first is a list of
events and the second is a list of event ids that we failed to fetch.
"""
seen_events = yield self.store.get_events(event_ids, allow_rejected=True)
signed_events = list(seen_events.values())
failed_to_fetch = set()
missing_events = set(event_ids)
for k in seen_events:
missing_events.discard(k)
if not missing_events:
return signed_events, failed_to_fetch
logger.debug(
"Fetching unknown state/auth events %s for room %s",
missing_events,
event_ids,
)
room_version = yield self.store.get_room_version(room_id)
batch_size = 20
missing_events = list(missing_events)
for i in range(0, len(missing_events), batch_size):
batch = set(missing_events[i : i + batch_size])
deferreds = [
run_in_background(
self.get_pdu,
destinations=[destination],
event_id=e_id,
room_version=room_version,
)
for e_id in batch
]
res = yield make_deferred_yieldable(
defer.DeferredList(deferreds, consumeErrors=True)
)
for success, result in res:
if success and result:
signed_events.append(result)
batch.discard(result.event_id)
# We removed all events we successfully fetched from `batch`
failed_to_fetch.update(batch)
return signed_events, failed_to_fetch
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
@ -609,13 +526,7 @@ class FederationClient(FederationBase):
@defer.inlineCallbacks @defer.inlineCallbacks
def send_request(destination): def send_request(destination):
time_now = self._clock.time_msec() content = yield self._do_send_join(destination, pdu)
_, content = yield self.transport_layer.send_join(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)
logger.debug("Got content: %s", content) logger.debug("Got content: %s", content)
@ -682,6 +593,44 @@ class FederationClient(FederationBase):
return self._try_destination_list("send_join", destinations, send_request) return self._try_destination_list("send_join", destinations, send_request)
@defer.inlineCallbacks
def _do_send_join(self, destination, pdu):
time_now = self._clock.time_msec()
try:
content = yield self.transport_layer.send_join_v2(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)
return content
except HttpResponseException as e:
if e.code in [400, 404]:
err = e.to_synapse_error()
# If we receive an error response that isn't a generic error, or an
# unrecognised endpoint error, we assume that the remote understands
# the v2 invite API and this is a legitimate error.
if err.errcode not in [Codes.UNKNOWN, Codes.UNRECOGNIZED]:
raise err
else:
raise e.to_synapse_error()
logger.debug("Couldn't send_join with the v2 API, falling back to the v1 API")
resp = yield self.transport_layer.send_join_v1(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)
# We expect the v1 API to respond with [200, content], so we only return the
# content.
return resp[1]
@defer.inlineCallbacks @defer.inlineCallbacks
def send_invite(self, destination, room_id, event_id, pdu): def send_invite(self, destination, room_id, event_id, pdu):
room_version = yield self.store.get_room_version(room_id) room_version = yield self.store.get_room_version(room_id)
@ -791,18 +740,50 @@ class FederationClient(FederationBase):
@defer.inlineCallbacks @defer.inlineCallbacks
def send_request(destination): def send_request(destination):
time_now = self._clock.time_msec() content = yield self._do_send_leave(destination, pdu)
_, content = yield self.transport_layer.send_leave(
logger.debug("Got content: %s", content)
return None
return self._try_destination_list("send_leave", destinations, send_request)
@defer.inlineCallbacks
def _do_send_leave(self, destination, pdu):
time_now = self._clock.time_msec()
try:
content = yield self.transport_layer.send_leave_v2(
destination=destination, destination=destination,
room_id=pdu.room_id, room_id=pdu.room_id,
event_id=pdu.event_id, event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now), content=pdu.get_pdu_json(time_now),
) )
logger.debug("Got content: %s", content) return content
return None except HttpResponseException as e:
if e.code in [400, 404]:
err = e.to_synapse_error()
return self._try_destination_list("send_leave", destinations, send_request) # If we receive an error response that isn't a generic error, or an
# unrecognised endpoint error, we assume that the remote understands
# the v2 invite API and this is a legitimate error.
if err.errcode not in [Codes.UNKNOWN, Codes.UNRECOGNIZED]:
raise err
else:
raise e.to_synapse_error()
logger.debug("Couldn't send_leave with the v2 API, falling back to the v1 API")
resp = yield self.transport_layer.send_leave_v1(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)
# We expect the v1 API to respond with [200, content], so we only return the
# content.
return resp[1]
def get_public_rooms( def get_public_rooms(
self, self,

View file

@ -384,15 +384,10 @@ class FederationServer(FederationBase):
res_pdus = await self.handler.on_send_join_request(origin, pdu) res_pdus = await self.handler.on_send_join_request(origin, pdu)
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
return ( return {
200, "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
{ "auth_chain": [p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]],
"state": [p.get_pdu_json(time_now) for p in res_pdus["state"]], }
"auth_chain": [
p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]
],
},
)
async def on_make_leave_request(self, origin, room_id, user_id): async def on_make_leave_request(self, origin, room_id, user_id):
origin_host, _ = parse_server_name(origin) origin_host, _ = parse_server_name(origin)
@ -419,7 +414,7 @@ class FederationServer(FederationBase):
pdu = await self._check_sigs_and_hash(room_version, pdu) pdu = await self._check_sigs_and_hash(room_version, pdu)
await self.handler.on_send_leave_request(origin, pdu) await self.handler.on_send_leave_request(origin, pdu)
return 200, {} return {}
async def on_event_auth(self, origin, room_id, event_id): async def on_event_auth(self, origin, room_id, event_id):
with (await self._server_linearizer.queue((origin, room_id))): with (await self._server_linearizer.queue((origin, room_id))):

View file

@ -243,7 +243,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def send_join(self, destination, room_id, event_id, content): def send_join_v1(self, destination, room_id, event_id, content):
path = _create_v1_path("/send_join/%s/%s", room_id, event_id) path = _create_v1_path("/send_join/%s/%s", room_id, event_id)
response = yield self.client.put_json( response = yield self.client.put_json(
@ -254,7 +254,18 @@ class TransportLayerClient(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def send_leave(self, destination, room_id, event_id, content): def send_join_v2(self, destination, room_id, event_id, content):
path = _create_v2_path("/send_join/%s/%s", room_id, event_id)
response = yield self.client.put_json(
destination=destination, path=path, data=content
)
return response
@defer.inlineCallbacks
@log_function
def send_leave_v1(self, destination, room_id, event_id, content):
path = _create_v1_path("/send_leave/%s/%s", room_id, event_id) path = _create_v1_path("/send_leave/%s/%s", room_id, event_id)
response = yield self.client.put_json( response = yield self.client.put_json(
@ -270,6 +281,24 @@ class TransportLayerClient(object):
return response return response
@defer.inlineCallbacks
@log_function
def send_leave_v2(self, destination, room_id, event_id, content):
path = _create_v2_path("/send_leave/%s/%s", room_id, event_id)
response = yield self.client.put_json(
destination=destination,
path=path,
data=content,
# we want to do our best to send this through. The problem is
# that if it fails, we won't retry it later, so if the remote
# server was just having a momentary blip, the room will be out of
# sync.
ignore_backoff=True,
)
return response
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def send_invite_v1(self, destination, room_id, event_id, content): def send_invite_v1(self, destination, room_id, event_id, content):

View file

@ -506,9 +506,19 @@ class FederationMakeLeaveServlet(BaseFederationServlet):
return 200, content return 200, content
class FederationSendLeaveServlet(BaseFederationServlet): class FederationV1SendLeaveServlet(BaseFederationServlet):
PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
async def on_PUT(self, origin, content, query, room_id, event_id):
content = await self.handler.on_send_leave_request(origin, content, room_id)
return 200, (200, content)
class FederationV2SendLeaveServlet(BaseFederationServlet):
PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
PREFIX = FEDERATION_V2_PREFIX
async def on_PUT(self, origin, content, query, room_id, event_id): async def on_PUT(self, origin, content, query, room_id, event_id):
content = await self.handler.on_send_leave_request(origin, content, room_id) content = await self.handler.on_send_leave_request(origin, content, room_id)
return 200, content return 200, content
@ -521,9 +531,21 @@ class FederationEventAuthServlet(BaseFederationServlet):
return await self.handler.on_event_auth(origin, context, event_id) return await self.handler.on_event_auth(origin, context, event_id)
class FederationSendJoinServlet(BaseFederationServlet): class FederationV1SendJoinServlet(BaseFederationServlet):
PATH = "/send_join/(?P<context>[^/]*)/(?P<event_id>[^/]*)" PATH = "/send_join/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
async def on_PUT(self, origin, content, query, context, event_id):
# TODO(paul): assert that context/event_id parsed from path actually
# match those given in content
content = await self.handler.on_send_join_request(origin, content, context)
return 200, (200, content)
class FederationV2SendJoinServlet(BaseFederationServlet):
PATH = "/send_join/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
PREFIX = FEDERATION_V2_PREFIX
async def on_PUT(self, origin, content, query, context, event_id): async def on_PUT(self, origin, content, query, context, event_id):
# TODO(paul): assert that context/event_id parsed from path actually # TODO(paul): assert that context/event_id parsed from path actually
# match those given in content # match those given in content
@ -1367,8 +1389,10 @@ FEDERATION_SERVLET_CLASSES = (
FederationMakeJoinServlet, FederationMakeJoinServlet,
FederationMakeLeaveServlet, FederationMakeLeaveServlet,
FederationEventServlet, FederationEventServlet,
FederationSendJoinServlet, FederationV1SendJoinServlet,
FederationSendLeaveServlet, FederationV2SendJoinServlet,
FederationV1SendLeaveServlet,
FederationV2SendLeaveServlet,
FederationV1InviteServlet, FederationV1InviteServlet,
FederationV2InviteServlet, FederationV2InviteServlet,
FederationQueryAuthServlet, FederationQueryAuthServlet,

View file

@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
class AccountDataEventSource(object): class AccountDataEventSource(object):
def __init__(self, hs): def __init__(self, hs):
@ -23,15 +21,14 @@ class AccountDataEventSource(object):
def get_current_key(self, direction="f"): def get_current_key(self, direction="f"):
return self.store.get_max_account_data_stream_id() return self.store.get_max_account_data_stream_id()
@defer.inlineCallbacks async def get_new_events(self, user, from_key, **kwargs):
def get_new_events(self, user, from_key, **kwargs):
user_id = user.to_string() user_id = user.to_string()
last_stream_id = from_key last_stream_id = from_key
current_stream_id = yield self.store.get_max_account_data_stream_id() current_stream_id = self.store.get_max_account_data_stream_id()
results = [] results = []
tags = yield self.store.get_updated_tags(user_id, last_stream_id) tags = await self.store.get_updated_tags(user_id, last_stream_id)
for room_id, room_tags in tags.items(): for room_id, room_tags in tags.items():
results.append( results.append(
@ -41,7 +38,7 @@ class AccountDataEventSource(object):
( (
account_data, account_data,
room_account_data, room_account_data,
) = yield self.store.get_updated_account_data_for_user(user_id, last_stream_id) ) = await self.store.get_updated_account_data_for_user(user_id, last_stream_id)
for account_data_type, content in account_data.items(): for account_data_type, content in account_data.items():
results.append({"type": account_data_type, "content": content}) results.append({"type": account_data_type, "content": content})
@ -54,6 +51,5 @@ class AccountDataEventSource(object):
return results, current_stream_id return results, current_stream_id
@defer.inlineCallbacks async def get_pagination_rows(self, user, config, key):
def get_pagination_rows(self, user, config, key):
return [], config.to_id return [], config.to_id

View file

@ -18,8 +18,7 @@ import email.utils
import logging import logging
from email.mime.multipart import MIMEMultipart from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText from email.mime.text import MIMEText
from typing import List
from twisted.internet import defer
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.logging.context import make_deferred_yieldable from synapse.logging.context import make_deferred_yieldable
@ -78,42 +77,39 @@ class AccountValidityHandler(object):
# run as a background process to make sure that the database transactions # run as a background process to make sure that the database transactions
# have a logcontext to report to # have a logcontext to report to
return run_as_background_process( return run_as_background_process(
"send_renewals", self.send_renewal_emails "send_renewals", self._send_renewal_emails
) )
self.clock.looping_call(send_emails, 30 * 60 * 1000) self.clock.looping_call(send_emails, 30 * 60 * 1000)
@defer.inlineCallbacks async def _send_renewal_emails(self):
def send_renewal_emails(self):
"""Gets the list of users whose account is expiring in the amount of time """Gets the list of users whose account is expiring in the amount of time
configured in the ``renew_at`` parameter from the ``account_validity`` configured in the ``renew_at`` parameter from the ``account_validity``
configuration, and sends renewal emails to all of these users as long as they configuration, and sends renewal emails to all of these users as long as they
have an email 3PID attached to their account. have an email 3PID attached to their account.
""" """
expiring_users = yield self.store.get_users_expiring_soon() expiring_users = await self.store.get_users_expiring_soon()
if expiring_users: if expiring_users:
for user in expiring_users: for user in expiring_users:
yield self._send_renewal_email( await self._send_renewal_email(
user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"] user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"]
) )
@defer.inlineCallbacks async def send_renewal_email_to_user(self, user_id: str):
def send_renewal_email_to_user(self, user_id): expiration_ts = await self.store.get_expiration_ts_for_user(user_id)
expiration_ts = yield self.store.get_expiration_ts_for_user(user_id) await self._send_renewal_email(user_id, expiration_ts)
yield self._send_renewal_email(user_id, expiration_ts)
@defer.inlineCallbacks async def _send_renewal_email(self, user_id: str, expiration_ts: int):
def _send_renewal_email(self, user_id, expiration_ts):
"""Sends out a renewal email to every email address attached to the given user """Sends out a renewal email to every email address attached to the given user
with a unique link allowing them to renew their account. with a unique link allowing them to renew their account.
Args: Args:
user_id (str): ID of the user to send email(s) to. user_id: ID of the user to send email(s) to.
expiration_ts (int): Timestamp in milliseconds for the expiration date of expiration_ts: Timestamp in milliseconds for the expiration date of
this user's account (used in the email templates). this user's account (used in the email templates).
""" """
addresses = yield self._get_email_addresses_for_user(user_id) addresses = await self._get_email_addresses_for_user(user_id)
# Stop right here if the user doesn't have at least one email address. # Stop right here if the user doesn't have at least one email address.
# In this case, they will have to ask their server admin to renew their # In this case, they will have to ask their server admin to renew their
@ -125,7 +121,7 @@ class AccountValidityHandler(object):
return return
try: try:
user_display_name = yield self.store.get_profile_displayname( user_display_name = await self.store.get_profile_displayname(
UserID.from_string(user_id).localpart UserID.from_string(user_id).localpart
) )
if user_display_name is None: if user_display_name is None:
@ -133,7 +129,7 @@ class AccountValidityHandler(object):
except StoreError: except StoreError:
user_display_name = user_id user_display_name = user_id
renewal_token = yield self._get_renewal_token(user_id) renewal_token = await self._get_renewal_token(user_id)
url = "%s_matrix/client/unstable/account_validity/renew?token=%s" % ( url = "%s_matrix/client/unstable/account_validity/renew?token=%s" % (
self.hs.config.public_baseurl, self.hs.config.public_baseurl,
renewal_token, renewal_token,
@ -165,7 +161,7 @@ class AccountValidityHandler(object):
logger.info("Sending renewal email to %s", address) logger.info("Sending renewal email to %s", address)
yield make_deferred_yieldable( await make_deferred_yieldable(
self.sendmail( self.sendmail(
self.hs.config.email_smtp_host, self.hs.config.email_smtp_host,
self._raw_from, self._raw_from,
@ -180,19 +176,18 @@ class AccountValidityHandler(object):
) )
) )
yield self.store.set_renewal_mail_status(user_id=user_id, email_sent=True) await self.store.set_renewal_mail_status(user_id=user_id, email_sent=True)
@defer.inlineCallbacks async def _get_email_addresses_for_user(self, user_id: str) -> List[str]:
def _get_email_addresses_for_user(self, user_id):
"""Retrieve the list of email addresses attached to a user's account. """Retrieve the list of email addresses attached to a user's account.
Args: Args:
user_id (str): ID of the user to lookup email addresses for. user_id: ID of the user to lookup email addresses for.
Returns: Returns:
defer.Deferred[list[str]]: Email addresses for this account. Email addresses for this account.
""" """
threepids = yield self.store.user_get_threepids(user_id) threepids = await self.store.user_get_threepids(user_id)
addresses = [] addresses = []
for threepid in threepids: for threepid in threepids:
@ -201,16 +196,15 @@ class AccountValidityHandler(object):
return addresses return addresses
@defer.inlineCallbacks async def _get_renewal_token(self, user_id: str) -> str:
def _get_renewal_token(self, user_id):
"""Generates a 32-byte long random string that will be inserted into the """Generates a 32-byte long random string that will be inserted into the
user's renewal email's unique link, then saves it into the database. user's renewal email's unique link, then saves it into the database.
Args: Args:
user_id (str): ID of the user to generate a string for. user_id: ID of the user to generate a string for.
Returns: Returns:
defer.Deferred[str]: The generated string. The generated string.
Raises: Raises:
StoreError(500): Couldn't generate a unique string after 5 attempts. StoreError(500): Couldn't generate a unique string after 5 attempts.
@ -219,52 +213,52 @@ class AccountValidityHandler(object):
while attempts < 5: while attempts < 5:
try: try:
renewal_token = stringutils.random_string(32) renewal_token = stringutils.random_string(32)
yield self.store.set_renewal_token_for_user(user_id, renewal_token) await self.store.set_renewal_token_for_user(user_id, renewal_token)
return renewal_token return renewal_token
except StoreError: except StoreError:
attempts += 1 attempts += 1
raise StoreError(500, "Couldn't generate a unique string as refresh string.") raise StoreError(500, "Couldn't generate a unique string as refresh string.")
@defer.inlineCallbacks async def renew_account(self, renewal_token: str) -> bool:
def renew_account(self, renewal_token):
"""Renews the account attached to a given renewal token by pushing back the """Renews the account attached to a given renewal token by pushing back the
expiration date by the current validity period in the server's configuration. expiration date by the current validity period in the server's configuration.
Args: Args:
renewal_token (str): Token sent with the renewal request. renewal_token: Token sent with the renewal request.
Returns: Returns:
bool: Whether the provided token is valid. Whether the provided token is valid.
""" """
try: try:
user_id = yield self.store.get_user_from_renewal_token(renewal_token) user_id = await self.store.get_user_from_renewal_token(renewal_token)
except StoreError: except StoreError:
defer.returnValue(False) return False
logger.debug("Renewing an account for user %s", user_id) logger.debug("Renewing an account for user %s", user_id)
yield self.renew_account_for_user(user_id) await self.renew_account_for_user(user_id)
defer.returnValue(True) return True
@defer.inlineCallbacks async def renew_account_for_user(
def renew_account_for_user(self, user_id, expiration_ts=None, email_sent=False): self, user_id: str, expiration_ts: int = None, email_sent: bool = False
) -> int:
"""Renews the account attached to a given user by pushing back the """Renews the account attached to a given user by pushing back the
expiration date by the current validity period in the server's expiration date by the current validity period in the server's
configuration. configuration.
Args: Args:
renewal_token (str): Token sent with the renewal request. renewal_token: Token sent with the renewal request.
expiration_ts (int): New expiration date. Defaults to now + validity period. expiration_ts: New expiration date. Defaults to now + validity period.
email_sent (bool): Whether an email has been sent for this validity period. email_sen: Whether an email has been sent for this validity period.
Defaults to False. Defaults to False.
Returns: Returns:
defer.Deferred[int]: New expiration date for this account, as a timestamp New expiration date for this account, as a timestamp in
in milliseconds since epoch. milliseconds since epoch.
""" """
if expiration_ts is None: if expiration_ts is None:
expiration_ts = self.clock.time_msec() + self._account_validity.period expiration_ts = self.clock.time_msec() + self._account_validity.period
yield self.store.set_account_validity_for_user( await self.store.set_account_validity_for_user(
user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent
) )

View file

@ -264,7 +264,6 @@ class E2eKeysHandler(object):
return ret return ret
@defer.inlineCallbacks
def get_cross_signing_keys_from_cache(self, query, from_user_id): def get_cross_signing_keys_from_cache(self, query, from_user_id):
"""Get cross-signing keys for users from the database """Get cross-signing keys for users from the database
@ -284,35 +283,14 @@ class E2eKeysHandler(object):
self_signing_keys = {} self_signing_keys = {}
user_signing_keys = {} user_signing_keys = {}
for user_id in query: # Currently a stub, implementation coming in https://github.com/matrix-org/synapse/pull/6486
# XXX: consider changing the store functions to allow querying return defer.succeed(
# multiple users simultaneously. {
key = yield self.store.get_e2e_cross_signing_key( "master_keys": master_keys,
user_id, "master", from_user_id "self_signing_keys": self_signing_keys,
) "user_signing_keys": user_signing_keys,
if key: }
master_keys[user_id] = key )
key = yield self.store.get_e2e_cross_signing_key(
user_id, "self_signing", from_user_id
)
if key:
self_signing_keys[user_id] = key
# users can see other users' master and self-signing keys, but can
# only see their own user-signing keys
if from_user_id == user_id:
key = yield self.store.get_e2e_cross_signing_key(
user_id, "user_signing", from_user_id
)
if key:
user_signing_keys[user_id] = key
return {
"master_keys": master_keys,
"self_signing_keys": self_signing_keys,
"user_signing_keys": user_signing_keys,
}
@trace @trace
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -19,7 +19,7 @@
import itertools import itertools
import logging import logging
from typing import Dict, Iterable, Optional, Sequence, Tuple from typing import Dict, Iterable, List, Optional, Sequence, Tuple
import six import six
from six import iteritems, itervalues from six import iteritems, itervalues
@ -63,8 +63,9 @@ from synapse.replication.http.federation import (
) )
from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
from synapse.state import StateResolutionStore, resolve_events_with_store from synapse.state import StateResolutionStore, resolve_events_with_store
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
from synapse.types import UserID, get_domain_from_id from synapse.types import UserID, get_domain_from_id
from synapse.util import unwrapFirstError from synapse.util import batch_iter, unwrapFirstError
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_joined_room from synapse.util.distributor import user_joined_room
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
@ -164,8 +165,7 @@ class FederationHandler(BaseHandler):
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
@defer.inlineCallbacks async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None:
def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False):
""" Process a PDU received via a federation /send/ transaction, or """ Process a PDU received via a federation /send/ transaction, or
via backfill of missing prev_events via backfill of missing prev_events
@ -175,17 +175,15 @@ class FederationHandler(BaseHandler):
pdu (FrozenEvent): received PDU pdu (FrozenEvent): received PDU
sent_to_us_directly (bool): True if this event was pushed to us; False if sent_to_us_directly (bool): True if this event was pushed to us; False if
we pulled it as the result of a missing prev_event. we pulled it as the result of a missing prev_event.
Returns (Deferred): completes with None
""" """
room_id = pdu.room_id room_id = pdu.room_id
event_id = pdu.event_id event_id = pdu.event_id
logger.info("[%s %s] handling received PDU: %s", room_id, event_id, pdu) logger.info("handling received PDU: %s", pdu)
# We reprocess pdus when we have seen them only as outliers # We reprocess pdus when we have seen them only as outliers
existing = yield self.store.get_event( existing = await self.store.get_event(
event_id, allow_none=True, allow_rejected=True event_id, allow_none=True, allow_rejected=True
) )
@ -229,7 +227,7 @@ class FederationHandler(BaseHandler):
# #
# Note that if we were never in the room then we would have already # Note that if we were never in the room then we would have already
# dropped the event, since we wouldn't know the room version. # dropped the event, since we wouldn't know the room version.
is_in_room = yield self.auth.check_host_in_room(room_id, self.server_name) is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
if not is_in_room: if not is_in_room:
logger.info( logger.info(
"[%s %s] Ignoring PDU from %s as we're not in the room", "[%s %s] Ignoring PDU from %s as we're not in the room",
@ -245,12 +243,12 @@ class FederationHandler(BaseHandler):
# Get missing pdus if necessary. # Get missing pdus if necessary.
if not pdu.internal_metadata.is_outlier(): if not pdu.internal_metadata.is_outlier():
# We only backfill backwards to the min depth. # We only backfill backwards to the min depth.
min_depth = yield self.get_min_depth_for_context(pdu.room_id) min_depth = await self.get_min_depth_for_context(pdu.room_id)
logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth) logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth)
prevs = set(pdu.prev_event_ids()) prevs = set(pdu.prev_event_ids())
seen = yield self.store.have_seen_events(prevs) seen = await self.store.have_seen_events(prevs)
if min_depth and pdu.depth < min_depth: if min_depth and pdu.depth < min_depth:
# This is so that we don't notify the user about this # This is so that we don't notify the user about this
@ -270,7 +268,7 @@ class FederationHandler(BaseHandler):
len(missing_prevs), len(missing_prevs),
shortstr(missing_prevs), shortstr(missing_prevs),
) )
with (yield self._room_pdu_linearizer.queue(pdu.room_id)): with (await self._room_pdu_linearizer.queue(pdu.room_id)):
logger.info( logger.info(
"[%s %s] Acquired room lock to fetch %d missing prev_events", "[%s %s] Acquired room lock to fetch %d missing prev_events",
room_id, room_id,
@ -278,13 +276,19 @@ class FederationHandler(BaseHandler):
len(missing_prevs), len(missing_prevs),
) )
yield self._get_missing_events_for_pdu( try:
origin, pdu, prevs, min_depth await self._get_missing_events_for_pdu(
) origin, pdu, prevs, min_depth
)
except Exception as e:
raise Exception(
"Error fetching missing prev_events for %s: %s"
% (event_id, e)
)
# Update the set of things we've seen after trying to # Update the set of things we've seen after trying to
# fetch the missing stuff # fetch the missing stuff
seen = yield self.store.have_seen_events(prevs) seen = await self.store.have_seen_events(prevs)
if not prevs - seen: if not prevs - seen:
logger.info( logger.info(
@ -292,14 +296,6 @@ class FederationHandler(BaseHandler):
room_id, room_id,
event_id, event_id,
) )
elif missing_prevs:
logger.info(
"[%s %s] Not recursively fetching %d missing prev_events: %s",
room_id,
event_id,
len(missing_prevs),
shortstr(missing_prevs),
)
if prevs - seen: if prevs - seen:
# We've still not been able to get all of the prev_events for this event. # We've still not been able to get all of the prev_events for this event.
@ -344,13 +340,19 @@ class FederationHandler(BaseHandler):
affected=pdu.event_id, affected=pdu.event_id,
) )
logger.info(
"Event %s is missing prev_events: calculating state for a "
"backwards extremity",
event_id,
)
# Calculate the state after each of the previous events, and # Calculate the state after each of the previous events, and
# resolve them to find the correct state at the current event. # resolve them to find the correct state at the current event.
auth_chains = set() auth_chains = set()
event_map = {event_id: pdu} event_map = {event_id: pdu}
try: try:
# Get the state of the events we know about # Get the state of the events we know about
ours = yield self.state_store.get_state_groups_ids(room_id, seen) ours = await self.state_store.get_state_groups_ids(room_id, seen)
# state_maps is a list of mappings from (type, state_key) to event_id # state_maps is a list of mappings from (type, state_key) to event_id
state_maps = list( state_maps = list(
@ -364,13 +366,10 @@ class FederationHandler(BaseHandler):
# know about # know about
for p in prevs - seen: for p in prevs - seen:
logger.info( logger.info(
"[%s %s] Requesting state at missing prev_event %s", "Requesting state at missing prev_event %s", event_id,
room_id,
event_id,
p,
) )
room_version = yield self.store.get_room_version(room_id) room_version = await self.store.get_room_version(room_id)
with nested_logging_context(p): with nested_logging_context(p):
# note that if any of the missing prevs share missing state or # note that if any of the missing prevs share missing state or
@ -379,24 +378,10 @@ class FederationHandler(BaseHandler):
( (
remote_state, remote_state,
got_auth_chain, got_auth_chain,
) = yield self.federation_client.get_state_for_room( ) = await self._get_state_for_room(
origin, room_id, p origin, room_id, p, include_event_in_state=True
) )
# we want the state *after* p; get_state_for_room returns the
# state *before* p.
remote_event = yield self.federation_client.get_pdu(
[origin], p, room_version, outlier=True
)
if remote_event is None:
raise Exception(
"Unable to get missing prev_event %s" % (p,)
)
if remote_event.is_state():
remote_state.append(remote_event)
# XXX hrm I'm not convinced that duplicate events will compare # XXX hrm I'm not convinced that duplicate events will compare
# for equality, so I'm not sure this does what the author # for equality, so I'm not sure this does what the author
# hoped. # hoped.
@ -410,7 +395,7 @@ class FederationHandler(BaseHandler):
for x in remote_state: for x in remote_state:
event_map[x.event_id] = x event_map[x.event_id] = x
state_map = yield resolve_events_with_store( state_map = await resolve_events_with_store(
room_version, room_version,
state_maps, state_maps,
event_map, event_map,
@ -422,10 +407,10 @@ class FederationHandler(BaseHandler):
# First though we need to fetch all the events that are in # First though we need to fetch all the events that are in
# state_map, so we can build up the state below. # state_map, so we can build up the state below.
evs = yield self.store.get_events( evs = await self.store.get_events(
list(state_map.values()), list(state_map.values()),
get_prev_content=False, get_prev_content=False,
check_redacted=False, redact_behaviour=EventRedactBehaviour.AS_IS,
) )
event_map.update(evs) event_map.update(evs)
@ -446,12 +431,11 @@ class FederationHandler(BaseHandler):
affected=event_id, affected=event_id,
) )
yield self._process_received_pdu( await self._process_received_pdu(
origin, pdu, state=state, auth_chain=auth_chain origin, pdu, state=state, auth_chain=auth_chain
) )
@defer.inlineCallbacks async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
""" """
Args: Args:
origin (str): Origin of the pdu. Will be called to get the missing events origin (str): Origin of the pdu. Will be called to get the missing events
@ -463,12 +447,12 @@ class FederationHandler(BaseHandler):
room_id = pdu.room_id room_id = pdu.room_id
event_id = pdu.event_id event_id = pdu.event_id
seen = yield self.store.have_seen_events(prevs) seen = await self.store.have_seen_events(prevs)
if not prevs - seen: if not prevs - seen:
return return
latest = yield self.store.get_latest_event_ids_in_room(room_id) latest = await self.store.get_latest_event_ids_in_room(room_id)
# We add the prev events that we have seen to the latest # We add the prev events that we have seen to the latest
# list to ensure the remote server doesn't give them to us # list to ensure the remote server doesn't give them to us
@ -532,7 +516,7 @@ class FederationHandler(BaseHandler):
# All that said: Let's try increasing the timout to 60s and see what happens. # All that said: Let's try increasing the timout to 60s and see what happens.
try: try:
missing_events = yield self.federation_client.get_missing_events( missing_events = await self.federation_client.get_missing_events(
origin, origin,
room_id, room_id,
earliest_events_ids=list(latest), earliest_events_ids=list(latest),
@ -571,7 +555,7 @@ class FederationHandler(BaseHandler):
) )
with nested_logging_context(ev.event_id): with nested_logging_context(ev.event_id):
try: try:
yield self.on_receive_pdu(origin, ev, sent_to_us_directly=False) await self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
except FederationError as e: except FederationError as e:
if e.code == 403: if e.code == 403:
logger.warning( logger.warning(
@ -583,8 +567,116 @@ class FederationHandler(BaseHandler):
else: else:
raise raise
@defer.inlineCallbacks async def _get_state_for_room(
def _process_received_pdu(self, origin, event, state, auth_chain): self,
destination: str,
room_id: str,
event_id: str,
include_event_in_state: bool = False,
) -> Tuple[List[EventBase], List[EventBase]]:
"""Requests all of the room state at a given event from a remote homeserver.
Args:
destination: The remote homeserver to query for the state.
room_id: The id of the room we're interested in.
event_id: The id of the event we want the state at.
include_event_in_state: if true, the event itself will be included in the
returned state event list.
Returns:
A list of events in the state, possibly including the event itself, and
a list of events in the auth chain for the given event.
"""
(
state_event_ids,
auth_event_ids,
) = await self.federation_client.get_room_state_ids(
destination, room_id, event_id=event_id
)
desired_events = set(state_event_ids + auth_event_ids)
if include_event_in_state:
desired_events.add(event_id)
event_map = await self._get_events_from_store_or_dest(
destination, room_id, desired_events
)
failed_to_fetch = desired_events - event_map.keys()
if failed_to_fetch:
logger.warning(
"Failed to fetch missing state/auth events for %s %s",
event_id,
failed_to_fetch,
)
remote_state = [
event_map[e_id] for e_id in state_event_ids if e_id in event_map
]
if include_event_in_state:
remote_event = event_map.get(event_id)
if not remote_event:
raise Exception("Unable to get missing prev_event %s" % (event_id,))
if remote_event.is_state():
remote_state.append(remote_event)
auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
auth_chain.sort(key=lambda e: e.depth)
return remote_state, auth_chain
async def _get_events_from_store_or_dest(
self, destination: str, room_id: str, event_ids: Iterable[str]
) -> Dict[str, EventBase]:
"""Fetch events from a remote destination, checking if we already have them.
Args:
destination
room_id
event_ids
Returns:
map from event_id to event
"""
fetched_events = await self.store.get_events(event_ids, allow_rejected=True)
missing_events = set(event_ids) - fetched_events.keys()
if not missing_events:
return fetched_events
logger.debug(
"Fetching unknown state/auth events %s for room %s",
missing_events,
event_ids,
)
room_version = await self.store.get_room_version(room_id)
# XXX 20 requests at once? really?
for batch in batch_iter(missing_events, 20):
deferreds = [
run_in_background(
self.federation_client.get_pdu,
destinations=[destination],
event_id=e_id,
room_version=room_version,
)
for e_id in batch
]
res = await make_deferred_yieldable(
defer.DeferredList(deferreds, consumeErrors=True)
)
for success, result in res:
if success and result:
fetched_events[result.event_id] = result
return fetched_events
async def _process_received_pdu(self, origin, event, state, auth_chain):
""" Called when we have a new pdu. We need to do auth checks and put it """ Called when we have a new pdu. We need to do auth checks and put it
through the StateHandler. through the StateHandler.
""" """
@ -599,7 +691,7 @@ class FederationHandler(BaseHandler):
if auth_chain: if auth_chain:
event_ids |= {e.event_id for e in auth_chain} event_ids |= {e.event_id for e in auth_chain}
seen_ids = yield self.store.have_seen_events(event_ids) seen_ids = await self.store.have_seen_events(event_ids)
if state and auth_chain is not None: if state and auth_chain is not None:
# If we have any state or auth_chain given to us by the replication # If we have any state or auth_chain given to us by the replication
@ -626,18 +718,18 @@ class FederationHandler(BaseHandler):
event_id, event_id,
[e.event.event_id for e in event_infos], [e.event.event_id for e in event_infos],
) )
yield self._handle_new_events(origin, event_infos) await self._handle_new_events(origin, event_infos)
try: try:
context = yield self._handle_new_event(origin, event, state=state) context = await self._handle_new_event(origin, event, state=state)
except AuthError as e: except AuthError as e:
raise FederationError("ERROR", e.code, e.msg, affected=event.event_id) raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
room = yield self.store.get_room(room_id) room = await self.store.get_room(room_id)
if not room: if not room:
try: try:
yield self.store.store_room( await self.store.store_room(
room_id=room_id, room_creator_user_id="", is_public=False room_id=room_id, room_creator_user_id="", is_public=False
) )
except StoreError: except StoreError:
@ -650,11 +742,11 @@ class FederationHandler(BaseHandler):
# changing their profile info. # changing their profile info.
newly_joined = True newly_joined = True
prev_state_ids = yield context.get_prev_state_ids(self.store) prev_state_ids = await context.get_prev_state_ids(self.store)
prev_state_id = prev_state_ids.get((event.type, event.state_key)) prev_state_id = prev_state_ids.get((event.type, event.state_key))
if prev_state_id: if prev_state_id:
prev_state = yield self.store.get_event( prev_state = await self.store.get_event(
prev_state_id, allow_none=True prev_state_id, allow_none=True
) )
if prev_state and prev_state.membership == Membership.JOIN: if prev_state and prev_state.membership == Membership.JOIN:
@ -662,11 +754,10 @@ class FederationHandler(BaseHandler):
if newly_joined: if newly_joined:
user = UserID.from_string(event.state_key) user = UserID.from_string(event.state_key)
yield self.user_joined_room(user, room_id) await self.user_joined_room(user, room_id)
@log_function @log_function
@defer.inlineCallbacks async def backfill(self, dest, room_id, limit, extremities):
def backfill(self, dest, room_id, limit, extremities):
""" Trigger a backfill request to `dest` for the given `room_id` """ Trigger a backfill request to `dest` for the given `room_id`
This will attempt to get more events from the remote. If the other side This will attempt to get more events from the remote. If the other side
@ -683,9 +774,9 @@ class FederationHandler(BaseHandler):
if dest == self.server_name: if dest == self.server_name:
raise SynapseError(400, "Can't backfill from self.") raise SynapseError(400, "Can't backfill from self.")
room_version = yield self.store.get_room_version(room_id) room_version = await self.store.get_room_version(room_id)
events = yield self.federation_client.backfill( events = await self.federation_client.backfill(
dest, room_id, limit=limit, extremities=extremities dest, room_id, limit=limit, extremities=extremities
) )
@ -700,7 +791,7 @@ class FederationHandler(BaseHandler):
# self._sanity_check_event(ev) # self._sanity_check_event(ev)
# Don't bother processing events we already have. # Don't bother processing events we already have.
seen_events = yield self.store.have_events_in_timeline( seen_events = await self.store.have_events_in_timeline(
set(e.event_id for e in events) set(e.event_id for e in events)
) )
@ -723,7 +814,7 @@ class FederationHandler(BaseHandler):
state_events = {} state_events = {}
events_to_state = {} events_to_state = {}
for e_id in edges: for e_id in edges:
state, auth = yield self.federation_client.get_state_for_room( state, auth = await self._get_state_for_room(
destination=dest, room_id=room_id, event_id=e_id destination=dest, room_id=room_id, event_id=e_id
) )
auth_events.update({a.event_id: a for a in auth}) auth_events.update({a.event_id: a for a in auth})
@ -748,7 +839,7 @@ class FederationHandler(BaseHandler):
# We repeatedly do this until we stop finding new auth events. # We repeatedly do this until we stop finding new auth events.
while missing_auth - failed_to_fetch: while missing_auth - failed_to_fetch:
logger.info("Missing auth for backfill: %r", missing_auth) logger.info("Missing auth for backfill: %r", missing_auth)
ret_events = yield self.store.get_events(missing_auth - failed_to_fetch) ret_events = await self.store.get_events(missing_auth - failed_to_fetch)
auth_events.update(ret_events) auth_events.update(ret_events)
required_auth.update( required_auth.update(
@ -762,7 +853,7 @@ class FederationHandler(BaseHandler):
missing_auth - failed_to_fetch, missing_auth - failed_to_fetch,
) )
results = yield make_deferred_yieldable( results = await make_deferred_yieldable(
defer.gatherResults( defer.gatherResults(
[ [
run_in_background( run_in_background(
@ -789,7 +880,7 @@ class FederationHandler(BaseHandler):
failed_to_fetch = missing_auth - set(auth_events) failed_to_fetch = missing_auth - set(auth_events)
seen_events = yield self.store.have_seen_events( seen_events = await self.store.have_seen_events(
set(auth_events.keys()) | set(state_events.keys()) set(auth_events.keys()) | set(state_events.keys())
) )
@ -851,7 +942,7 @@ class FederationHandler(BaseHandler):
) )
) )
yield self._handle_new_events(dest, ev_infos, backfilled=True) await self._handle_new_events(dest, ev_infos, backfilled=True)
# Step 2: Persist the rest of the events in the chunk one by one # Step 2: Persist the rest of the events in the chunk one by one
events.sort(key=lambda e: e.depth) events.sort(key=lambda e: e.depth)
@ -867,16 +958,15 @@ class FederationHandler(BaseHandler):
# We store these one at a time since each event depends on the # We store these one at a time since each event depends on the
# previous to work out the state. # previous to work out the state.
# TODO: We can probably do something more clever here. # TODO: We can probably do something more clever here.
yield self._handle_new_event(dest, event, backfilled=True) await self._handle_new_event(dest, event, backfilled=True)
return events return events
@defer.inlineCallbacks async def maybe_backfill(self, room_id, current_depth):
def maybe_backfill(self, room_id, current_depth):
"""Checks the database to see if we should backfill before paginating, """Checks the database to see if we should backfill before paginating,
and if so do. and if so do.
""" """
extremities = yield self.store.get_oldest_events_with_depth_in_room(room_id) extremities = await self.store.get_oldest_events_with_depth_in_room(room_id)
if not extremities: if not extremities:
logger.debug("Not backfilling as no extremeties found.") logger.debug("Not backfilling as no extremeties found.")
@ -908,15 +998,17 @@ class FederationHandler(BaseHandler):
# state *before* the event, ignoring the special casing certain event # state *before* the event, ignoring the special casing certain event
# types have. # types have.
forward_events = yield self.store.get_successor_events(list(extremities)) forward_events = await self.store.get_successor_events(list(extremities))
extremities_events = yield self.store.get_events( extremities_events = await self.store.get_events(
forward_events, check_redacted=False, get_prev_content=False forward_events,
redact_behaviour=EventRedactBehaviour.AS_IS,
get_prev_content=False,
) )
# We set `check_history_visibility_only` as we might otherwise get false # We set `check_history_visibility_only` as we might otherwise get false
# positives from users having been erased. # positives from users having been erased.
filtered_extremities = yield filter_events_for_server( filtered_extremities = await filter_events_for_server(
self.storage, self.storage,
self.server_name, self.server_name,
list(extremities_events.values()), list(extremities_events.values()),
@ -946,7 +1038,7 @@ class FederationHandler(BaseHandler):
# First we try hosts that are already in the room # First we try hosts that are already in the room
# TODO: HEURISTIC ALERT. # TODO: HEURISTIC ALERT.
curr_state = yield self.state_handler.get_current_state(room_id) curr_state = await self.state_handler.get_current_state(room_id)
def get_domains_from_state(state): def get_domains_from_state(state):
"""Get joined domains from state """Get joined domains from state
@ -985,12 +1077,11 @@ class FederationHandler(BaseHandler):
domain for domain, depth in curr_domains if domain != self.server_name domain for domain, depth in curr_domains if domain != self.server_name
] ]
@defer.inlineCallbacks async def try_backfill(domains):
def try_backfill(domains):
# TODO: Should we try multiple of these at a time? # TODO: Should we try multiple of these at a time?
for dom in domains: for dom in domains:
try: try:
yield self.backfill( await self.backfill(
dom, room_id, limit=100, extremities=extremities dom, room_id, limit=100, extremities=extremities
) )
# If this succeeded then we probably already have the # If this succeeded then we probably already have the
@ -1021,7 +1112,7 @@ class FederationHandler(BaseHandler):
return False return False
success = yield try_backfill(likely_domains) success = await try_backfill(likely_domains)
if success: if success:
return True return True
@ -1035,7 +1126,7 @@ class FederationHandler(BaseHandler):
logger.debug("calling resolve_state_groups in _maybe_backfill") logger.debug("calling resolve_state_groups in _maybe_backfill")
resolve = preserve_fn(self.state_handler.resolve_state_groups_for_events) resolve = preserve_fn(self.state_handler.resolve_state_groups_for_events)
states = yield make_deferred_yieldable( states = await make_deferred_yieldable(
defer.gatherResults( defer.gatherResults(
[resolve(room_id, [e]) for e in event_ids], consumeErrors=True [resolve(room_id, [e]) for e in event_ids], consumeErrors=True
) )
@ -1045,7 +1136,7 @@ class FederationHandler(BaseHandler):
# event_ids. # event_ids.
states = dict(zip(event_ids, [s.state for s in states])) states = dict(zip(event_ids, [s.state for s in states]))
state_map = yield self.store.get_events( state_map = await self.store.get_events(
[e_id for ids in itervalues(states) for e_id in itervalues(ids)], [e_id for ids in itervalues(states) for e_id in itervalues(ids)],
get_prev_content=False, get_prev_content=False,
) )
@ -1061,7 +1152,7 @@ class FederationHandler(BaseHandler):
for e_id, _ in sorted_extremeties_tuple: for e_id, _ in sorted_extremeties_tuple:
likely_domains = get_domains_from_state(states[e_id]) likely_domains = get_domains_from_state(states[e_id])
success = yield try_backfill( success = await try_backfill(
[dom for dom, _ in likely_domains if dom not in tried_domains] [dom for dom, _ in likely_domains if dom not in tried_domains]
) )
if success: if success:
@ -1210,7 +1301,7 @@ class FederationHandler(BaseHandler):
# Check whether this room is the result of an upgrade of a room we already know # Check whether this room is the result of an upgrade of a room we already know
# about. If so, migrate over user information # about. If so, migrate over user information
predecessor = yield self.store.get_room_predecessor(room_id) predecessor = yield self.store.get_room_predecessor(room_id)
if not predecessor: if not predecessor or not isinstance(predecessor.get("room_id"), str):
return return
old_room_id = predecessor["room_id"] old_room_id = predecessor["room_id"]
logger.debug( logger.debug(
@ -1238,8 +1329,7 @@ class FederationHandler(BaseHandler):
return True return True
@defer.inlineCallbacks async def _handle_queued_pdus(self, room_queue):
def _handle_queued_pdus(self, room_queue):
"""Process PDUs which got queued up while we were busy send_joining. """Process PDUs which got queued up while we were busy send_joining.
Args: Args:
@ -1255,7 +1345,7 @@ class FederationHandler(BaseHandler):
p.room_id, p.room_id,
) )
with nested_logging_context(p.event_id): with nested_logging_context(p.event_id):
yield self.on_receive_pdu(origin, p, sent_to_us_directly=True) await self.on_receive_pdu(origin, p, sent_to_us_directly=True)
except Exception as e: except Exception as e:
logger.warning( logger.warning(
"Error handling queued PDU %s from %s: %s", p.event_id, origin, e "Error handling queued PDU %s from %s: %s", p.event_id, origin, e
@ -1453,7 +1543,7 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def do_remotely_reject_invite(self, target_hosts, room_id, user_id, content): def do_remotely_reject_invite(self, target_hosts, room_id, user_id, content):
origin, event, event_format_version = yield self._make_and_verify_event( origin, event, event_format_version = yield self._make_and_verify_event(
target_hosts, room_id, user_id, "leave", content=content, target_hosts, room_id, user_id, "leave", content=content
) )
# Mark as outlier as we don't have any state for this event; we're not # Mark as outlier as we don't have any state for this event; we're not
# even in the room. # even in the room.
@ -2814,7 +2904,7 @@ class FederationHandler(BaseHandler):
room_id=room_id, user_id=user.to_string(), change="joined" room_id=room_id, user_id=user.to_string(), change="joined"
) )
else: else:
return user_joined_room(self.distributor, user, room_id) return defer.succeed(user_joined_room(self.distributor, user, room_id))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_room_complexity(self, remote_room_hosts, room_id): def get_room_complexity(self, remote_room_hosts, room_id):

View file

@ -46,6 +46,7 @@ from synapse.events.validator import EventValidator
from synapse.logging.context import run_in_background from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.replication.http.send_event import ReplicationSendEventRestServlet
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.types import RoomAlias, UserID, create_requester from synapse.types import RoomAlias, UserID, create_requester
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
@ -875,7 +876,7 @@ class EventCreationHandler(object):
if event.type == EventTypes.Redaction: if event.type == EventTypes.Redaction:
original_event = yield self.store.get_event( original_event = yield self.store.get_event(
event.redacts, event.redacts,
check_redacted=False, redact_behaviour=EventRedactBehaviour.AS_IS,
get_prev_content=False, get_prev_content=False,
allow_rejected=False, allow_rejected=False,
allow_none=True, allow_none=True,
@ -952,7 +953,7 @@ class EventCreationHandler(object):
if event.type == EventTypes.Redaction: if event.type == EventTypes.Redaction:
original_event = yield self.store.get_event( original_event = yield self.store.get_event(
event.redacts, event.redacts,
check_redacted=False, redact_behaviour=EventRedactBehaviour.AS_IS,
get_prev_content=False, get_prev_content=False,
allow_rejected=False, allow_rejected=False,
allow_none=True, allow_none=True,

View file

@ -280,8 +280,7 @@ class PaginationHandler(object):
await self.storage.purge_events.purge_room(room_id) await self.storage.purge_events.purge_room(room_id)
@defer.inlineCallbacks async def get_messages(
def get_messages(
self, self,
requester, requester,
room_id=None, room_id=None,
@ -307,7 +306,7 @@ class PaginationHandler(object):
room_token = pagin_config.from_token.room_key room_token = pagin_config.from_token.room_key
else: else:
pagin_config.from_token = ( pagin_config.from_token = (
yield self.hs.get_event_sources().get_current_token_for_pagination() await self.hs.get_event_sources().get_current_token_for_pagination()
) )
room_token = pagin_config.from_token.room_key room_token = pagin_config.from_token.room_key
@ -319,11 +318,11 @@ class PaginationHandler(object):
source_config = pagin_config.get_source_config("room") source_config = pagin_config.get_source_config("room")
with (yield self.pagination_lock.read(room_id)): with (await self.pagination_lock.read(room_id)):
( (
membership, membership,
member_event_id, member_event_id,
) = yield self.auth.check_in_room_or_world_readable(room_id, user_id) ) = await self.auth.check_in_room_or_world_readable(room_id, user_id)
if source_config.direction == "b": if source_config.direction == "b":
# if we're going backwards, we might need to backfill. This # if we're going backwards, we might need to backfill. This
@ -331,7 +330,7 @@ class PaginationHandler(object):
if room_token.topological: if room_token.topological:
max_topo = room_token.topological max_topo = room_token.topological
else: else:
max_topo = yield self.store.get_max_topological_token( max_topo = await self.store.get_max_topological_token(
room_id, room_token.stream room_id, room_token.stream
) )
@ -339,18 +338,18 @@ class PaginationHandler(object):
# If they have left the room then clamp the token to be before # If they have left the room then clamp the token to be before
# they left the room, to save the effort of loading from the # they left the room, to save the effort of loading from the
# database. # database.
leave_token = yield self.store.get_topological_token_for_event( leave_token = await self.store.get_topological_token_for_event(
member_event_id member_event_id
) )
leave_token = RoomStreamToken.parse(leave_token) leave_token = RoomStreamToken.parse(leave_token)
if leave_token.topological < max_topo: if leave_token.topological < max_topo:
source_config.from_key = str(leave_token) source_config.from_key = str(leave_token)
yield self.hs.get_handlers().federation_handler.maybe_backfill( await self.hs.get_handlers().federation_handler.maybe_backfill(
room_id, max_topo room_id, max_topo
) )
events, next_key = yield self.store.paginate_room_events( events, next_key = await self.store.paginate_room_events(
room_id=room_id, room_id=room_id,
from_key=source_config.from_key, from_key=source_config.from_key,
to_key=source_config.to_key, to_key=source_config.to_key,
@ -365,7 +364,7 @@ class PaginationHandler(object):
if event_filter: if event_filter:
events = event_filter.filter(events) events = event_filter.filter(events)
events = yield filter_events_for_client( events = await filter_events_for_client(
self.storage, user_id, events, is_peeking=(member_event_id is None) self.storage, user_id, events, is_peeking=(member_event_id is None)
) )
@ -385,19 +384,19 @@ class PaginationHandler(object):
(EventTypes.Member, event.sender) for event in events (EventTypes.Member, event.sender) for event in events
) )
state_ids = yield self.state_store.get_state_ids_for_event( state_ids = await self.state_store.get_state_ids_for_event(
events[0].event_id, state_filter=state_filter events[0].event_id, state_filter=state_filter
) )
if state_ids: if state_ids:
state = yield self.store.get_events(list(state_ids.values())) state = await self.store.get_events(list(state_ids.values()))
state = state.values() state = state.values()
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
chunk = { chunk = {
"chunk": ( "chunk": (
yield self._event_serializer.serialize_events( await self._event_serializer.serialize_events(
events, time_now, as_client_event=as_client_event events, time_now, as_client_event=as_client_event
) )
), ),
@ -406,7 +405,7 @@ class PaginationHandler(object):
} }
if state: if state:
chunk["state"] = yield self._event_serializer.serialize_events( chunk["state"] = await self._event_serializer.serialize_events(
state, time_now, as_client_event=as_client_event state, time_now, as_client_event=as_client_event
) )

View file

@ -13,20 +13,36 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
import re
from typing import Tuple
import attr import attr
import saml2 import saml2
import saml2.response
from saml2.client import Saml2Client from saml2.client import Saml2Client
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.config import ConfigError
from synapse.http.servlet import parse_string from synapse.http.servlet import parse_string
from synapse.rest.client.v1.login import SSOAuthHandler from synapse.rest.client.v1.login import SSOAuthHandler
from synapse.types import UserID, map_username_to_mxid_localpart from synapse.types import (
UserID,
map_username_to_mxid_localpart,
mxid_localpart_allowed_characters,
)
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@attr.s
class Saml2SessionData:
"""Data we track about SAML2 sessions"""
# time the session was created, in milliseconds
creation_time = attr.ib()
class SamlHandler: class SamlHandler:
def __init__(self, hs): def __init__(self, hs):
self._saml_client = Saml2Client(hs.config.saml2_sp_config) self._saml_client = Saml2Client(hs.config.saml2_sp_config)
@ -37,11 +53,14 @@ class SamlHandler:
self._datastore = hs.get_datastore() self._datastore = hs.get_datastore()
self._hostname = hs.hostname self._hostname = hs.hostname
self._saml2_session_lifetime = hs.config.saml2_session_lifetime self._saml2_session_lifetime = hs.config.saml2_session_lifetime
self._mxid_source_attribute = hs.config.saml2_mxid_source_attribute
self._grandfathered_mxid_source_attribute = ( self._grandfathered_mxid_source_attribute = (
hs.config.saml2_grandfathered_mxid_source_attribute hs.config.saml2_grandfathered_mxid_source_attribute
) )
self._mxid_mapper = hs.config.saml2_mxid_mapper
# plugin to do custom mapping from saml response to mxid
self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class(
hs.config.saml2_user_mapping_provider_config
)
# identifier for the external_ids table # identifier for the external_ids table
self._auth_provider_id = "saml" self._auth_provider_id = "saml"
@ -118,22 +137,10 @@ class SamlHandler:
remote_user_id = saml2_auth.ava["uid"][0] remote_user_id = saml2_auth.ava["uid"][0]
except KeyError: except KeyError:
logger.warning("SAML2 response lacks a 'uid' attestation") logger.warning("SAML2 response lacks a 'uid' attestation")
raise SynapseError(400, "uid not in SAML2 response") raise SynapseError(400, "'uid' not in SAML2 response")
try:
mxid_source = saml2_auth.ava[self._mxid_source_attribute][0]
except KeyError:
logger.warning(
"SAML2 response lacks a '%s' attestation", self._mxid_source_attribute
)
raise SynapseError(
400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
)
self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None) self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None)
displayName = saml2_auth.ava.get("displayName", [None])[0]
with (await self._mapping_lock.queue(self._auth_provider_id)): with (await self._mapping_lock.queue(self._auth_provider_id)):
# first of all, check if we already have a mapping for this user # first of all, check if we already have a mapping for this user
logger.info( logger.info(
@ -173,22 +180,46 @@ class SamlHandler:
) )
return registered_user_id return registered_user_id
# figure out a new mxid for this user # Map saml response to user attributes using the configured mapping provider
base_mxid_localpart = self._mxid_mapper(mxid_source) for i in range(1000):
attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes(
saml2_auth, i
)
suffix = 0 logger.debug(
while True: "Retrieved SAML attributes from user mapping provider: %s "
localpart = base_mxid_localpart + (str(suffix) if suffix else "") "(attempt %d)",
attribute_dict,
i,
)
localpart = attribute_dict.get("mxid_localpart")
if not localpart:
logger.error(
"SAML mapping provider plugin did not return a "
"mxid_localpart object"
)
raise SynapseError(500, "Error parsing SAML2 response")
displayname = attribute_dict.get("displayname")
# Check if this mxid already exists
if not await self._datastore.get_users_by_id_case_insensitive( if not await self._datastore.get_users_by_id_case_insensitive(
UserID(localpart, self._hostname).to_string() UserID(localpart, self._hostname).to_string()
): ):
# This mxid is free
break break
suffix += 1 else:
logger.info("Allocating mxid for new user with localpart %s", localpart) # Unable to generate a username in 1000 iterations
# Break and return error to the user
raise SynapseError(
500, "Unable to generate a Matrix ID from the SAML response"
)
registered_user_id = await self._registration_handler.register_user( registered_user_id = await self._registration_handler.register_user(
localpart=localpart, default_display_name=displayName localpart=localpart, default_display_name=displayname
) )
await self._datastore.record_user_external_id( await self._datastore.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id self._auth_provider_id, remote_user_id, registered_user_id
) )
@ -205,9 +236,120 @@ class SamlHandler:
del self._outstanding_requests_dict[reqid] del self._outstanding_requests_dict[reqid]
@attr.s DOT_REPLACE_PATTERN = re.compile(
class Saml2SessionData: ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
"""Data we track about SAML2 sessions""" )
# time the session was created, in milliseconds
creation_time = attr.ib() def dot_replace_for_mxid(username: str) -> str:
username = username.lower()
username = DOT_REPLACE_PATTERN.sub(".", username)
# regular mxids aren't allowed to start with an underscore either
username = re.sub("^_", "", username)
return username
MXID_MAPPER_MAP = {
"hexencode": map_username_to_mxid_localpart,
"dotreplace": dot_replace_for_mxid,
}
@attr.s
class SamlConfig(object):
mxid_source_attribute = attr.ib()
mxid_mapper = attr.ib()
class DefaultSamlMappingProvider(object):
__version__ = "0.0.1"
def __init__(self, parsed_config: SamlConfig):
"""The default SAML user mapping provider
Args:
parsed_config: Module configuration
"""
self._mxid_source_attribute = parsed_config.mxid_source_attribute
self._mxid_mapper = parsed_config.mxid_mapper
def saml_response_to_user_attributes(
self, saml_response: saml2.response.AuthnResponse, failures: int = 0,
) -> dict:
"""Maps some text from a SAML response to attributes of a new user
Args:
saml_response: A SAML auth response object
failures: How many times a call to this function with this
saml_response has resulted in a failure
Returns:
dict: A dict containing new user attributes. Possible keys:
* mxid_localpart (str): Required. The localpart of the user's mxid
* displayname (str): The displayname of the user
"""
try:
mxid_source = saml_response.ava[self._mxid_source_attribute][0]
except KeyError:
logger.warning(
"SAML2 response lacks a '%s' attestation", self._mxid_source_attribute,
)
raise SynapseError(
400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
)
# Use the configured mapper for this mxid_source
base_mxid_localpart = self._mxid_mapper(mxid_source)
# Append suffix integer if last call to this function failed to produce
# a usable mxid
localpart = base_mxid_localpart + (str(failures) if failures else "")
# Retrieve the display name from the saml response
# If displayname is None, the mxid_localpart will be used instead
displayname = saml_response.ava.get("displayName", [None])[0]
return {
"mxid_localpart": localpart,
"displayname": displayname,
}
@staticmethod
def parse_config(config: dict) -> SamlConfig:
"""Parse the dict provided by the homeserver's config
Args:
config: A dictionary containing configuration options for this provider
Returns:
SamlConfig: A custom config object for this module
"""
# Parse config options and use defaults where necessary
mxid_source_attribute = config.get("mxid_source_attribute", "uid")
mapping_type = config.get("mxid_mapping", "hexencode")
# Retrieve the associating mapping function
try:
mxid_mapper = MXID_MAPPER_MAP[mapping_type]
except KeyError:
raise ConfigError(
"saml2_config.user_mapping_provider.config: '%s' is not a valid "
"mxid_mapping value" % (mapping_type,)
)
return SamlConfig(mxid_source_attribute, mxid_mapper)
@staticmethod
def get_saml_attributes(config: SamlConfig) -> Tuple[set, set]:
"""Returns the required attributes of a SAML
Args:
config: A SamlConfig object containing configuration params for this provider
Returns:
tuple[set,set]: The first set equates to the saml auth response
attributes that are required for the module to function, whereas the
second set consists of those attributes which can be used if
available, but are not necessary
"""
return {"uid", config.mxid_source_attribute}, {"displayName"}

View file

@ -21,7 +21,7 @@ from unpaddedbase64 import decode_base64, encode_base64
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import SynapseError from synapse.api.errors import NotFoundError, SynapseError
from synapse.api.filtering import Filter from synapse.api.filtering import Filter
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
@ -37,6 +37,7 @@ class SearchHandler(BaseHandler):
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
self.storage = hs.get_storage() self.storage = hs.get_storage()
self.state_store = self.storage.state self.state_store = self.storage.state
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def get_old_rooms_from_upgraded_room(self, room_id): def get_old_rooms_from_upgraded_room(self, room_id):
@ -53,23 +54,38 @@ class SearchHandler(BaseHandler):
room_id (str): id of the room to search through. room_id (str): id of the room to search through.
Returns: Returns:
Deferred[iterable[unicode]]: predecessor room ids Deferred[iterable[str]]: predecessor room ids
""" """
historical_room_ids = [] historical_room_ids = []
while True: # The initial room must have been known for us to get this far
predecessor = yield self.store.get_room_predecessor(room_id) predecessor = yield self.store.get_room_predecessor(room_id)
# If no predecessor, assume we've hit a dead end while True:
if not predecessor: if not predecessor:
# We have reached the end of the chain of predecessors
break break
# Add predecessor's room ID if not isinstance(predecessor.get("room_id"), str):
historical_room_ids.append(predecessor["room_id"]) # This predecessor object is malformed. Exit here
break
# Scan through the old room for further predecessors predecessor_room_id = predecessor["room_id"]
room_id = predecessor["room_id"]
# Don't add it to the list until we have checked that we are in the room
try:
next_predecessor_room = yield self.store.get_room_predecessor(
predecessor_room_id
)
except NotFoundError:
# The predecessor is not a known room, so we are done here
break
historical_room_ids.append(predecessor_room_id)
# And repeat
predecessor = next_predecessor_room
return historical_room_ids return historical_room_ids

View file

@ -23,6 +23,7 @@ them.
See doc/log_contexts.rst for details on how this works. See doc/log_contexts.rst for details on how this works.
""" """
import inspect
import logging import logging
import threading import threading
import types import types
@ -612,7 +613,8 @@ def run_in_background(f, *args, **kwargs):
def make_deferred_yieldable(deferred): def make_deferred_yieldable(deferred):
"""Given a deferred, make it follow the Synapse logcontext rules: """Given a deferred (or coroutine), make it follow the Synapse logcontext
rules:
If the deferred has completed (or is not actually a Deferred), essentially If the deferred has completed (or is not actually a Deferred), essentially
does nothing (just returns another completed deferred with the does nothing (just returns another completed deferred with the
@ -624,6 +626,13 @@ def make_deferred_yieldable(deferred):
(This is more-or-less the opposite operation to run_in_background.) (This is more-or-less the opposite operation to run_in_background.)
""" """
if inspect.isawaitable(deferred):
# If we're given a coroutine we convert it to a deferred so that we
# run it and find out if it immediately finishes, it it does then we
# don't need to fiddle with log contexts at all and can return
# immediately.
deferred = defer.ensureDeferred(deferred)
if not isinstance(deferred, defer.Deferred): if not isinstance(deferred, defer.Deferred):
return deferred return deferred

View file

@ -20,6 +20,7 @@ import six
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.main.cache import CURRENT_STATE_CACHE_NAME from synapse.storage.data_stores.main.cache import CURRENT_STATE_CACHE_NAME
from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from ._slaved_id_tracker import SlavedIdTracker from ._slaved_id_tracker import SlavedIdTracker
@ -35,8 +36,8 @@ def __func__(inp):
class BaseSlavedStore(SQLBaseStore): class BaseSlavedStore(SQLBaseStore):
def __init__(self, db_conn, hs): def __init__(self, database: Database, db_conn, hs):
super(BaseSlavedStore, self).__init__(db_conn, hs) super(BaseSlavedStore, self).__init__(database, db_conn, hs)
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen = SlavedIdTracker( self._cache_id_gen = SlavedIdTracker(
db_conn, "cache_invalidation_stream", "stream_id" db_conn, "cache_invalidation_stream", "stream_id"

View file

@ -18,15 +18,16 @@ from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore
from synapse.storage.data_stores.main.tags import TagsWorkerStore from synapse.storage.data_stores.main.tags import TagsWorkerStore
from synapse.storage.database import Database
class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore): class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore):
def __init__(self, db_conn, hs): def __init__(self, database: Database, db_conn, hs):
self._account_data_id_gen = SlavedIdTracker( self._account_data_id_gen = SlavedIdTracker(
db_conn, "account_data_max_stream_id", "stream_id" db_conn, "account_data_max_stream_id", "stream_id"
) )
super(SlavedAccountDataStore, self).__init__(db_conn, hs) super(SlavedAccountDataStore, self).__init__(database, db_conn, hs)
def get_max_account_data_stream_id(self): def get_max_account_data_stream_id(self):
return self._account_data_id_gen.get_current_token() return self._account_data_id_gen.get_current_token()

View file

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
from synapse.storage.data_stores.main.client_ips import LAST_SEEN_GRANULARITY from synapse.storage.data_stores.main.client_ips import LAST_SEEN_GRANULARITY
from synapse.storage.database import Database
from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.caches.descriptors import Cache from synapse.util.caches.descriptors import Cache
@ -21,8 +22,8 @@ from ._base import BaseSlavedStore
class SlavedClientIpStore(BaseSlavedStore): class SlavedClientIpStore(BaseSlavedStore):
def __init__(self, db_conn, hs): def __init__(self, database: Database, db_conn, hs):
super(SlavedClientIpStore, self).__init__(db_conn, hs) super(SlavedClientIpStore, self).__init__(database, db_conn, hs)
self.client_ip_last_seen = Cache( self.client_ip_last_seen = Cache(
name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR

View file

@ -16,13 +16,14 @@
from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage.data_stores.main.deviceinbox import DeviceInboxWorkerStore from synapse.storage.data_stores.main.deviceinbox import DeviceInboxWorkerStore
from synapse.storage.database import Database
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
def __init__(self, db_conn, hs): def __init__(self, database: Database, db_conn, hs):
super(SlavedDeviceInboxStore, self).__init__(db_conn, hs) super(SlavedDeviceInboxStore, self).__init__(database, db_conn, hs)
self._device_inbox_id_gen = SlavedIdTracker( self._device_inbox_id_gen = SlavedIdTracker(
db_conn, "device_max_stream_id", "stream_id" db_conn, "device_max_stream_id", "stream_id"
) )

View file

@ -18,12 +18,13 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
from synapse.storage.data_stores.main.devices import DeviceWorkerStore from synapse.storage.data_stores.main.devices import DeviceWorkerStore
from synapse.storage.data_stores.main.end_to_end_keys import EndToEndKeyWorkerStore from synapse.storage.data_stores.main.end_to_end_keys import EndToEndKeyWorkerStore
from synapse.storage.database import Database
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore): class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore):
def __init__(self, db_conn, hs): def __init__(self, database: Database, db_conn, hs):
super(SlavedDeviceStore, self).__init__(db_conn, hs) super(SlavedDeviceStore, self).__init__(database, db_conn, hs)
self.hs = hs self.hs = hs

Some files were not shown because too many files have changed in this diff Show more