Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes

This commit is contained in:
Richard van der Hoff 2021-06-16 15:35:00 +01:00
commit 89013b99bd
100 changed files with 2770 additions and 763 deletions

View file

@ -305,11 +305,29 @@ jobs:
with:
path: synapse
- name: Run actions/checkout@v2 for complement
uses: actions/checkout@v2
with:
repository: "matrix-org/complement"
path: complement
# Attempt to check out the same branch of Complement as the PR. If it
# doesn't exist, fallback to master.
- name: Checkout complement
shell: bash
run: |
mkdir -p complement
# Attempt to use the version of complement which best matches the current
# build. Depending on whether this is a PR or release, etc. we need to
# use different fallbacks.
#
# 1. First check if there's a similarly named branch (GITHUB_HEAD_REF
# for pull requests, otherwise GITHUB_REF).
# 2. Attempt to use the base branch, e.g. when merging into release-vX.Y
# (GITHUB_BASE_REF for pull requests).
# 3. Use the default complement branch ("master").
for BRANCH_NAME in "$GITHUB_HEAD_REF" "$GITHUB_BASE_REF" "${GITHUB_REF#refs/heads/}" "master"; do
# Skip empty branch names and merge commits.
if [[ -z "$BRANCH_NAME" || $BRANCH_NAME =~ ^refs/pull/.* ]]; then
continue
fi
(wget -O - "https://github.com/matrix-org/complement/archive/$BRANCH_NAME.tar.gz" | tar -xz --strip-components=1 -C complement) && break
done
# Build initial Synapse image
- run: docker build -t matrixdotorg/synapse:latest -f docker/Dockerfile .
@ -322,7 +340,7 @@ jobs:
working-directory: complement/dockerfiles
# Run Complement
- run: go test -v -tags synapse_blacklist ./tests
- run: go test -v -tags synapse_blacklist,msc2403,msc2946,msc3083 ./tests
env:
COMPLEMENT_BASE_IMAGE: complement-synapse:latest
working-directory: complement

View file

@ -1,3 +1,9 @@
Synapse 1.36.0 (2021-06-15)
===========================
No significant changes.
Synapse 1.36.0rc2 (2021-06-11)
==============================

View file

@ -173,12 +173,19 @@ source ./env/bin/activate
trial tests.rest.admin.test_room tests.handlers.test_admin.ExfiltrateData.test_invite
```
If your tests fail, you may wish to look at the logs:
If your tests fail, you may wish to look at the logs (the default log level is `ERROR`):
```sh
less _trial_temp/test.log
```
To increase the log level for the tests, set `SYNAPSE_TEST_LOG_LEVEL`:
```sh
SYNAPSE_TEST_LOG_LEVEL=DEBUG trial tests
```
## Run the integration tests.
The integration tests are a more comprehensive suite of tests. They

View file

@ -293,18 +293,6 @@ try installing the failing modules individually::
pip install -e "module-name"
Once this is done, you may wish to run Synapse's unit tests to
check that everything is installed correctly::
python -m twisted.trial tests
This should end with a 'PASSED' result (note that exact numbers will
differ)::
Ran 1337 tests in 716.064s
PASSED (skips=15, successes=1322)
We recommend using the demo which starts 3 federated instances running on ports `8080` - `8082`
./demo/start.sh
@ -324,6 +312,23 @@ If you just want to start a single instance of the app and run it directly::
python -m synapse.app.homeserver --config-path homeserver.yaml
Running the unit tests
======================
After getting up and running, you may wish to run Synapse's unit tests to
check that everything is installed correctly::
trial tests
This should end with a 'PASSED' result (note that exact numbers will
differ)::
Ran 1337 tests in 716.064s
PASSED (skips=15, successes=1322)
For more tips on running the unit tests, like running a specific test or
to see the logging output, see the `CONTRIBUTING doc <CONTRIBUTING.md#run-the-unit-tests>`_.
Running the Integration Tests

1
changelog.d/10080.misc Normal file
View file

@ -0,0 +1 @@
Add type hints to the federation servlets.

1
changelog.d/10115.bugfix Normal file
View file

@ -0,0 +1 @@
Fix a bug introduced in Synapse v1.25.0 that prevented the `ip_range_whitelist` configuration option from working for federation and identity servers. Contributed by @mikure.

1
changelog.d/10122.doc Normal file
View file

@ -0,0 +1 @@
Mention in the sample homeserver config that you may need to configure max upload size in your reverse proxy. Contributed by @aaronraimist.

1
changelog.d/10134.misc Normal file
View file

@ -0,0 +1 @@
Improve OpenTracing for event persistence.

1
changelog.d/10143.misc Normal file
View file

@ -0,0 +1 @@
Clean up the interface for injecting opentracing over HTTP.

1
changelog.d/10144.misc Normal file
View file

@ -0,0 +1 @@
Limit the number of in-flight `/keys/query` requests from a single device.

1
changelog.d/10145.misc Normal file
View file

@ -0,0 +1 @@
Refactor EventPersistenceQueue.

1
changelog.d/10148.misc Normal file
View file

@ -0,0 +1 @@
Document `SYNAPSE_TEST_LOG_LEVEL` to see the logger output when running tests.

1
changelog.d/10154.bugfix Normal file
View file

@ -0,0 +1 @@
Remove a broken import line in Synapse's admin_cmd worker. Broke in 1.33.0.

1
changelog.d/10155.misc Normal file
View file

@ -0,0 +1 @@
Update the Complement build tags in GitHub Actions to test currently experimental features.

1
changelog.d/10156.misc Normal file
View file

@ -0,0 +1 @@
Add `synapse_federation_soft_failed_events_total` metric to track how often events are soft failed.

1
changelog.d/10157.misc Normal file
View file

@ -0,0 +1 @@
Extend `ResponseCache` to pass a context object into the callback.

1
changelog.d/10160.misc Normal file
View file

@ -0,0 +1 @@
Fetch the corresponding complement branch when performing CI.

View file

@ -0,0 +1 @@
Stop supporting the unstable spaces prefixes from MSC1772.

1
changelog.d/10164.misc Normal file
View file

@ -0,0 +1 @@
Add some developer documentation about boolean columns in database schemas.

View file

@ -0,0 +1 @@
Implement "room knocking" as per [MSC2403](https://github.com/matrix-org/matrix-doc/pull/2403). Contributed by Sorunome and anoa.

1
changelog.d/10175.bugfix Normal file
View file

@ -0,0 +1 @@
Fix a minor bug in the response to `/_matrix/client/r0/user/{user}/openid/request_token`. Contributed by @lukaslihotzki.

1
changelog.d/10180.doc Normal file
View file

@ -0,0 +1 @@
Fix broken links in documentation.

1
changelog.d/10183.misc Normal file
View file

@ -0,0 +1 @@
Add debug logging for when we enter and exit `Measure` blocks.

1
changelog.d/6739.feature Normal file
View file

@ -0,0 +1 @@
Implement "room knocking" as per [MSC2403](https://github.com/matrix-org/matrix-doc/pull/2403). Contributed by Sorunome and anoa.

1
changelog.d/8436.doc Normal file
View file

@ -0,0 +1 @@
Add a new guide to decoding request logs.

1
changelog.d/9359.feature Normal file
View file

@ -0,0 +1 @@
Implement "room knocking" as per [MSC2403](https://github.com/matrix-org/matrix-doc/pull/2403). Contributed by Sorunome and anoa.

1
changelog.d/9933.misc Normal file
View file

@ -0,0 +1 @@
Update the database schema versioning to support gradual migration away from legacy tables.

6
debian/changelog vendored
View file

@ -1,3 +1,9 @@
matrix-synapse-py3 (1.36.0) stable; urgency=medium
* New synapse release 1.36.0.
-- Synapse Packaging team <packages@matrix.org> Tue, 15 Jun 2021 15:41:53 +0100
matrix-synapse-py3 (1.35.1) stable; urgency=medium
* New synapse release 1.35.1.

View file

@ -61,6 +61,7 @@
- [Server Version](admin_api/version_api.md)
- [Manhole](manhole.md)
- [Monitoring](metrics-howto.md)
- [Request log format](usage/administration/request_log.md)
- [Scripts]()
# Development
@ -69,6 +70,7 @@
- [Git Usage](dev/git.md)
- [Testing]()
- [OpenTracing](opentracing.md)
- [Database Schemas](development/database_schema.md)
- [Synapse Architecture]()
- [Log Contexts](log_contexts.md)
- [Replication](replication.md)
@ -84,4 +86,4 @@
- [Scripts]()
# Other
- [Dependency Deprecation Policy](deprecation_policy.md)
- [Dependency Deprecation Policy](deprecation_policy.md)

View file

@ -2,7 +2,7 @@ Admin APIs
==========
**Note**: The latest documentation can be viewed `here <https://matrix-org.github.io/synapse>`_.
See `docs/README.md <../docs/README.md>`_ for more information.
See `docs/README.md <../README.md>`_ for more information.
**Please update links to point to the website instead.** Existing files in this directory
are preserved to maintain historical links, but may be moved in the future.
@ -10,5 +10,5 @@ are preserved to maintain historical links, but may be moved in the future.
This directory includes documentation for the various synapse specific admin
APIs available. Updates to the existing Admin API documentation should still
be made to these files, but any new documentation files should instead be placed under
`docs/usage/administration/admin_api <../docs/usage/administration/admin_api>`_.
`docs/usage/administration/admin_api <../usage/administration/admin_api>`_.

View file

@ -11,4 +11,4 @@ POST /_synapse/admin/v1/delete_group/<group_id>
```
To use it, you will need to authenticate by providing an `access_token` for a
server admin: see [Admin API](../../usage/administration/admin_api).
server admin: see [Admin API](../usage/administration/admin_api).

View file

@ -7,7 +7,7 @@ The api is:
GET /_synapse/admin/v1/event_reports?from=0&limit=10
```
To use it, you will need to authenticate by providing an `access_token` for a
server admin: see [Admin API](../../usage/administration/admin_api).
server admin: see [Admin API](../usage/administration/admin_api).
It returns a JSON body like the following:
@ -95,7 +95,7 @@ The api is:
GET /_synapse/admin/v1/event_reports/<report_id>
```
To use it, you will need to authenticate by providing an `access_token` for a
server admin: see [Admin API](../../usage/administration/admin_api).
server admin: see [Admin API](../usage/administration/admin_api).
It returns a JSON body like the following:

View file

@ -28,7 +28,7 @@ The API is:
GET /_synapse/admin/v1/room/<room_id>/media
```
To use it, you will need to authenticate by providing an `access_token` for a
server admin: see [Admin API](../../usage/administration/admin_api).
server admin: see [Admin API](../usage/administration/admin_api).
The API returns a JSON body like the following:
```json
@ -311,7 +311,7 @@ The following fields are returned in the JSON response body:
* `deleted`: integer - The number of media items successfully deleted
To use it, you will need to authenticate by providing an `access_token` for a
server admin: see [Admin API](../../usage/administration/admin_api).
server admin: see [Admin API](../usage/administration/admin_api).
If the user re-requests purged remote media, synapse will re-request the media
from the originating server.

View file

@ -17,7 +17,7 @@ POST /_synapse/admin/v1/purge_history/<room_id>[/<event_id>]
```
To use it, you will need to authenticate by providing an `access_token` for a
server admin: [Admin API](../../usage/administration/admin_api)
server admin: [Admin API](../usage/administration/admin_api)
By default, events sent by local users are not deleted, as they may represent
the only copies of this content in existence. (Events sent by remote users are

View file

@ -24,7 +24,7 @@ POST /_synapse/admin/v1/join/<room_id_or_alias>
```
To use it, you will need to authenticate by providing an `access_token` for a
server admin: see [Admin API](../../usage/administration/admin_api).
server admin: see [Admin API](../usage/administration/admin_api).
Response:

View file

@ -443,7 +443,7 @@ with a body of:
```
To use it, you will need to authenticate by providing an ``access_token`` for a
server admin: see [Admin API](../../usage/administration/admin_api).
server admin: see [Admin API](../usage/administration/admin_api).
A response body like the following is returned:

View file

@ -10,7 +10,7 @@ GET /_synapse/admin/v1/statistics/users/media
```
To use it, you will need to authenticate by providing an `access_token`
for a server admin: see [Admin API](../../usage/administration/admin_api).
for a server admin: see [Admin API](../usage/administration/admin_api).
A response body like the following is returned:

View file

@ -11,7 +11,7 @@ GET /_synapse/admin/v2/users/<user_id>
```
To use it, you will need to authenticate by providing an `access_token` for a
server admin: [Admin API](../../usage/administration/admin_api)
server admin: [Admin API](../usage/administration/admin_api)
It returns a JSON body like the following:
@ -78,7 +78,7 @@ with a body of:
```
To use it, you will need to authenticate by providing an `access_token` for a
server admin: [Admin API](../../usage/administration/admin_api)
server admin: [Admin API](../usage/administration/admin_api)
URL parameters:
@ -119,7 +119,7 @@ GET /_synapse/admin/v2/users?from=0&limit=10&guests=false
```
To use it, you will need to authenticate by providing an `access_token` for a
server admin: [Admin API](../../usage/administration/admin_api)
server admin: [Admin API](../usage/administration/admin_api)
A response body like the following is returned:
@ -237,7 +237,7 @@ See also: [Client Server
API Whois](https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-admin-whois-userid).
To use it, you will need to authenticate by providing an `access_token` for a
server admin: [Admin API](../../usage/administration/admin_api)
server admin: [Admin API](../usage/administration/admin_api)
It returns a JSON body like the following:
@ -294,7 +294,7 @@ with a body of:
```
To use it, you will need to authenticate by providing an `access_token` for a
server admin: [Admin API](../../usage/administration/admin_api)
server admin: [Admin API](../usage/administration/admin_api)
The erase parameter is optional and defaults to `false`.
An empty body may be passed for backwards compatibility.
@ -339,7 +339,7 @@ with a body of:
```
To use it, you will need to authenticate by providing an `access_token` for a
server admin: [Admin API](../../usage/administration/admin_api)
server admin: [Admin API](../usage/administration/admin_api)
The parameter `new_password` is required.
The parameter `logout_devices` is optional and defaults to `true`.
@ -354,7 +354,7 @@ GET /_synapse/admin/v1/users/<user_id>/admin
```
To use it, you will need to authenticate by providing an `access_token` for a
server admin: [Admin API](../../usage/administration/admin_api)
server admin: [Admin API](../usage/administration/admin_api)
A response body like the following is returned:
@ -384,7 +384,7 @@ with a body of:
```
To use it, you will need to authenticate by providing an `access_token` for a
server admin: [Admin API](../../usage/administration/admin_api)
server admin: [Admin API](../usage/administration/admin_api)
## List room memberships of a user
@ -398,7 +398,7 @@ GET /_synapse/admin/v1/users/<user_id>/joined_rooms
```
To use it, you will need to authenticate by providing an `access_token` for a
server admin: [Admin API](../../usage/administration/admin_api)
server admin: [Admin API](../usage/administration/admin_api)
A response body like the following is returned:
@ -443,7 +443,7 @@ GET /_synapse/admin/v1/users/<user_id>/media
```
To use it, you will need to authenticate by providing an `access_token` for a
server admin: [Admin API](../../usage/administration/admin_api)
server admin: [Admin API](../usage/administration/admin_api)
A response body like the following is returned:
@ -591,7 +591,7 @@ GET /_synapse/admin/v2/users/<user_id>/devices
```
To use it, you will need to authenticate by providing an `access_token` for a
server admin: [Admin API](../../usage/administration/admin_api)
server admin: [Admin API](../usage/administration/admin_api)
A response body like the following is returned:
@ -659,7 +659,7 @@ POST /_synapse/admin/v2/users/<user_id>/delete_devices
```
To use it, you will need to authenticate by providing an `access_token` for a
server admin: [Admin API](../../usage/administration/admin_api)
server admin: [Admin API](../usage/administration/admin_api)
An empty JSON dict is returned.
@ -683,7 +683,7 @@ GET /_synapse/admin/v2/users/<user_id>/devices/<device_id>
```
To use it, you will need to authenticate by providing an `access_token` for a
server admin: [Admin API](../../usage/administration/admin_api)
server admin: [Admin API](../usage/administration/admin_api)
A response body like the following is returned:
@ -731,7 +731,7 @@ PUT /_synapse/admin/v2/users/<user_id>/devices/<device_id>
```
To use it, you will need to authenticate by providing an `access_token` for a
server admin: [Admin API](../../usage/administration/admin_api)
server admin: [Admin API](../usage/administration/admin_api)
An empty JSON dict is returned.
@ -760,7 +760,7 @@ DELETE /_synapse/admin/v2/users/<user_id>/devices/<device_id>
```
To use it, you will need to authenticate by providing an `access_token` for a
server admin: [Admin API](../../usage/administration/admin_api)
server admin: [Admin API](../usage/administration/admin_api)
An empty JSON dict is returned.
@ -781,7 +781,7 @@ GET /_synapse/admin/v1/users/<user_id>/pushers
```
To use it, you will need to authenticate by providing an `access_token` for a
server admin: [Admin API](../../usage/administration/admin_api)
server admin: [Admin API](../usage/administration/admin_api)
A response body like the following is returned:
@ -872,7 +872,7 @@ POST /_synapse/admin/v1/users/<user_id>/shadow_ban
```
To use it, you will need to authenticate by providing an `access_token` for a
server admin: [Admin API](../../usage/administration/admin_api)
server admin: [Admin API](../usage/administration/admin_api)
An empty JSON dict is returned.
@ -897,7 +897,7 @@ GET /_synapse/admin/v1/users/<user_id>/override_ratelimit
```
To use it, you will need to authenticate by providing an `access_token` for a
server admin: [Admin API](../../usage/administration/admin_api)
server admin: [Admin API](../usage/administration/admin_api)
A response body like the following is returned:
@ -939,7 +939,7 @@ POST /_synapse/admin/v1/users/<user_id>/override_ratelimit
```
To use it, you will need to authenticate by providing an `access_token` for a
server admin: [Admin API](../../usage/administration/admin_api)
server admin: [Admin API](../usage/administration/admin_api)
A response body like the following is returned:
@ -984,7 +984,7 @@ DELETE /_synapse/admin/v1/users/<user_id>/override_ratelimit
```
To use it, you will need to authenticate by providing an `access_token` for a
server admin: [Admin API](../../usage/administration/admin_api)
server admin: [Admin API](../usage/administration/admin_api)
An empty JSON dict is returned.

View file

@ -24,8 +24,8 @@ To enable this, first create templates for the policy and success pages.
These should be stored on the local filesystem.
These templates use the [Jinja2](http://jinja.pocoo.org) templating language,
and [docs/privacy_policy_templates](privacy_policy_templates) gives
examples of the sort of thing that can be done.
and [docs/privacy_policy_templates](https://github.com/matrix-org/synapse/tree/develop/docs/privacy_policy_templates/)
gives examples of the sort of thing that can be done.
Note that the templates must be stored under a name giving the language of the
template - currently this must always be `en` (for "English");

View file

@ -0,0 +1,137 @@
# Synapse database schema files
Synapse's database schema is stored in the `synapse.storage.schema` module.
## Logical databases
Synapse supports splitting its datastore across multiple physical databases (which can
be useful for large installations), and the schema files are therefore split according
to the logical database they apply to.
At the time of writing, the following "logical" databases are supported:
* `state` - used to store Matrix room state (more specifically, `state_groups`,
their relationships and contents).
* `main` - stores everything else.
Additionally, the `common` directory contains schema files for tables which must be
present on *all* physical databases.
## Synapse schema versions
Synapse manages its database schema via "schema versions". These are mainly used to
help avoid confusion if the Synapse codebase is rolled back after the database is
updated. They work as follows:
* The Synapse codebase defines a constant `synapse.storage.schema.SCHEMA_VERSION`
which represents the expectations made about the database by that version. For
example, as of Synapse v1.36, this is `59`.
* The database stores a "compatibility version" in
`schema_compat_version.compat_version` which defines the `SCHEMA_VERSION` of the
oldest version of Synapse which will work with the database. On startup, if
`compat_version` is found to be newer than `SCHEMA_VERSION`, Synapse will refuse to
start.
Synapse automatically updates this field from
`synapse.storage.schema.SCHEMA_COMPAT_VERSION`.
* Whenever a backwards-incompatible change is made to the database format (normally
via a `delta` file), `synapse.storage.schema.SCHEMA_COMPAT_VERSION` is also updated
so that administrators can not accidentally roll back to a too-old version of Synapse.
Generally, the goal is to maintain compatibility with at least one or two previous
releases of Synapse, so any substantial change tends to require multiple releases and a
bit of forward-planning to get right.
As a worked example: we want to remove the `room_stats_historical` table. Here is how it
might pan out.
1. Replace any code that *reads* from `room_stats_historical` with alternative
implementations, but keep writing to it in case of rollback to an earlier version.
Also, increase `synapse.storage.schema.SCHEMA_VERSION`. In this
instance, there is no existing code which reads from `room_stats_historical`, so
our starting point is:
v1.36.0: `SCHEMA_VERSION=59`, `SCHEMA_COMPAT_VERSION=59`
2. Next (say in Synapse v1.37.0): remove the code that *writes* to
`room_stats_historical`, but dont yet remove the table in case of rollback to
v1.36.0. Again, we increase `synapse.storage.schema.SCHEMA_VERSION`, but
because we have not broken compatibility with v1.36, we do not yet update
`SCHEMA_COMPAT_VERSION`. We now have:
v1.37.0: `SCHEMA_VERSION=60`, `SCHEMA_COMPAT_VERSION=59`.
3. Later (say in Synapse v1.38.0): we can remove the table altogether. This will
break compatibility with v1.36.0, so we must update `SCHEMA_COMPAT_VERSION` accordingly.
There is no need to update `synapse.storage.schema.SCHEMA_VERSION`, since there is no
change to the Synapse codebase here. So we end up with:
v1.38.0: `SCHEMA_VERSION=60`, `SCHEMA_COMPAT_VERSION=60`.
If in doubt about whether to update `SCHEMA_VERSION` or not, it is generally best to
lean towards doing so.
## Full schema dumps
In the `full_schemas` directories, only the most recently-numbered snapshot is used
(`54` at the time of writing). Older snapshots (eg, `16`) are present for historical
reference only.
### Building full schema dumps
If you want to recreate these schemas, they need to be made from a database that
has had all background updates run.
To do so, use `scripts-dev/make_full_schema.sh`. This will produce new
`full.sql.postgres` and `full.sql.sqlite` files.
Ensure postgres is installed, then run:
./scripts-dev/make_full_schema.sh -p postgres_username -o output_dir/
NB at the time of writing, this script predates the split into separate `state`/`main`
databases so will require updates to handle that correctly.
## Boolean columns
Boolean columns require special treatment, since SQLite treats booleans the
same as integers.
There are three separate aspects to this:
* Any new boolean column must be added to the `BOOLEAN_COLUMNS` list in
`scripts/synapse_port_db`. This tells the port script to cast the integer
value from SQLite to a boolean before writing the value to the postgres
database.
* Before SQLite 3.23, `TRUE` and `FALSE` were not recognised as constants by
SQLite, and the `IS [NOT] TRUE`/`IS [NOT] FALSE` operators were not
supported. This makes it necessary to avoid using `TRUE` and `FALSE`
constants in SQL commands.
For example, to insert a `TRUE` value into the database, write:
```python
txn.execute("INSERT INTO tbl(col) VALUES (?)", (True, ))
```
* Default values for new boolean columns present a particular
difficulty. Generally it is best to create separate schema files for
Postgres and SQLite. For example:
```sql
# in 00delta.sql.postgres:
ALTER TABLE tbl ADD COLUMN col BOOLEAN DEFAULT FALSE;
```
```sql
# in 00delta.sql.sqlite:
ALTER TABLE tbl ADD COLUMN col BOOLEAN DEFAULT 0;
```
Note that there is a particularly insidious failure mode here: the Postgres
flavour will be accepted by SQLite 3.22, but will give a column whose
default value is the **string** `"FALSE"` - which, when cast back to a boolean
in Python, evaluates to `True`.

View file

@ -14,7 +14,7 @@ you set the `server_name` to match your machine's public DNS hostname.
For this default configuration to work, you will need to listen for TLS
connections on port 8448. The preferred way to do that is by using a
reverse proxy: see [reverse_proxy.md](<reverse_proxy.md>) for instructions
reverse proxy: see [reverse_proxy.md](reverse_proxy.md) for instructions
on how to correctly set one up.
In some cases you might not want to run Synapse on the machine that has
@ -44,7 +44,7 @@ a complicated dance which requires connections in both directions).
Another common problem is that people on other servers can't join rooms that
you invite them to. This can be caused by an incorrectly-configured reverse
proxy: see [reverse_proxy.md](<reverse_proxy.md>) for instructions on how to correctly
proxy: see [reverse_proxy.md](reverse_proxy.md) for instructions on how to correctly
configure a reverse proxy.
### Known issues
@ -63,4 +63,4 @@ release of Synapse.
If you want to get up and running quickly with a trio of homeservers in a
private federation, there is a script in the `demo` directory. This is mainly
useful just for development purposes. See [demo/README](<../demo/README>).
useful just for development purposes. See [demo/README](https://github.com/matrix-org/synapse/tree/develop/demo/).

View file

@ -51,7 +51,7 @@ clients.
Support for this feature can be enabled and configured in the
`retention` section of the Synapse configuration file (see the
[sample file](https://github.com/matrix-org/synapse/blob/v1.7.3/docs/sample_config.yaml#L332-L393)).
[sample file](https://github.com/matrix-org/synapse/blob/v1.36.0/docs/sample_config.yaml#L451-L518)).
To enable support for message retention policies, set the setting
`enabled` in this section to `true`.
@ -87,7 +87,7 @@ expired events from the database. They are only run if support for
message retention policies is enabled in the server's configuration. If
no configuration for purge jobs is configured by the server admin,
Synapse will use a default configuration, which is described in the
[sample configuration file](https://github.com/matrix-org/synapse/blob/master/docs/sample_config.yaml#L332-L393).
[sample configuration file](https://github.com/matrix-org/synapse/blob/v1.36.0/docs/sample_config.yaml#L451-L518).
Some server admins might want a finer control on when events are removed
depending on an event's room's policy. This can be done by setting the

View file

@ -72,8 +72,7 @@
## Monitoring workers
To monitor a Synapse installation using
[workers](https://github.com/matrix-org/synapse/blob/master/docs/workers.md),
To monitor a Synapse installation using [workers](workers.md),
every worker needs to be monitored independently, in addition to
the main homeserver process. This is because workers don't send
their metrics to the main homeserver process, but expose them

View file

@ -30,7 +30,7 @@ presence to (for those users that the receiving user is considered interested in
It does not include state for users who are currently offline, and it can only be
called on workers that support sending federation. Additionally, this method must
only be called from the process that has been configured to write to the
the [presence stream](https://github.com/matrix-org/synapse/blob/master/docs/workers.md#stream-writers).
the [presence stream](workers.md#stream-writers).
By default, this is the main process, but another worker can be configured to do
so.

View file

@ -21,7 +21,7 @@ port 8448. Where these are different, we refer to the 'client port' and the
'federation port'. See [the Matrix
specification](https://matrix.org/docs/spec/server_server/latest#resolving-server-names)
for more details of the algorithm used for federation connections, and
[delegate.md](<delegate.md>) for instructions on setting up delegation.
[delegate.md](delegate.md) for instructions on setting up delegation.
**NOTE**: Your reverse proxy must not `canonicalise` or `normalise`
the requested URI in any way (for example, by decoding `%xx` escapes).

View file

@ -954,6 +954,10 @@ media_store_path: "DATADIR/media_store"
# The largest allowed upload size in bytes
#
# If you are using a reverse proxy you may also need to set this value in
# your reverse proxy's config. Notably Nginx has a small max body size by default.
# See https://matrix-org.github.io/synapse/develop/reverse_proxy.html.
#
#max_upload_size: 50M
# Maximum number of pixels that will be thumbnailed

View file

@ -108,7 +108,7 @@ A custom mapping provider must specify the following methods:
Synapse has a built-in OpenID mapping provider if a custom provider isn't
specified in the config. It is located at
[`synapse.handlers.oidc.JinjaOidcMappingProvider`](../synapse/handlers/oidc.py).
[`synapse.handlers.oidc.JinjaOidcMappingProvider`](https://github.com/matrix-org/synapse/blob/develop/synapse/handlers/oidc.py).
## SAML Mapping Providers
@ -194,4 +194,4 @@ A custom mapping provider must specify the following methods:
Synapse has a built-in SAML mapping provider if a custom provider isn't
specified in the config. It is located at
[`synapse.handlers.saml.DefaultSamlMappingProvider`](../synapse/handlers/saml.py).
[`synapse.handlers.saml.DefaultSamlMappingProvider`](https://github.com/matrix-org/synapse/blob/develop/synapse/handlers/saml.py).

View file

@ -6,16 +6,18 @@ well as a `matrix-synapse-worker@` service template for any workers you
require. Additionally, to group the required services, it sets up a
`matrix-synapse.target`.
See the folder [system](system) for the systemd unit files.
See the folder [system](https://github.com/matrix-org/synapse/tree/develop/docs/systemd-with-workers/system/)
for the systemd unit files.
The folder [workers](workers) contains an example configuration for the
`federation_reader` worker.
The folder [workers](https://github.com/matrix-org/synapse/tree/develop/docs/systemd-with-workers/workers/)
contains an example configuration for the `federation_reader` worker.
## Synapse configuration files
See [workers.md](../workers.md) for information on how to set up the
configuration files and reverse-proxy correctly. You can find an example worker
config in the [workers](workers) folder.
config in the [workers](https://github.com/matrix-org/synapse/tree/develop/docs/systemd-with-workers/workers/)
folder.
Systemd manages daemonization itself, so ensure that none of the configuration
files set either `daemonize` or `worker_daemonize`.
@ -29,8 +31,8 @@ There is no need for a separate configuration file for the master process.
## Set up
1. Adjust synapse configuration files as above.
1. Copy the `*.service` and `*.target` files in [system](system) to
`/etc/systemd/system`.
1. Copy the `*.service` and `*.target` files in [system](https://github.com/matrix-org/synapse/tree/develop/docs/systemd-with-workers/system/)
to `/etc/systemd/system`.
1. Run `systemctl daemon-reload` to tell systemd to load the new unit files.
1. Run `systemctl enable matrix-synapse.service`. This will configure the
synapse master process to be started as part of the `matrix-synapse.target`

View file

@ -0,0 +1,44 @@
# Request log format
HTTP request logs are written by synapse (see [`site.py`](../synapse/http/site.py) for details).
See the following for how to decode the dense data available from the default logging configuration.
```
2020-10-01 12:00:00,000 - synapse.access.http.8008 - 311 - INFO - PUT-1000- 192.168.0.1 - 8008 - {another-matrix-server.com} Processed request: 0.100sec/-0.000sec (0.000sec, 0.000sec) (0.001sec/0.090sec/3) 11B !200 "PUT /_matrix/federation/v1/send/1600000000000 HTTP/1.1" "Synapse/1.20.1" [0 dbevts]
-AAAAAAAAAAAAAAAAAAAAA- -BBBBBBBBBBBBBBBBBBBBBB- -C- -DD- -EEEEEE- -FFFFFFFFF- -GG- -HHHHHHHHHHHHHHHHHHHHHHH- -IIIIII- -JJJJJJJ- -KKKKKK-, -LLLLLL- -MMMMMMM- -NNNNNN- O -P- -QQ- -RRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRR- -SSSSSSSSSSSS- -TTTTTT-
```
| Part | Explanation |
| ----- | ------------ |
| AAAA | Timestamp request was logged (not recieved) |
| BBBB | Logger name (`synapse.access.(http\|https).<tag>`, where 'tag' is defined in the `listeners` config section, normally the port) |
| CCCC | Line number in code |
| DDDD | Log Level |
| EEEE | Request Identifier (This identifier is shared by related log lines)|
| FFFF | Source IP (Or X-Forwarded-For if enabled) |
| GGGG | Server Port |
| HHHH | Federated Server or Local User making request (blank if unauthenticated or not supplied) |
| IIII | Total Time to process the request |
| JJJJ | Time to send response over network once generated (this may be negative if the socket is closed before the response is generated)|
| KKKK | Userland CPU time |
| LLLL | System CPU time |
| MMMM | Total time waiting for a free DB connection from the pool across all parallel DB work from this request |
| NNNN | Total time waiting for response to DB queries across all parallel DB work from this request |
| OOOO | Count of DB transactions performed |
| PPPP | Response body size |
| QQQQ | Response status code (prefixed with ! if the socket was closed before the response was generated) |
| RRRR | Request |
| SSSS | User-agent |
| TTTT | Events fetched from DB to service this request (note that this does not include events fetched from the cache) |
MMMM / NNNN can be greater than IIII if there are multiple slow database queries
running in parallel.
Some actions can result in multiple identical http requests, which will return
the same data, but only the first request will report time/transactions in
`KKKK`/`LLLL`/`MMMM`/`NNNN`/`OOOO` - the others will be awaiting the first query to return a
response and will simultaneously return with the first request, but with very
small processing times.

View file

@ -16,7 +16,7 @@ workers only work with PostgreSQL-based Synapse deployments. SQLite should only
be used for demo purposes and any admin considering workers should already be
running PostgreSQL.
See also https://matrix.org/blog/2020/11/03/how-we-fixed-synapses-scalability
See also [Matrix.org blog post](https://matrix.org/blog/2020/11/03/how-we-fixed-synapses-scalability)
for a higher level overview.
## Main process/worker communication

View file

@ -47,7 +47,7 @@ try:
except ImportError:
pass
__version__ = "1.36.0rc2"
__version__ = "1.36.0"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when

View file

@ -207,7 +207,7 @@ class Auth:
request.requester = user_id
if user_id in self._force_tracing_for_users:
opentracing.set_tag(opentracing.tags.SAMPLING_PRIORITY, 1)
opentracing.force_tracing()
opentracing.set_tag("authenticated_entity", user_id)
opentracing.set_tag("user_id", user_id)
opentracing.set_tag("appservice_id", app_service.id)
@ -260,7 +260,7 @@ class Auth:
request.requester = requester
if user_info.token_owner in self._force_tracing_for_users:
opentracing.set_tag(opentracing.tags.SAMPLING_PRIORITY, 1)
opentracing.force_tracing()
opentracing.set_tag("authenticated_entity", user_info.token_owner)
opentracing.set_tag("user_id", user_info.user_id)
if device_id:

View file

@ -112,8 +112,6 @@ class EventTypes:
SpaceChild = "m.space.child"
SpaceParent = "m.space.parent"
MSC1772_SPACE_CHILD = "org.matrix.msc1772.space.child"
MSC1772_SPACE_PARENT = "org.matrix.msc1772.space.parent"
class ToDeviceEventTypes:
@ -180,7 +178,6 @@ class EventContentFields:
# cf https://github.com/matrix-org/matrix-doc/pull/1772
ROOM_TYPE = "type"
MSC1772_ROOM_TYPE = "org.matrix.msc1772.type"
class RoomEncryptionAlgorithms:

View file

@ -449,7 +449,7 @@ class IncompatibleRoomVersionError(SynapseError):
super().__init__(
code=400,
msg="Your homeserver does not support the features required to "
"join this room",
"interact with this room",
errcode=Codes.INCOMPATIBLE_ROOM_VERSION,
)

View file

@ -56,7 +56,7 @@ class RoomVersion:
state_res = attr.ib(type=int) # one of the StateResolutionVersions
enforce_key_validity = attr.ib(type=bool)
# Before MSC2261/MSC2432, m.room.aliases had special auth rules and redaction rules
# Before MSC2432, m.room.aliases had special auth rules and redaction rules
special_case_aliases_auth = attr.ib(type=bool)
# Strictly enforce canonicaljson, do not allow:
# * Integers outside the range of [-2 ^ 53 + 1, 2 ^ 53 - 1]
@ -70,6 +70,9 @@ class RoomVersion:
msc2176_redaction_rules = attr.ib(type=bool)
# MSC3083: Support the 'restricted' join_rule.
msc3083_join_rules = attr.ib(type=bool)
# MSC2403: Allows join_rules to be set to 'knock', changes auth rules to allow sending
# m.room.membership event with membership 'knock'.
msc2403_knocking = attr.ib(type=bool)
class RoomVersions:
@ -84,6 +87,7 @@ class RoomVersions:
limit_notifications_power_levels=False,
msc2176_redaction_rules=False,
msc3083_join_rules=False,
msc2403_knocking=False,
)
V2 = RoomVersion(
"2",
@ -96,6 +100,7 @@ class RoomVersions:
limit_notifications_power_levels=False,
msc2176_redaction_rules=False,
msc3083_join_rules=False,
msc2403_knocking=False,
)
V3 = RoomVersion(
"3",
@ -108,6 +113,7 @@ class RoomVersions:
limit_notifications_power_levels=False,
msc2176_redaction_rules=False,
msc3083_join_rules=False,
msc2403_knocking=False,
)
V4 = RoomVersion(
"4",
@ -120,6 +126,7 @@ class RoomVersions:
limit_notifications_power_levels=False,
msc2176_redaction_rules=False,
msc3083_join_rules=False,
msc2403_knocking=False,
)
V5 = RoomVersion(
"5",
@ -132,6 +139,7 @@ class RoomVersions:
limit_notifications_power_levels=False,
msc2176_redaction_rules=False,
msc3083_join_rules=False,
msc2403_knocking=False,
)
V6 = RoomVersion(
"6",
@ -144,6 +152,7 @@ class RoomVersions:
limit_notifications_power_levels=True,
msc2176_redaction_rules=False,
msc3083_join_rules=False,
msc2403_knocking=False,
)
MSC2176 = RoomVersion(
"org.matrix.msc2176",
@ -156,6 +165,7 @@ class RoomVersions:
limit_notifications_power_levels=True,
msc2176_redaction_rules=True,
msc3083_join_rules=False,
msc2403_knocking=False,
)
MSC3083 = RoomVersion(
"org.matrix.msc3083",
@ -168,6 +178,20 @@ class RoomVersions:
limit_notifications_power_levels=True,
msc2176_redaction_rules=False,
msc3083_join_rules=True,
msc2403_knocking=False,
)
V7 = RoomVersion(
"7",
RoomDisposition.STABLE,
EventFormatVersions.V3,
StateResolutionVersions.V2,
enforce_key_validity=True,
special_case_aliases_auth=False,
strict_canonicaljson=True,
limit_notifications_power_levels=True,
msc2176_redaction_rules=False,
msc3083_join_rules=False,
msc2403_knocking=True,
)
@ -182,5 +206,7 @@ KNOWN_ROOM_VERSIONS = {
RoomVersions.V6,
RoomVersions.MSC2176,
RoomVersions.MSC3083,
RoomVersions.V7,
)
# Note that we do not include MSC2043 here unless it is enabled in the config.
} # type: Dict[str, RoomVersion]

View file

@ -36,7 +36,6 @@ from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.filtering import SlavedFilteringStore
from synapse.replication.slave.storage.groups import SlavedGroupServerStore
from synapse.replication.slave.storage.presence import SlavedPresenceStore
from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
@ -54,7 +53,6 @@ class AdminCmdSlavedStore(
SlavedApplicationServiceStore,
SlavedRegistrationStore,
SlavedFilteringStore,
SlavedPresenceStore,
SlavedGroupServerStore,
SlavedDeviceInboxStore,
SlavedDeviceStore,

View file

@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple
from prometheus_client import Counter
from synapse.api.constants import EventTypes, ThirdPartyEntityKind
from synapse.api.constants import EventTypes, Membership, ThirdPartyEntityKind
from synapse.api.errors import CodeMessageException
from synapse.events import EventBase
from synapse.events.utils import serialize_event
@ -247,9 +247,14 @@ class ApplicationServiceApi(SimpleHttpClient):
e,
time_now,
as_client_event=True,
is_invite=(
# If this is an invite or a knock membership event, and we're interested
# in this user, then include any stripped state alongside the event.
include_stripped_room_state=(
e.type == EventTypes.Member
and e.membership == "invite"
and (
e.membership == Membership.INVITE
or e.membership == Membership.KNOCK
)
and service.is_interested_in_user(e.state_key)
),
)

View file

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");

View file

@ -248,6 +248,10 @@ class ContentRepositoryConfig(Config):
# The largest allowed upload size in bytes
#
# If you are using a reverse proxy you may also need to set this value in
# your reverse proxy's config. Notably Nginx has a small max body size by default.
# See https://matrix-org.github.io/synapse/develop/reverse_proxy.html.
#
#max_upload_size: 50M
# Maximum number of pixels that will be thumbnailed

View file

@ -397,19 +397,22 @@ class ServerConfig(Config):
self.ip_range_whitelist = generate_ip_set(
config.get("ip_range_whitelist", ()), config_path=("ip_range_whitelist",)
)
# The federation_ip_range_blacklist is used for backwards-compatibility
# and only applies to federation and identity servers. If it is not given,
# default to ip_range_blacklist.
federation_ip_range_blacklist = config.get(
"federation_ip_range_blacklist", ip_range_blacklist
)
# Always blacklist 0.0.0.0, ::
self.federation_ip_range_blacklist = generate_ip_set(
federation_ip_range_blacklist,
["0.0.0.0", "::"],
config_path=("federation_ip_range_blacklist",),
)
# and only applies to federation and identity servers.
if "federation_ip_range_blacklist" in config:
# Always blacklist 0.0.0.0, ::
self.federation_ip_range_blacklist = generate_ip_set(
config["federation_ip_range_blacklist"],
["0.0.0.0", "::"],
config_path=("federation_ip_range_blacklist",),
)
# 'federation_ip_range_whitelist' was never a supported configuration option.
self.federation_ip_range_whitelist = None
else:
# No backwards-compatiblity requrired, as federation_ip_range_blacklist
# is not given. Default to ip_range_blacklist and ip_range_whitelist.
self.federation_ip_range_blacklist = self.ip_range_blacklist
self.federation_ip_range_whitelist = self.ip_range_whitelist
# (undocumented) option for torturing the worker-mode replication a bit,
# for testing. The value defines the number of milliseconds to pause before

View file

@ -160,6 +160,7 @@ def check(
if logger.isEnabledFor(logging.DEBUG):
logger.debug("Auth events: %s", [a.event_id for a in auth_events.values()])
# 5. If type is m.room.membership
if event.type == EventTypes.Member:
_is_membership_change_allowed(room_version_obj, event, auth_events)
logger.debug("Allowing! %s", event)
@ -257,6 +258,11 @@ def _is_membership_change_allowed(
caller_in_room = caller and caller.membership == Membership.JOIN
caller_invited = caller and caller.membership == Membership.INVITE
caller_knocked = (
caller
and room_version.msc2403_knocking
and caller.membership == Membership.KNOCK
)
# get info about the target
key = (EventTypes.Member, target_user_id)
@ -283,6 +289,7 @@ def _is_membership_change_allowed(
{
"caller_in_room": caller_in_room,
"caller_invited": caller_invited,
"caller_knocked": caller_knocked,
"target_banned": target_banned,
"target_in_room": target_in_room,
"membership": membership,
@ -299,9 +306,14 @@ def _is_membership_change_allowed(
raise AuthError(403, "%s is banned from the room" % (target_user_id,))
return
if Membership.JOIN != membership:
# Require the user to be in the room for membership changes other than join/knock.
if Membership.JOIN != membership and (
RoomVersion.msc2403_knocking and Membership.KNOCK != membership
):
# If the user has been invited or has knocked, they are allowed to change their
# membership event to leave
if (
caller_invited
(caller_invited or caller_knocked)
and Membership.LEAVE == membership
and target_user_id == event.user_id
):
@ -339,7 +351,9 @@ def _is_membership_change_allowed(
and join_rule == JoinRules.MSC3083_RESTRICTED
):
pass
elif join_rule == JoinRules.INVITE:
elif join_rule == JoinRules.INVITE or (
room_version.msc2403_knocking and join_rule == JoinRules.KNOCK
):
if not caller_in_room and not caller_invited:
raise AuthError(403, "You are not invited to this room.")
else:
@ -358,6 +372,17 @@ def _is_membership_change_allowed(
elif Membership.BAN == membership:
if user_level < ban_level or user_level <= target_level:
raise AuthError(403, "You don't have permission to ban")
elif room_version.msc2403_knocking and Membership.KNOCK == membership:
if join_rule != JoinRules.KNOCK:
raise AuthError(403, "You don't have permission to knock")
elif target_user_id != event.user_id:
raise AuthError(403, "You cannot knock for other users")
elif target_in_room:
raise AuthError(403, "You cannot knock on a room you are already in")
elif caller_invited:
raise AuthError(403, "You are already invited to this room")
elif target_banned:
raise AuthError(403, "You are banned from this room")
else:
raise AuthError(500, "Unknown membership %s" % membership)
@ -718,7 +743,7 @@ def auth_types_for_event(event: EventBase) -> Set[Tuple[str, str]]:
if event.type == EventTypes.Member:
membership = event.content["membership"]
if membership in [Membership.JOIN, Membership.INVITE]:
if membership in [Membership.JOIN, Membership.INVITE, Membership.KNOCK]:
auth_types.add((EventTypes.JoinRules, ""))
auth_types.add((EventTypes.Member, event.state_key))

View file

@ -242,6 +242,7 @@ def format_event_for_client_v1(d):
"replaces_state",
"prev_content",
"invite_room_state",
"knock_room_state",
)
for key in copy_keys:
if key in d["unsigned"]:
@ -278,7 +279,7 @@ def serialize_event(
event_format=format_event_for_client_v1,
token_id=None,
only_event_fields=None,
is_invite=False,
include_stripped_room_state=False,
):
"""Serialize event for clients
@ -289,8 +290,10 @@ def serialize_event(
event_format
token_id
only_event_fields
is_invite (bool): Whether this is an invite that is being sent to the
invitee
include_stripped_room_state (bool): Some events can have stripped room state
stored in the `unsigned` field. This is required for invite and knock
functionality. If this option is False, that state will be removed from the
event before it is returned. Otherwise, it will be kept.
Returns:
dict
@ -322,11 +325,13 @@ def serialize_event(
if txn_id is not None:
d["unsigned"]["transaction_id"] = txn_id
# If this is an invite for somebody else, then we don't care about the
# invite_room_state as that's meant solely for the invitee. Other clients
# will already have the state since they're in the room.
if not is_invite:
# invite_room_state and knock_room_state are a list of stripped room state events
# that are meant to provide metadata about a room to an invitee/knocker. They are
# intended to only be included in specific circumstances, such as down sync, and
# should not be included in any other case.
if not include_stripped_room_state:
d["unsigned"].pop("invite_room_state", None)
d["unsigned"].pop("knock_room_state", None)
if as_client_event:
d = event_format(d)

View file

@ -1,4 +1,5 @@
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2015-2021 The Matrix.org Foundation C.I.C.
# Copyright 2020 Sorunome
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -619,7 +620,8 @@ class FederationClient(FederationBase):
SynapseError: if the chosen remote server returns a 300/400 code, or
no servers successfully handle the request.
"""
valid_memberships = {Membership.JOIN, Membership.LEAVE}
valid_memberships = {Membership.JOIN, Membership.LEAVE, Membership.KNOCK}
if membership not in valid_memberships:
raise RuntimeError(
"make_membership_event called with membership='%s', must be one of %s"
@ -638,6 +640,13 @@ class FederationClient(FederationBase):
if not room_version:
raise UnsupportedRoomVersionError()
if not room_version.msc2403_knocking and membership == Membership.KNOCK:
raise SynapseError(
400,
"This room version does not support knocking",
errcode=Codes.FORBIDDEN,
)
pdu_dict = ret.get("event", None)
if not isinstance(pdu_dict, dict):
raise InvalidResponseError("Bad 'event' field in response")
@ -946,6 +955,62 @@ class FederationClient(FederationBase):
# content.
return resp[1]
async def send_knock(self, destinations: List[str], pdu: EventBase) -> JsonDict:
"""Attempts to send a knock event to given a list of servers. Iterates
through the list until one attempt succeeds.
Doing so will cause the remote server to add the event to the graph,
and send the event out to the rest of the federation.
Args:
destinations: A list of candidate homeservers which are likely to be
participating in the room.
pdu: The event to be sent.
Returns:
The remote homeserver return some state from the room. The response
dictionary is in the form:
{"knock_state_events": [<state event dict>, ...]}
The list of state events may be empty.
Raises:
SynapseError: If the chosen remote server returns a 3xx/4xx code.
RuntimeError: If no servers were reachable.
"""
async def send_request(destination: str) -> JsonDict:
return await self._do_send_knock(destination, pdu)
return await self._try_destination_list(
"send_knock", destinations, send_request
)
async def _do_send_knock(self, destination: str, pdu: EventBase) -> JsonDict:
"""Send a knock event to a remote homeserver.
Args:
destination: The homeserver to send to.
pdu: The event to send.
Returns:
The remote homeserver can optionally return some state from the room. The response
dictionary is in the form:
{"knock_state_events": [<state event dict>, ...]}
The list of state events may be empty.
"""
time_now = self._clock.time_msec()
return await self.transport_layer.send_knock_v1(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)
async def get_public_rooms(
self,
remote_server: str,

View file

@ -129,7 +129,7 @@ class FederationServer(FederationBase):
# come in waves.
self._state_resp_cache = ResponseCache(
hs.get_clock(), "state_resp", timeout_ms=30000
) # type: ResponseCache[Tuple[str, str]]
) # type: ResponseCache[Tuple[str, Optional[str]]]
self._state_ids_resp_cache = ResponseCache(
hs.get_clock(), "state_ids_resp", timeout_ms=30000
) # type: ResponseCache[Tuple[str, str]]
@ -138,6 +138,8 @@ class FederationServer(FederationBase):
hs.config.federation.federation_metrics_domains
)
self._room_prejoin_state_types = hs.config.api.room_prejoin_state
async def on_backfill_request(
self, origin: str, room_id: str, versions: List[str], limit: int
) -> Tuple[int, Dict[str, Any]]:
@ -406,7 +408,7 @@ class FederationServer(FederationBase):
)
async def on_room_state_request(
self, origin: str, room_id: str, event_id: str
self, origin: str, room_id: str, event_id: Optional[str]
) -> Tuple[int, Dict[str, Any]]:
origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, room_id)
@ -463,7 +465,7 @@ class FederationServer(FederationBase):
return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
async def _on_context_state_request_compute(
self, room_id: str, event_id: str
self, room_id: str, event_id: Optional[str]
) -> Dict[str, list]:
if event_id:
pdus = await self.handler.get_state_for_pdu(
@ -586,6 +588,103 @@ class FederationServer(FederationBase):
await self.handler.on_send_leave_request(origin, pdu)
return {}
async def on_make_knock_request(
self, origin: str, room_id: str, user_id: str, supported_versions: List[str]
) -> Dict[str, Union[EventBase, str]]:
"""We've received a /make_knock/ request, so we create a partial knock
event for the room and hand that back, along with the room version, to the knocking
homeserver. We do *not* persist or process this event until the other server has
signed it and sent it back.
Args:
origin: The (verified) server name of the requesting server.
room_id: The room to create the knock event in.
user_id: The user to create the knock for.
supported_versions: The room versions supported by the requesting server.
Returns:
The partial knock event.
"""
origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, room_id)
room_version = await self.store.get_room_version(room_id)
# Check that this room version is supported by the remote homeserver
if room_version.identifier not in supported_versions:
logger.warning(
"Room version %s not in %s", room_version.identifier, supported_versions
)
raise IncompatibleRoomVersionError(room_version=room_version.identifier)
# Check that this room supports knocking as defined by its room version
if not room_version.msc2403_knocking:
raise SynapseError(
403,
"This room version does not support knocking",
errcode=Codes.FORBIDDEN,
)
pdu = await self.handler.on_make_knock_request(origin, room_id, user_id)
time_now = self._clock.time_msec()
return {
"event": pdu.get_pdu_json(time_now),
"room_version": room_version.identifier,
}
async def on_send_knock_request(
self,
origin: str,
content: JsonDict,
room_id: str,
) -> Dict[str, List[JsonDict]]:
"""
We have received a knock event for a room. Verify and send the event into the room
on the knocking homeserver's behalf. Then reply with some stripped state from the
room for the knockee.
Args:
origin: The remote homeserver of the knocking user.
content: The content of the request.
room_id: The ID of the room to knock on.
Returns:
The stripped room state.
"""
logger.debug("on_send_knock_request: content: %s", content)
room_version = await self.store.get_room_version(room_id)
# Check that this room supports knocking as defined by its room version
if not room_version.msc2403_knocking:
raise SynapseError(
403,
"This room version does not support knocking",
errcode=Codes.FORBIDDEN,
)
pdu = event_from_pdu_json(content, room_version)
origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, pdu.room_id)
logger.debug("on_send_knock_request: pdu sigs: %s", pdu.signatures)
pdu = await self._check_sigs_and_hash(room_version, pdu)
# Handle the event, and retrieve the EventContext
event_context = await self.handler.on_send_knock_request(origin, pdu)
# Retrieve stripped state events from the room and send them back to the remote
# server. This will allow the remote server's clients to display information
# related to the room while the knock request is pending.
stripped_room_state = (
await self.store.get_stripped_room_state_from_event_context(
event_context, self._room_prejoin_state_types
)
)
return {"knock_state_events": stripped_room_state}
async def on_event_auth(
self, origin: str, room_id: str, event_id: str
) -> Tuple[int, Dict[str, Any]]:

View file

@ -1,5 +1,5 @@
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
# Copyright 2020 Sorunome
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -220,7 +220,8 @@ class TransportLayerClient:
Fails with ``FederationDeniedError`` if the remote destination
is not in our federation whitelist
"""
valid_memberships = {Membership.JOIN, Membership.LEAVE}
valid_memberships = {Membership.JOIN, Membership.LEAVE, Membership.KNOCK}
if membership not in valid_memberships:
raise RuntimeError(
"make_membership_event called with membership='%s', must be one of %s"
@ -321,6 +322,40 @@ class TransportLayerClient:
return response
@log_function
async def send_knock_v1(
self,
destination: str,
room_id: str,
event_id: str,
content: JsonDict,
) -> JsonDict:
"""
Sends a signed knock membership event to a remote server. This is the second
step for knocking after make_knock.
Args:
destination: The remote homeserver.
room_id: The ID of the room to knock on.
event_id: The ID of the knock membership event that we're sending.
content: The knock membership event that we're sending. Note that this is not the
`content` field of the membership event, but the entire signed membership event
itself represented as a JSON dict.
Returns:
The remote homeserver can optionally return some state from the room. The response
dictionary is in the form:
{"knock_state_events": [<state event dict>, ...]}
The list of state events may be empty.
"""
path = _create_v1_path("/send_knock/%s/%s", room_id, event_id)
return await self.client.put_json(
destination=destination, path=path, data=content
)
@log_function
async def send_invite_v1(self, destination, room_id, event_id, content):
path = _create_v1_path("/invite/%s/%s", room_id, event_id)

View file

@ -1,6 +1,5 @@
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# Copyright 2019 The Matrix.org Foundation C.I.C.
# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
# Copyright 2020 Sorunome
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -13,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import logging
import re
@ -28,12 +26,14 @@ from synapse.api.urls import (
FEDERATION_V1_PREFIX,
FEDERATION_V2_PREFIX,
)
from synapse.handlers.groups_local import GroupsLocalHandler
from synapse.http.server import HttpServer, JsonResource
from synapse.http.servlet import (
parse_boolean_from_args,
parse_integer_from_args,
parse_json_object_from_request,
parse_string_from_args,
parse_strings_from_args,
)
from synapse.logging.context import run_in_background
from synapse.logging.opentracing import (
@ -275,10 +275,17 @@ class BaseFederationServlet:
RATELIMIT = True # Whether to rate limit requests or not
def __init__(self, handler, authenticator, ratelimiter, server_name):
self.handler = handler
def __init__(
self,
hs: HomeServer,
authenticator: Authenticator,
ratelimiter: FederationRateLimiter,
server_name: str,
):
self.hs = hs
self.authenticator = authenticator
self.ratelimiter = ratelimiter
self.server_name = server_name
def _wrap(self, func):
authenticator = self.authenticator
@ -375,17 +382,30 @@ class BaseFederationServlet:
)
class FederationSendServlet(BaseFederationServlet):
class BaseFederationServerServlet(BaseFederationServlet):
"""Abstract base class for federation servlet classes which provides a federation server handler.
See BaseFederationServlet for more information.
"""
def __init__(
self,
hs: HomeServer,
authenticator: Authenticator,
ratelimiter: FederationRateLimiter,
server_name: str,
):
super().__init__(hs, authenticator, ratelimiter, server_name)
self.handler = hs.get_federation_server()
class FederationSendServlet(BaseFederationServerServlet):
PATH = "/send/(?P<transaction_id>[^/]*)/?"
# We ratelimit manually in the handler as we queue up the requests and we
# don't want to fill up the ratelimiter with blocked requests.
RATELIMIT = False
def __init__(self, handler, server_name, **kwargs):
super().__init__(handler, server_name=server_name, **kwargs)
self.server_name = server_name
# This is when someone is trying to send us a bunch of data.
async def on_PUT(self, origin, content, query, transaction_id):
"""Called on PUT /send/<transaction_id>/
@ -434,7 +454,7 @@ class FederationSendServlet(BaseFederationServlet):
return code, response
class FederationEventServlet(BaseFederationServlet):
class FederationEventServlet(BaseFederationServerServlet):
PATH = "/event/(?P<event_id>[^/]*)/?"
# This is when someone asks for a data item for a given server data_id pair.
@ -442,7 +462,7 @@ class FederationEventServlet(BaseFederationServlet):
return await self.handler.on_pdu_request(origin, event_id)
class FederationStateV1Servlet(BaseFederationServlet):
class FederationStateV1Servlet(BaseFederationServerServlet):
PATH = "/state/(?P<room_id>[^/]*)/?"
# This is when someone asks for all data for a given room.
@ -454,7 +474,7 @@ class FederationStateV1Servlet(BaseFederationServlet):
)
class FederationStateIdsServlet(BaseFederationServlet):
class FederationStateIdsServlet(BaseFederationServerServlet):
PATH = "/state_ids/(?P<room_id>[^/]*)/?"
async def on_GET(self, origin, content, query, room_id):
@ -465,7 +485,7 @@ class FederationStateIdsServlet(BaseFederationServlet):
)
class FederationBackfillServlet(BaseFederationServlet):
class FederationBackfillServlet(BaseFederationServerServlet):
PATH = "/backfill/(?P<room_id>[^/]*)/?"
async def on_GET(self, origin, content, query, room_id):
@ -478,7 +498,7 @@ class FederationBackfillServlet(BaseFederationServlet):
return await self.handler.on_backfill_request(origin, room_id, versions, limit)
class FederationQueryServlet(BaseFederationServlet):
class FederationQueryServlet(BaseFederationServerServlet):
PATH = "/query/(?P<query_type>[^/]*)"
# This is when we receive a server-server Query
@ -488,7 +508,7 @@ class FederationQueryServlet(BaseFederationServlet):
return await self.handler.on_query_request(query_type, args)
class FederationMakeJoinServlet(BaseFederationServlet):
class FederationMakeJoinServlet(BaseFederationServerServlet):
PATH = "/make_join/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)"
async def on_GET(self, origin, _content, query, room_id, user_id):
@ -518,7 +538,7 @@ class FederationMakeJoinServlet(BaseFederationServlet):
return 200, content
class FederationMakeLeaveServlet(BaseFederationServlet):
class FederationMakeLeaveServlet(BaseFederationServerServlet):
PATH = "/make_leave/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)"
async def on_GET(self, origin, content, query, room_id, user_id):
@ -526,7 +546,7 @@ class FederationMakeLeaveServlet(BaseFederationServlet):
return 200, content
class FederationV1SendLeaveServlet(BaseFederationServlet):
class FederationV1SendLeaveServlet(BaseFederationServerServlet):
PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
async def on_PUT(self, origin, content, query, room_id, event_id):
@ -534,7 +554,7 @@ class FederationV1SendLeaveServlet(BaseFederationServlet):
return 200, (200, content)
class FederationV2SendLeaveServlet(BaseFederationServlet):
class FederationV2SendLeaveServlet(BaseFederationServerServlet):
PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
PREFIX = FEDERATION_V2_PREFIX
@ -544,14 +564,38 @@ class FederationV2SendLeaveServlet(BaseFederationServlet):
return 200, content
class FederationEventAuthServlet(BaseFederationServlet):
class FederationMakeKnockServlet(BaseFederationServerServlet):
PATH = "/make_knock/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)"
async def on_GET(self, origin, content, query, room_id, user_id):
try:
# Retrieve the room versions the remote homeserver claims to support
supported_versions = parse_strings_from_args(query, "ver", encoding="utf-8")
except KeyError:
raise SynapseError(400, "Missing required query parameter 'ver'")
content = await self.handler.on_make_knock_request(
origin, room_id, user_id, supported_versions=supported_versions
)
return 200, content
class FederationV1SendKnockServlet(BaseFederationServerServlet):
PATH = "/send_knock/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
async def on_PUT(self, origin, content, query, room_id, event_id):
content = await self.handler.on_send_knock_request(origin, content, room_id)
return 200, content
class FederationEventAuthServlet(BaseFederationServerServlet):
PATH = "/event_auth/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
async def on_GET(self, origin, content, query, room_id, event_id):
return await self.handler.on_event_auth(origin, room_id, event_id)
class FederationV1SendJoinServlet(BaseFederationServlet):
class FederationV1SendJoinServlet(BaseFederationServerServlet):
PATH = "/send_join/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
async def on_PUT(self, origin, content, query, room_id, event_id):
@ -561,7 +605,7 @@ class FederationV1SendJoinServlet(BaseFederationServlet):
return 200, (200, content)
class FederationV2SendJoinServlet(BaseFederationServlet):
class FederationV2SendJoinServlet(BaseFederationServerServlet):
PATH = "/send_join/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
PREFIX = FEDERATION_V2_PREFIX
@ -573,7 +617,7 @@ class FederationV2SendJoinServlet(BaseFederationServlet):
return 200, content
class FederationV1InviteServlet(BaseFederationServlet):
class FederationV1InviteServlet(BaseFederationServerServlet):
PATH = "/invite/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
async def on_PUT(self, origin, content, query, room_id, event_id):
@ -590,7 +634,7 @@ class FederationV1InviteServlet(BaseFederationServlet):
return 200, (200, content)
class FederationV2InviteServlet(BaseFederationServlet):
class FederationV2InviteServlet(BaseFederationServerServlet):
PATH = "/invite/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
PREFIX = FEDERATION_V2_PREFIX
@ -614,7 +658,7 @@ class FederationV2InviteServlet(BaseFederationServlet):
return 200, content
class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet):
class FederationThirdPartyInviteExchangeServlet(BaseFederationServerServlet):
PATH = "/exchange_third_party_invite/(?P<room_id>[^/]*)"
async def on_PUT(self, origin, content, query, room_id):
@ -622,21 +666,21 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet):
return 200, {}
class FederationClientKeysQueryServlet(BaseFederationServlet):
class FederationClientKeysQueryServlet(BaseFederationServerServlet):
PATH = "/user/keys/query"
async def on_POST(self, origin, content, query):
return await self.handler.on_query_client_keys(origin, content)
class FederationUserDevicesQueryServlet(BaseFederationServlet):
class FederationUserDevicesQueryServlet(BaseFederationServerServlet):
PATH = "/user/devices/(?P<user_id>[^/]*)"
async def on_GET(self, origin, content, query, user_id):
return await self.handler.on_query_user_devices(origin, user_id)
class FederationClientKeysClaimServlet(BaseFederationServlet):
class FederationClientKeysClaimServlet(BaseFederationServerServlet):
PATH = "/user/keys/claim"
async def on_POST(self, origin, content, query):
@ -644,7 +688,7 @@ class FederationClientKeysClaimServlet(BaseFederationServlet):
return 200, response
class FederationGetMissingEventsServlet(BaseFederationServlet):
class FederationGetMissingEventsServlet(BaseFederationServerServlet):
# TODO(paul): Why does this path alone end with "/?" optional?
PATH = "/get_missing_events/(?P<room_id>[^/]*)/?"
@ -664,7 +708,7 @@ class FederationGetMissingEventsServlet(BaseFederationServlet):
return 200, content
class On3pidBindServlet(BaseFederationServlet):
class On3pidBindServlet(BaseFederationServerServlet):
PATH = "/3pid/onbind"
REQUIRE_AUTH = False
@ -694,7 +738,7 @@ class On3pidBindServlet(BaseFederationServlet):
return 200, {}
class OpenIdUserInfo(BaseFederationServlet):
class OpenIdUserInfo(BaseFederationServerServlet):
"""
Exchange a bearer token for information about a user.
@ -770,8 +814,16 @@ class PublicRoomList(BaseFederationServlet):
PATH = "/publicRooms"
def __init__(self, handler, authenticator, ratelimiter, server_name, allow_access):
super().__init__(handler, authenticator, ratelimiter, server_name)
def __init__(
self,
hs: HomeServer,
authenticator: Authenticator,
ratelimiter: FederationRateLimiter,
server_name: str,
allow_access: bool,
):
super().__init__(hs, authenticator, ratelimiter, server_name)
self.handler = hs.get_room_list_handler()
self.allow_access = allow_access
async def on_GET(self, origin, content, query):
@ -856,7 +908,24 @@ class FederationVersionServlet(BaseFederationServlet):
)
class FederationGroupsProfileServlet(BaseFederationServlet):
class BaseGroupsServerServlet(BaseFederationServlet):
"""Abstract base class for federation servlet classes which provides a groups server handler.
See BaseFederationServlet for more information.
"""
def __init__(
self,
hs: HomeServer,
authenticator: Authenticator,
ratelimiter: FederationRateLimiter,
server_name: str,
):
super().__init__(hs, authenticator, ratelimiter, server_name)
self.handler = hs.get_groups_server_handler()
class FederationGroupsProfileServlet(BaseGroupsServerServlet):
"""Get/set the basic profile of a group on behalf of a user"""
PATH = "/groups/(?P<group_id>[^/]*)/profile"
@ -882,7 +951,7 @@ class FederationGroupsProfileServlet(BaseFederationServlet):
return 200, new_content
class FederationGroupsSummaryServlet(BaseFederationServlet):
class FederationGroupsSummaryServlet(BaseGroupsServerServlet):
PATH = "/groups/(?P<group_id>[^/]*)/summary"
async def on_GET(self, origin, content, query, group_id):
@ -895,7 +964,7 @@ class FederationGroupsSummaryServlet(BaseFederationServlet):
return 200, new_content
class FederationGroupsRoomsServlet(BaseFederationServlet):
class FederationGroupsRoomsServlet(BaseGroupsServerServlet):
"""Get the rooms in a group on behalf of a user"""
PATH = "/groups/(?P<group_id>[^/]*)/rooms"
@ -910,7 +979,7 @@ class FederationGroupsRoomsServlet(BaseFederationServlet):
return 200, new_content
class FederationGroupsAddRoomsServlet(BaseFederationServlet):
class FederationGroupsAddRoomsServlet(BaseGroupsServerServlet):
"""Add/remove room from group"""
PATH = "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)"
@ -938,7 +1007,7 @@ class FederationGroupsAddRoomsServlet(BaseFederationServlet):
return 200, new_content
class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet):
class FederationGroupsAddRoomsConfigServlet(BaseGroupsServerServlet):
"""Update room config in group"""
PATH = (
@ -958,7 +1027,7 @@ class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet):
return 200, result
class FederationGroupsUsersServlet(BaseFederationServlet):
class FederationGroupsUsersServlet(BaseGroupsServerServlet):
"""Get the users in a group on behalf of a user"""
PATH = "/groups/(?P<group_id>[^/]*)/users"
@ -973,7 +1042,7 @@ class FederationGroupsUsersServlet(BaseFederationServlet):
return 200, new_content
class FederationGroupsInvitedUsersServlet(BaseFederationServlet):
class FederationGroupsInvitedUsersServlet(BaseGroupsServerServlet):
"""Get the users that have been invited to a group"""
PATH = "/groups/(?P<group_id>[^/]*)/invited_users"
@ -990,7 +1059,7 @@ class FederationGroupsInvitedUsersServlet(BaseFederationServlet):
return 200, new_content
class FederationGroupsInviteServlet(BaseFederationServlet):
class FederationGroupsInviteServlet(BaseGroupsServerServlet):
"""Ask a group server to invite someone to the group"""
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite"
@ -1007,7 +1076,7 @@ class FederationGroupsInviteServlet(BaseFederationServlet):
return 200, new_content
class FederationGroupsAcceptInviteServlet(BaseFederationServlet):
class FederationGroupsAcceptInviteServlet(BaseGroupsServerServlet):
"""Accept an invitation from the group server"""
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/accept_invite"
@ -1021,7 +1090,7 @@ class FederationGroupsAcceptInviteServlet(BaseFederationServlet):
return 200, new_content
class FederationGroupsJoinServlet(BaseFederationServlet):
class FederationGroupsJoinServlet(BaseGroupsServerServlet):
"""Attempt to join a group"""
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/join"
@ -1035,7 +1104,7 @@ class FederationGroupsJoinServlet(BaseFederationServlet):
return 200, new_content
class FederationGroupsRemoveUserServlet(BaseFederationServlet):
class FederationGroupsRemoveUserServlet(BaseGroupsServerServlet):
"""Leave or kick a user from the group"""
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove"
@ -1052,7 +1121,24 @@ class FederationGroupsRemoveUserServlet(BaseFederationServlet):
return 200, new_content
class FederationGroupsLocalInviteServlet(BaseFederationServlet):
class BaseGroupsLocalServlet(BaseFederationServlet):
"""Abstract base class for federation servlet classes which provides a groups local handler.
See BaseFederationServlet for more information.
"""
def __init__(
self,
hs: HomeServer,
authenticator: Authenticator,
ratelimiter: FederationRateLimiter,
server_name: str,
):
super().__init__(hs, authenticator, ratelimiter, server_name)
self.handler = hs.get_groups_local_handler()
class FederationGroupsLocalInviteServlet(BaseGroupsLocalServlet):
"""A group server has invited a local user"""
PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite"
@ -1061,12 +1147,16 @@ class FederationGroupsLocalInviteServlet(BaseFederationServlet):
if get_domain_from_id(group_id) != origin:
raise SynapseError(403, "group_id doesn't match origin")
assert isinstance(
self.handler, GroupsLocalHandler
), "Workers cannot handle group invites."
new_content = await self.handler.on_invite(group_id, user_id, content)
return 200, new_content
class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet):
class FederationGroupsRemoveLocalUserServlet(BaseGroupsLocalServlet):
"""A group server has removed a local user"""
PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove"
@ -1075,6 +1165,10 @@ class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet):
if get_domain_from_id(group_id) != origin:
raise SynapseError(403, "user_id doesn't match origin")
assert isinstance(
self.handler, GroupsLocalHandler
), "Workers cannot handle group removals."
new_content = await self.handler.user_removed_from_group(
group_id, user_id, content
)
@ -1087,6 +1181,16 @@ class FederationGroupsRenewAttestaionServlet(BaseFederationServlet):
PATH = "/groups/(?P<group_id>[^/]*)/renew_attestation/(?P<user_id>[^/]*)"
def __init__(
self,
hs: HomeServer,
authenticator: Authenticator,
ratelimiter: FederationRateLimiter,
server_name: str,
):
super().__init__(hs, authenticator, ratelimiter, server_name)
self.handler = hs.get_groups_attestation_renewer()
async def on_POST(self, origin, content, query, group_id, user_id):
# We don't need to check auth here as we check the attestation signatures
@ -1097,7 +1201,7 @@ class FederationGroupsRenewAttestaionServlet(BaseFederationServlet):
return 200, new_content
class FederationGroupsSummaryRoomsServlet(BaseFederationServlet):
class FederationGroupsSummaryRoomsServlet(BaseGroupsServerServlet):
"""Add/remove a room from the group summary, with optional category.
Matches both:
@ -1154,7 +1258,7 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet):
return 200, resp
class FederationGroupsCategoriesServlet(BaseFederationServlet):
class FederationGroupsCategoriesServlet(BaseGroupsServerServlet):
"""Get all categories for a group"""
PATH = "/groups/(?P<group_id>[^/]*)/categories/?"
@ -1169,7 +1273,7 @@ class FederationGroupsCategoriesServlet(BaseFederationServlet):
return 200, resp
class FederationGroupsCategoryServlet(BaseFederationServlet):
class FederationGroupsCategoryServlet(BaseGroupsServerServlet):
"""Add/remove/get a category in a group"""
PATH = "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)"
@ -1222,7 +1326,7 @@ class FederationGroupsCategoryServlet(BaseFederationServlet):
return 200, resp
class FederationGroupsRolesServlet(BaseFederationServlet):
class FederationGroupsRolesServlet(BaseGroupsServerServlet):
"""Get roles in a group"""
PATH = "/groups/(?P<group_id>[^/]*)/roles/?"
@ -1237,7 +1341,7 @@ class FederationGroupsRolesServlet(BaseFederationServlet):
return 200, resp
class FederationGroupsRoleServlet(BaseFederationServlet):
class FederationGroupsRoleServlet(BaseGroupsServerServlet):
"""Add/remove/get a role in a group"""
PATH = "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)"
@ -1290,7 +1394,7 @@ class FederationGroupsRoleServlet(BaseFederationServlet):
return 200, resp
class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
class FederationGroupsSummaryUsersServlet(BaseGroupsServerServlet):
"""Add/remove a user from the group summary, with optional role.
Matches both:
@ -1345,7 +1449,7 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
return 200, resp
class FederationGroupsBulkPublicisedServlet(BaseFederationServlet):
class FederationGroupsBulkPublicisedServlet(BaseGroupsLocalServlet):
"""Get roles in a group"""
PATH = "/get_groups_publicised"
@ -1358,7 +1462,7 @@ class FederationGroupsBulkPublicisedServlet(BaseFederationServlet):
return 200, resp
class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet):
class FederationGroupsSettingJoinPolicyServlet(BaseGroupsServerServlet):
"""Sets whether a group is joinable without an invite or knock"""
PATH = "/groups/(?P<group_id>[^/]*)/settings/m.join_policy"
@ -1379,6 +1483,16 @@ class FederationSpaceSummaryServlet(BaseFederationServlet):
PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc2946"
PATH = "/spaces/(?P<room_id>[^/]*)"
def __init__(
self,
hs: HomeServer,
authenticator: Authenticator,
ratelimiter: FederationRateLimiter,
server_name: str,
):
super().__init__(hs, authenticator, ratelimiter, server_name)
self.handler = hs.get_space_summary_handler()
async def on_GET(
self,
origin: str,
@ -1444,16 +1558,25 @@ class RoomComplexityServlet(BaseFederationServlet):
PATH = "/rooms/(?P<room_id>[^/]*)/complexity"
PREFIX = FEDERATION_UNSTABLE_PREFIX
def __init__(
self,
hs: HomeServer,
authenticator: Authenticator,
ratelimiter: FederationRateLimiter,
server_name: str,
):
super().__init__(hs, authenticator, ratelimiter, server_name)
self._store = self.hs.get_datastore()
async def on_GET(self, origin, content, query, room_id):
store = self.handler.hs.get_datastore()
is_public = await store.is_room_world_readable_or_publicly_joinable(room_id)
is_public = await self._store.is_room_world_readable_or_publicly_joinable(
room_id
)
if not is_public:
raise SynapseError(404, "Room not found", errcode=Codes.INVALID_PARAM)
complexity = await store.get_room_complexity(room_id)
complexity = await self._store.get_room_complexity(room_id)
return 200, complexity
@ -1482,6 +1605,9 @@ FEDERATION_SERVLET_CLASSES = (
On3pidBindServlet,
FederationVersionServlet,
RoomComplexityServlet,
FederationSpaceSummaryServlet,
FederationV1SendKnockServlet,
FederationMakeKnockServlet,
) # type: Tuple[Type[BaseFederationServlet], ...]
OPENID_SERVLET_CLASSES = (
@ -1523,6 +1649,7 @@ GROUP_ATTESTATION_SERVLET_CLASSES = (
FederationGroupsRenewAttestaionServlet,
) # type: Tuple[Type[BaseFederationServlet], ...]
DEFAULT_SERVLET_GROUPS = (
"federation",
"room_list",
@ -1559,23 +1686,16 @@ def register_servlets(
if "federation" in servlet_groups:
for servletclass in FEDERATION_SERVLET_CLASSES:
servletclass(
handler=hs.get_federation_server(),
hs=hs,
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,
).register(resource)
FederationSpaceSummaryServlet(
handler=hs.get_space_summary_handler(),
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,
).register(resource)
if "openid" in servlet_groups:
for servletclass in OPENID_SERVLET_CLASSES:
servletclass(
handler=hs.get_federation_server(),
hs=hs,
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,
@ -1584,7 +1704,7 @@ def register_servlets(
if "room_list" in servlet_groups:
for servletclass in ROOM_LIST_CLASSES:
servletclass(
handler=hs.get_room_list_handler(),
hs=hs,
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,
@ -1594,7 +1714,7 @@ def register_servlets(
if "group_server" in servlet_groups:
for servletclass in GROUP_SERVER_SERVLET_CLASSES:
servletclass(
handler=hs.get_groups_server_handler(),
hs=hs,
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,
@ -1603,7 +1723,7 @@ def register_servlets(
if "group_local" in servlet_groups:
for servletclass in GROUP_LOCAL_SERVLET_CLASSES:
servletclass(
handler=hs.get_groups_local_handler(),
hs=hs,
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,
@ -1612,7 +1732,7 @@ def register_servlets(
if "group_attestation" in servlet_groups:
for servletclass in GROUP_ATTESTATION_SERVLET_CLASSES:
servletclass(
handler=hs.get_groups_attestation_renewer(),
hs=hs,
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,

View file

@ -79,9 +79,15 @@ class E2eKeysHandler:
"client_keys", self.on_federation_query_client_keys
)
# Limit the number of in-flight requests from a single device.
self._query_devices_linearizer = Linearizer(
name="query_devices",
max_count=10,
)
@trace
async def query_devices(
self, query_body: JsonDict, timeout: int, from_user_id: str
self, query_body: JsonDict, timeout: int, from_user_id: str, from_device_id: str
) -> JsonDict:
"""Handle a device key query from a client
@ -105,191 +111,197 @@ class E2eKeysHandler:
from_user_id: the user making the query. This is used when
adding cross-signing signatures to limit what signatures users
can see.
from_device_id: the device making the query. This is used to limit
the number of in-flight queries at a time.
"""
with await self._query_devices_linearizer.queue((from_user_id, from_device_id)):
device_keys_query = query_body.get(
"device_keys", {}
) # type: Dict[str, Iterable[str]]
device_keys_query = query_body.get(
"device_keys", {}
) # type: Dict[str, Iterable[str]]
# separate users by domain.
# make a map from domain to user_id to device_ids
local_query = {}
remote_queries = {}
# separate users by domain.
# make a map from domain to user_id to device_ids
local_query = {}
remote_queries = {}
for user_id, device_ids in device_keys_query.items():
# we use UserID.from_string to catch invalid user ids
if self.is_mine(UserID.from_string(user_id)):
local_query[user_id] = device_ids
else:
remote_queries[user_id] = device_ids
set_tag("local_key_query", local_query)
set_tag("remote_key_query", remote_queries)
# First get local devices.
# A map of destination -> failure response.
failures = {} # type: Dict[str, JsonDict]
results = {}
if local_query:
local_result = await self.query_local_devices(local_query)
for user_id, keys in local_result.items():
if user_id in local_query:
results[user_id] = keys
# Get cached cross-signing keys
cross_signing_keys = await self.get_cross_signing_keys_from_cache(
device_keys_query, from_user_id
)
# Now attempt to get any remote devices from our local cache.
# A map of destination -> user ID -> device IDs.
remote_queries_not_in_cache = {} # type: Dict[str, Dict[str, Iterable[str]]]
if remote_queries:
query_list = [] # type: List[Tuple[str, Optional[str]]]
for user_id, device_ids in remote_queries.items():
if device_ids:
query_list.extend((user_id, device_id) for device_id in device_ids)
for user_id, device_ids in device_keys_query.items():
# we use UserID.from_string to catch invalid user ids
if self.is_mine(UserID.from_string(user_id)):
local_query[user_id] = device_ids
else:
query_list.append((user_id, None))
remote_queries[user_id] = device_ids
(
user_ids_not_in_cache,
remote_results,
) = await self.store.get_user_devices_from_cache(query_list)
for user_id, devices in remote_results.items():
user_devices = results.setdefault(user_id, {})
for device_id, device in devices.items():
keys = device.get("keys", None)
device_display_name = device.get("device_display_name", None)
if keys:
result = dict(keys)
unsigned = result.setdefault("unsigned", {})
if device_display_name:
unsigned["device_display_name"] = device_display_name
user_devices[device_id] = result
set_tag("local_key_query", local_query)
set_tag("remote_key_query", remote_queries)
# check for missing cross-signing keys.
for user_id in remote_queries.keys():
cached_cross_master = user_id in cross_signing_keys["master_keys"]
cached_cross_selfsigning = (
user_id in cross_signing_keys["self_signing_keys"]
)
# check if we are missing only one of cross-signing master or
# self-signing key, but the other one is cached.
# as we need both, this will issue a federation request.
# if we don't have any of the keys, either the user doesn't have
# cross-signing set up, or the cached device list
# is not (yet) updated.
if cached_cross_master ^ cached_cross_selfsigning:
user_ids_not_in_cache.add(user_id)
# add those users to the list to fetch over federation.
for user_id in user_ids_not_in_cache:
domain = get_domain_from_id(user_id)
r = remote_queries_not_in_cache.setdefault(domain, {})
r[user_id] = remote_queries[user_id]
# Now fetch any devices that we don't have in our cache
@trace
async def do_remote_query(destination):
"""This is called when we are querying the device list of a user on
a remote homeserver and their device list is not in the device list
cache. If we share a room with this user and we're not querying for
specific user we will update the cache with their device list.
"""
destination_query = remote_queries_not_in_cache[destination]
# We first consider whether we wish to update the device list cache with
# the users device list. We want to track a user's devices when the
# authenticated user shares a room with the queried user and the query
# has not specified a particular device.
# If we update the cache for the queried user we remove them from further
# queries. We use the more efficient batched query_client_keys for all
# remaining users
user_ids_updated = []
for (user_id, device_list) in destination_query.items():
if user_id in user_ids_updated:
continue
if device_list:
continue
room_ids = await self.store.get_rooms_for_user(user_id)
if not room_ids:
continue
# We've decided we're sharing a room with this user and should
# probably be tracking their device lists. However, we haven't
# done an initial sync on the device list so we do it now.
try:
if self._is_master:
user_devices = await self.device_handler.device_list_updater.user_device_resync(
user_id
)
else:
user_devices = await self._user_device_resync_client(
user_id=user_id
)
user_devices = user_devices["devices"]
user_results = results.setdefault(user_id, {})
for device in user_devices:
user_results[device["device_id"]] = device["keys"]
user_ids_updated.append(user_id)
except Exception as e:
failures[destination] = _exception_to_failure(e)
if len(destination_query) == len(user_ids_updated):
# We've updated all the users in the query and we do not need to
# make any further remote calls.
return
# Remove all the users from the query which we have updated
for user_id in user_ids_updated:
destination_query.pop(user_id)
try:
remote_result = await self.federation.query_client_keys(
destination, {"device_keys": destination_query}, timeout=timeout
)
for user_id, keys in remote_result["device_keys"].items():
if user_id in destination_query:
# First get local devices.
# A map of destination -> failure response.
failures = {} # type: Dict[str, JsonDict]
results = {}
if local_query:
local_result = await self.query_local_devices(local_query)
for user_id, keys in local_result.items():
if user_id in local_query:
results[user_id] = keys
if "master_keys" in remote_result:
for user_id, key in remote_result["master_keys"].items():
# Get cached cross-signing keys
cross_signing_keys = await self.get_cross_signing_keys_from_cache(
device_keys_query, from_user_id
)
# Now attempt to get any remote devices from our local cache.
# A map of destination -> user ID -> device IDs.
remote_queries_not_in_cache = (
{}
) # type: Dict[str, Dict[str, Iterable[str]]]
if remote_queries:
query_list = [] # type: List[Tuple[str, Optional[str]]]
for user_id, device_ids in remote_queries.items():
if device_ids:
query_list.extend(
(user_id, device_id) for device_id in device_ids
)
else:
query_list.append((user_id, None))
(
user_ids_not_in_cache,
remote_results,
) = await self.store.get_user_devices_from_cache(query_list)
for user_id, devices in remote_results.items():
user_devices = results.setdefault(user_id, {})
for device_id, device in devices.items():
keys = device.get("keys", None)
device_display_name = device.get("device_display_name", None)
if keys:
result = dict(keys)
unsigned = result.setdefault("unsigned", {})
if device_display_name:
unsigned["device_display_name"] = device_display_name
user_devices[device_id] = result
# check for missing cross-signing keys.
for user_id in remote_queries.keys():
cached_cross_master = user_id in cross_signing_keys["master_keys"]
cached_cross_selfsigning = (
user_id in cross_signing_keys["self_signing_keys"]
)
# check if we are missing only one of cross-signing master or
# self-signing key, but the other one is cached.
# as we need both, this will issue a federation request.
# if we don't have any of the keys, either the user doesn't have
# cross-signing set up, or the cached device list
# is not (yet) updated.
if cached_cross_master ^ cached_cross_selfsigning:
user_ids_not_in_cache.add(user_id)
# add those users to the list to fetch over federation.
for user_id in user_ids_not_in_cache:
domain = get_domain_from_id(user_id)
r = remote_queries_not_in_cache.setdefault(domain, {})
r[user_id] = remote_queries[user_id]
# Now fetch any devices that we don't have in our cache
@trace
async def do_remote_query(destination):
"""This is called when we are querying the device list of a user on
a remote homeserver and their device list is not in the device list
cache. If we share a room with this user and we're not querying for
specific user we will update the cache with their device list.
"""
destination_query = remote_queries_not_in_cache[destination]
# We first consider whether we wish to update the device list cache with
# the users device list. We want to track a user's devices when the
# authenticated user shares a room with the queried user and the query
# has not specified a particular device.
# If we update the cache for the queried user we remove them from further
# queries. We use the more efficient batched query_client_keys for all
# remaining users
user_ids_updated = []
for (user_id, device_list) in destination_query.items():
if user_id in user_ids_updated:
continue
if device_list:
continue
room_ids = await self.store.get_rooms_for_user(user_id)
if not room_ids:
continue
# We've decided we're sharing a room with this user and should
# probably be tracking their device lists. However, we haven't
# done an initial sync on the device list so we do it now.
try:
if self._is_master:
user_devices = await self.device_handler.device_list_updater.user_device_resync(
user_id
)
else:
user_devices = await self._user_device_resync_client(
user_id=user_id
)
user_devices = user_devices["devices"]
user_results = results.setdefault(user_id, {})
for device in user_devices:
user_results[device["device_id"]] = device["keys"]
user_ids_updated.append(user_id)
except Exception as e:
failures[destination] = _exception_to_failure(e)
if len(destination_query) == len(user_ids_updated):
# We've updated all the users in the query and we do not need to
# make any further remote calls.
return
# Remove all the users from the query which we have updated
for user_id in user_ids_updated:
destination_query.pop(user_id)
try:
remote_result = await self.federation.query_client_keys(
destination, {"device_keys": destination_query}, timeout=timeout
)
for user_id, keys in remote_result["device_keys"].items():
if user_id in destination_query:
cross_signing_keys["master_keys"][user_id] = key
results[user_id] = keys
if "self_signing_keys" in remote_result:
for user_id, key in remote_result["self_signing_keys"].items():
if user_id in destination_query:
cross_signing_keys["self_signing_keys"][user_id] = key
if "master_keys" in remote_result:
for user_id, key in remote_result["master_keys"].items():
if user_id in destination_query:
cross_signing_keys["master_keys"][user_id] = key
except Exception as e:
failure = _exception_to_failure(e)
failures[destination] = failure
set_tag("error", True)
set_tag("reason", failure)
if "self_signing_keys" in remote_result:
for user_id, key in remote_result["self_signing_keys"].items():
if user_id in destination_query:
cross_signing_keys["self_signing_keys"][user_id] = key
await make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(do_remote_query, destination)
for destination in remote_queries_not_in_cache
],
consumeErrors=True,
).addErrback(unwrapFirstError)
)
except Exception as e:
failure = _exception_to_failure(e)
failures[destination] = failure
set_tag("error", True)
set_tag("reason", failure)
ret = {"device_keys": results, "failures": failures}
await make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(do_remote_query, destination)
for destination in remote_queries_not_in_cache
],
consumeErrors=True,
).addErrback(unwrapFirstError)
)
ret.update(cross_signing_keys)
ret = {"device_keys": results, "failures": failures}
return ret
ret.update(cross_signing_keys)
return ret
async def get_cross_signing_keys_from_cache(
self, query: Iterable[str], from_user_id: Optional[str]

View file

@ -1,6 +1,5 @@
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2017-2018 New Vector Ltd
# Copyright 2019 The Matrix.org Foundation C.I.C.
# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
# Copyright 2020 Sorunome
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -34,6 +33,7 @@ from typing import (
)
import attr
from prometheus_client import Counter
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json
from unpaddedbase64 import decode_base64
@ -102,6 +102,11 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
soft_failed_event_counter = Counter(
"synapse_federation_soft_failed_events_total",
"Events received over federation that we marked as soft_failed",
)
@attr.s(slots=True)
class _NewEventInfo:
@ -1550,6 +1555,77 @@ class FederationHandler(BaseHandler):
run_in_background(self._handle_queued_pdus, room_queue)
@log_function
async def do_knock(
self,
target_hosts: List[str],
room_id: str,
knockee: str,
content: JsonDict,
) -> Tuple[str, int]:
"""Sends the knock to the remote server.
This first triggers a make_knock request that returns a partial
event that we can fill out and sign. This is then sent to the
remote server via send_knock.
Knock events must be signed by the knockee's server before distributing.
Args:
target_hosts: A list of hosts that we want to try knocking through.
room_id: The ID of the room to knock on.
knockee: The ID of the user who is knocking.
content: The content of the knock event.
Returns:
A tuple of (event ID, stream ID).
Raises:
SynapseError: If the chosen remote server returns a 3xx/4xx code.
RuntimeError: If no servers were reachable.
"""
logger.debug("Knocking on room %s on behalf of user %s", room_id, knockee)
# Inform the remote server of the room versions we support
supported_room_versions = list(KNOWN_ROOM_VERSIONS.keys())
# Ask the remote server to create a valid knock event for us. Once received,
# we sign the event
params = {"ver": supported_room_versions} # type: Dict[str, Iterable[str]]
origin, event, event_format_version = await self._make_and_verify_event(
target_hosts, room_id, knockee, Membership.KNOCK, content, params=params
)
# Record the room ID and its version so that we have a record of the room
await self._maybe_store_room_on_outlier_membership(
room_id=event.room_id, room_version=event_format_version
)
# Initially try the host that we successfully called /make_knock on
try:
target_hosts.remove(origin)
target_hosts.insert(0, origin)
except ValueError:
pass
# Send the signed event back to the room, and potentially receive some
# further information about the room in the form of partial state events
stripped_room_state = await self.federation_client.send_knock(
target_hosts, event
)
# Store any stripped room state events in the "unsigned" key of the event.
# This is a bit of a hack and is cribbing off of invites. Basically we
# store the room state here and retrieve it again when this event appears
# in the invitee's sync stream. It is stripped out for all other local users.
event.unsigned["knock_room_state"] = stripped_room_state["knock_state_events"]
context = await self.state_handler.compute_event_context(event)
stream_id = await self.persist_events_and_notify(
event.room_id, [(event, context)]
)
return event.event_id, stream_id
async def _handle_queued_pdus(
self, room_queue: List[Tuple[EventBase, str]]
) -> None:
@ -1915,6 +1991,114 @@ class FederationHandler(BaseHandler):
return None
@log_function
async def on_make_knock_request(
self, origin: str, room_id: str, user_id: str
) -> EventBase:
"""We've received a make_knock request, so we create a partial
knock event for the room and return that. We do *not* persist or
process it until the other server has signed it and sent it back.
Args:
origin: The (verified) server name of the requesting server.
room_id: The room to create the knock event in.
user_id: The user to create the knock for.
Returns:
The partial knock event.
"""
if get_domain_from_id(user_id) != origin:
logger.info(
"Get /make_knock request for user %r from different origin %s, ignoring",
user_id,
origin,
)
raise SynapseError(403, "User not from origin", Codes.FORBIDDEN)
room_version = await self.store.get_room_version_id(room_id)
builder = self.event_builder_factory.new(
room_version,
{
"type": EventTypes.Member,
"content": {"membership": Membership.KNOCK},
"room_id": room_id,
"sender": user_id,
"state_key": user_id,
},
)
event, context = await self.event_creation_handler.create_new_client_event(
builder=builder
)
event_allowed = await self.third_party_event_rules.check_event_allowed(
event, context
)
if not event_allowed:
logger.warning("Creation of knock %s forbidden by third-party rules", event)
raise SynapseError(
403, "This event is not allowed in this context", Codes.FORBIDDEN
)
try:
# The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_knock_request`
await self.auth.check_from_context(
room_version, event, context, do_sig_check=False
)
except AuthError as e:
logger.warning("Failed to create new knock %r because %s", event, e)
raise e
return event
@log_function
async def on_send_knock_request(
self, origin: str, event: EventBase
) -> EventContext:
"""
We have received a knock event for a room. Verify that event and send it into the room
on the knocking homeserver's behalf.
Args:
origin: The remote homeserver of the knocking user.
event: The knocking member event that has been signed by the remote homeserver.
Returns:
The context of the event after inserting it into the room graph.
"""
logger.debug(
"on_send_knock_request: Got event: %s, signatures: %s",
event.event_id,
event.signatures,
)
if get_domain_from_id(event.sender) != origin:
logger.info(
"Got /send_knock request for user %r from different origin %s",
event.sender,
origin,
)
raise SynapseError(403, "User not from origin", Codes.FORBIDDEN)
event.internal_metadata.outlier = False
context = await self.state_handler.compute_event_context(event)
await self._auth_and_persist_event(origin, event, context)
event_allowed = await self.third_party_event_rules.check_event_allowed(
event, context
)
if not event_allowed:
logger.info("Sending of knock %s forbidden by third-party rules", event)
raise SynapseError(
403, "This event is not allowed in this context", Codes.FORBIDDEN
)
return context
async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]:
"""Returns the state at the event. i.e. not including said event."""
@ -2318,6 +2502,7 @@ class FederationHandler(BaseHandler):
event_auth.check(room_version_obj, event, auth_events=current_auth_events)
except AuthError as e:
logger.warning("Soft-failing %r because %s", event, e)
soft_failed_event_counter.inc()
event.internal_metadata.soft_failed = True
async def on_get_missing_events(

View file

@ -1,6 +1,7 @@
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2017-2018 New Vector Ltd
# Copyright 2019 The Matrix.org Foundation C.I.C.
# Copyright 2019-2020 The Matrix.org Foundation C.I.C.
# Copyrignt 2020 Sorunome
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -398,13 +399,14 @@ class EventCreationHandler:
self._events_shard_config = self.config.worker.events_shard_config
self._instance_name = hs.get_instance_name()
self.room_invite_state_types = self.hs.config.api.room_prejoin_state
self.room_prejoin_state_types = self.hs.config.api.room_prejoin_state
self.membership_types_to_include_profile_data_in = (
{Membership.JOIN, Membership.INVITE}
if self.hs.config.include_profile_data_on_invite
else {Membership.JOIN}
)
self.membership_types_to_include_profile_data_in = {
Membership.JOIN,
Membership.KNOCK,
}
if self.hs.config.include_profile_data_on_invite:
self.membership_types_to_include_profile_data_in.add(Membership.INVITE)
self.send_event = ReplicationSendEventRestServlet.make_client(hs)
@ -961,8 +963,8 @@ class EventCreationHandler:
room_version = await self.store.get_room_version_id(event.room_id)
if event.internal_metadata.is_out_of_band_membership():
# the only sort of out-of-band-membership events we expect to see here
# are invite rejections we have generated ourselves.
# the only sort of out-of-band-membership events we expect to see here are
# invite rejections and rescinded knocks that we have generated ourselves.
assert event.type == EventTypes.Member
assert event.content["membership"] == Membership.LEAVE
else:
@ -1239,7 +1241,7 @@ class EventCreationHandler:
"invite_room_state"
] = await self.store.get_stripped_room_state_from_event_context(
context,
self.room_invite_state_types,
self.room_prejoin_state_types,
membership_user_id=event.sender,
)
@ -1257,6 +1259,14 @@ class EventCreationHandler:
# TODO: Make sure the signatures actually are correct.
event.signatures.update(returned_invite.signatures)
if event.content["membership"] == Membership.KNOCK:
event.unsigned[
"knock_room_state"
] = await self.store.get_stripped_room_state_from_event_context(
context,
self.room_prejoin_state_types,
)
if event.type == EventTypes.Redaction:
original_event = await self.store.get_event(
event.redacts,

View file

@ -45,7 +45,7 @@ class RoomListHandler(BaseHandler):
self.response_cache = ResponseCache(
hs.get_clock(), "room_list"
) # type: ResponseCache[Tuple[Optional[int], Optional[str], ThirdPartyInstanceID]]
) # type: ResponseCache[Tuple[Optional[int], Optional[str], Optional[ThirdPartyInstanceID]]]
self.remote_response_cache = ResponseCache(
hs.get_clock(), "remote_room_list", timeout_ms=30 * 1000
) # type: ResponseCache[Tuple[str, Optional[int], Optional[str], bool, Optional[str]]]
@ -55,7 +55,7 @@ class RoomListHandler(BaseHandler):
limit: Optional[int] = None,
since_token: Optional[str] = None,
search_filter: Optional[dict] = None,
network_tuple: ThirdPartyInstanceID = EMPTY_THIRD_PARTY_ID,
network_tuple: Optional[ThirdPartyInstanceID] = EMPTY_THIRD_PARTY_ID,
from_federation: bool = False,
) -> JsonDict:
"""Generate a local public room list.
@ -112,7 +112,7 @@ class RoomListHandler(BaseHandler):
limit: Optional[int] = None,
since_token: Optional[str] = None,
search_filter: Optional[dict] = None,
network_tuple: ThirdPartyInstanceID = EMPTY_THIRD_PARTY_ID,
network_tuple: Optional[ThirdPartyInstanceID] = EMPTY_THIRD_PARTY_ID,
from_federation: bool = False,
) -> JsonDict:
"""Generate a public room list.
@ -170,6 +170,7 @@ class RoomListHandler(BaseHandler):
"world_readable": room["history_visibility"]
== HistoryVisibility.WORLD_READABLE,
"guest_can_join": room["guest_access"] == "can_join",
"join_rule": room["join_rules"],
}
# Filter out Nones rather omit the field altogether

View file

@ -1,4 +1,5 @@
# Copyright 2016-2020 The Matrix.org Foundation C.I.C.
# Copyright 2020 Sorunome
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -11,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import logging
import random
@ -30,7 +30,15 @@ from synapse.api.errors import (
from synapse.api.ratelimiting import Ratelimiter
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID
from synapse.types import (
JsonDict,
Requester,
RoomAlias,
RoomID,
StateMap,
UserID,
get_domain_from_id,
)
from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_left_room
@ -126,6 +134,24 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
"""
raise NotImplementedError()
@abc.abstractmethod
async def remote_knock(
self,
remote_room_hosts: List[str],
room_id: str,
user: UserID,
content: dict,
) -> Tuple[str, int]:
"""Try and knock on a room that this server is not in
Args:
remote_room_hosts: List of servers that can be used to knock via.
room_id: Room that we are trying to knock on.
user: User who is trying to knock.
content: A dict that should be used as the content of the knock event.
"""
raise NotImplementedError()
@abc.abstractmethod
async def remote_reject_invite(
self,
@ -149,6 +175,27 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
"""
raise NotImplementedError()
@abc.abstractmethod
async def remote_rescind_knock(
self,
knock_event_id: str,
txn_id: Optional[str],
requester: Requester,
content: JsonDict,
) -> Tuple[str, int]:
"""Rescind a local knock made on a remote room.
Args:
knock_event_id: The ID of the knock event to rescind.
txn_id: An optional transaction ID supplied by the client.
requester: The user making the request, according to the access token.
content: The content of the generated leave event.
Returns:
A tuple containing (event_id, stream_id of the leave event).
"""
raise NotImplementedError()
@abc.abstractmethod
async def _user_left_room(self, target: UserID, room_id: str) -> None:
"""Notifies distributor on master process that the user has left the
@ -623,53 +670,79 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
elif effective_membership_state == Membership.LEAVE:
if not is_host_in_room:
# perhaps we've been invited
# Figure out the user's current membership state for the room
(
current_membership_type,
current_membership_event_id,
) = await self.store.get_local_current_membership_for_user_in_room(
target.to_string(), room_id
)
if (
current_membership_type != Membership.INVITE
or not current_membership_event_id
):
if not current_membership_type or not current_membership_event_id:
logger.info(
"%s sent a leave request to %s, but that is not an active room "
"on this server, and there is no pending invite",
"on this server, or there is no pending invite or knock",
target,
room_id,
)
raise SynapseError(404, "Not a known room")
invite = await self.store.get_event(current_membership_event_id)
logger.info(
"%s rejects invite to %s from %s", target, room_id, invite.sender
)
if not self.hs.is_mine_id(invite.sender):
# send the rejection to the inviter's HS (with fallback to
# local event)
return await self.remote_reject_invite(
invite.event_id,
txn_id,
requester,
content,
# perhaps we've been invited
if current_membership_type == Membership.INVITE:
invite = await self.store.get_event(current_membership_event_id)
logger.info(
"%s rejects invite to %s from %s",
target,
room_id,
invite.sender,
)
# the inviter was on our server, but has now left. Carry on
# with the normal rejection codepath, which will also send the
# rejection out to any other servers we believe are still in the room.
if not self.hs.is_mine_id(invite.sender):
# send the rejection to the inviter's HS (with fallback to
# local event)
return await self.remote_reject_invite(
invite.event_id,
txn_id,
requester,
content,
)
# thanks to overzealous cleaning up of event_forward_extremities in
# `delete_old_current_state_events`, it's possible to end up with no
# forward extremities here. If that happens, let's just hang the
# rejection off the invite event.
#
# see: https://github.com/matrix-org/synapse/issues/7139
if len(latest_event_ids) == 0:
latest_event_ids = [invite.event_id]
# the inviter was on our server, but has now left. Carry on
# with the normal rejection codepath, which will also send the
# rejection out to any other servers we believe are still in the room.
# thanks to overzealous cleaning up of event_forward_extremities in
# `delete_old_current_state_events`, it's possible to end up with no
# forward extremities here. If that happens, let's just hang the
# rejection off the invite event.
#
# see: https://github.com/matrix-org/synapse/issues/7139
if len(latest_event_ids) == 0:
latest_event_ids = [invite.event_id]
# or perhaps this is a remote room that a local user has knocked on
elif current_membership_type == Membership.KNOCK:
knock = await self.store.get_event(current_membership_event_id)
return await self.remote_rescind_knock(
knock.event_id, txn_id, requester, content
)
elif effective_membership_state == Membership.KNOCK:
if not is_host_in_room:
# The knock needs to be sent over federation instead
remote_room_hosts.append(get_domain_from_id(room_id))
content["membership"] = Membership.KNOCK
profile = self.profile_handler
if "displayname" not in content:
content["displayname"] = await profile.get_displayname(target)
if "avatar_url" not in content:
content["avatar_url"] = await profile.get_avatar_url(target)
return await self.remote_knock(
remote_room_hosts, room_id, target, content
)
return await self._local_membership_update(
requester=requester,
@ -1229,6 +1302,35 @@ class RoomMemberMasterHandler(RoomMemberHandler):
invite_event, txn_id, requester, content
)
async def remote_rescind_knock(
self,
knock_event_id: str,
txn_id: Optional[str],
requester: Requester,
content: JsonDict,
) -> Tuple[str, int]:
"""
Rescinds a local knock made on a remote room
Args:
knock_event_id: The ID of the knock event to rescind.
txn_id: The transaction ID to use.
requester: The originator of the request.
content: The content of the leave event.
Implements RoomMemberHandler.remote_rescind_knock
"""
# TODO: We don't yet support rescinding knocks over federation
# as we don't know which homeserver to send it to. An obvious
# candidate is the remote homeserver we originally knocked through,
# however we don't currently store that information.
# Just rescind the knock locally
knock_event = await self.store.get_event(knock_event_id)
return await self._generate_local_out_of_band_leave(
knock_event, txn_id, requester, content
)
async def _generate_local_out_of_band_leave(
self,
previous_membership_event: EventBase,
@ -1292,6 +1394,36 @@ class RoomMemberMasterHandler(RoomMemberHandler):
return result_event.event_id, result_event.internal_metadata.stream_ordering
async def remote_knock(
self,
remote_room_hosts: List[str],
room_id: str,
user: UserID,
content: dict,
) -> Tuple[str, int]:
"""Sends a knock to a room. Attempts to do so via one remote out of a given list.
Args:
remote_room_hosts: A list of homeservers to try knocking through.
room_id: The ID of the room to knock on.
user: The user to knock on behalf of.
content: The content of the knock event.
Returns:
A tuple of (event ID, stream ID).
"""
# filter ourselves out of remote_room_hosts
remote_room_hosts = [
host for host in remote_room_hosts if host != self.hs.hostname
]
if len(remote_room_hosts) == 0:
raise SynapseError(404, "No known servers")
return await self.federation_handler.do_knock(
remote_room_hosts, room_id, user.to_string(), content=content
)
async def _user_left_room(self, target: UserID, room_id: str) -> None:
"""Implements RoomMemberHandler._user_left_room"""
user_left_room(self.distributor, target, room_id)

View file

@ -1,4 +1,4 @@
# Copyright 2018 New Vector Ltd
# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -19,10 +19,12 @@ from synapse.api.errors import SynapseError
from synapse.handlers.room_member import RoomMemberHandler
from synapse.replication.http.membership import (
ReplicationRemoteJoinRestServlet as ReplRemoteJoin,
ReplicationRemoteKnockRestServlet as ReplRemoteKnock,
ReplicationRemoteRejectInviteRestServlet as ReplRejectInvite,
ReplicationRemoteRescindKnockRestServlet as ReplRescindKnock,
ReplicationUserJoinedLeftRoomRestServlet as ReplJoinedLeft,
)
from synapse.types import Requester, UserID
from synapse.types import JsonDict, Requester, UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -35,7 +37,9 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
super().__init__(hs)
self._remote_join_client = ReplRemoteJoin.make_client(hs)
self._remote_knock_client = ReplRemoteKnock.make_client(hs)
self._remote_reject_client = ReplRejectInvite.make_client(hs)
self._remote_rescind_client = ReplRescindKnock.make_client(hs)
self._notify_change_client = ReplJoinedLeft.make_client(hs)
async def _remote_join(
@ -80,6 +84,53 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
)
return ret["event_id"], ret["stream_id"]
async def remote_rescind_knock(
self,
knock_event_id: str,
txn_id: Optional[str],
requester: Requester,
content: JsonDict,
) -> Tuple[str, int]:
"""
Rescinds a local knock made on a remote room
Args:
knock_event_id: the knock event
txn_id: optional transaction ID supplied by the client
requester: user making the request, according to the access token
content: additional content to include in the leave event.
Normally an empty dict.
Returns:
A tuple containing (event_id, stream_id of the leave event)
"""
ret = await self._remote_rescind_client(
knock_event_id=knock_event_id,
txn_id=txn_id,
requester=requester,
content=content,
)
return ret["event_id"], ret["stream_id"]
async def remote_knock(
self,
remote_room_hosts: List[str],
room_id: str,
user: UserID,
content: dict,
) -> Tuple[str, int]:
"""Sends a knock to a room.
Implements RoomMemberHandler.remote_knock
"""
ret = await self._remote_knock_client(
remote_room_hosts=remote_room_hosts,
room_id=room_id,
user=user,
content=content,
)
return ret["event_id"], ret["stream_id"]
async def _user_left_room(self, target: UserID, room_id: str) -> None:
"""Implements RoomMemberHandler._user_left_room"""
await self._notify_change_client(

View file

@ -402,10 +402,7 @@ class SpaceSummaryHandler:
return (), ()
return res.rooms, tuple(
ev.data
for ev in res.events
if ev.event_type == EventTypes.MSC1772_SPACE_CHILD
or ev.event_type == EventTypes.SpaceChild
ev.data for ev in res.events if ev.event_type == EventTypes.SpaceChild
)
async def _is_room_accessible(
@ -514,11 +511,6 @@ class SpaceSummaryHandler:
current_state_ids[(EventTypes.Create, "")]
)
# TODO: update once MSC1772 lands
room_type = create_event.content.get(EventContentFields.ROOM_TYPE)
if not room_type:
room_type = create_event.content.get(EventContentFields.MSC1772_ROOM_TYPE)
room_version = await self._store.get_room_version(room_id)
allowed_spaces = None
if await self._event_auth_handler.has_restricted_join_rules(
@ -540,7 +532,7 @@ class SpaceSummaryHandler:
),
"guest_can_join": stats["guest_access"] == "can_join",
"creation_ts": create_event.origin_server_ts,
"room_type": room_type,
"room_type": create_event.content.get(EventContentFields.ROOM_TYPE),
"allowed_spaces": allowed_spaces,
}
@ -569,9 +561,7 @@ class SpaceSummaryHandler:
[
event_id
for key, event_id in current_state_ids.items()
# TODO: update once MSC1772 has been FCP for a period of time.
if key[0] == EventTypes.MSC1772_SPACE_CHILD
or key[0] == EventTypes.SpaceChild
if key[0] == EventTypes.SpaceChild
]
)

View file

@ -1,4 +1,5 @@
# Copyright 2018 New Vector Ltd
# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
# Copyright 2020 Sorunome
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -230,6 +231,8 @@ class StatsHandler:
room_stats_delta["left_members"] -= 1
elif prev_membership == Membership.BAN:
room_stats_delta["banned_members"] -= 1
elif prev_membership == Membership.KNOCK:
room_stats_delta["knocked_members"] -= 1
else:
raise ValueError(
"%r is not a valid prev_membership" % (prev_membership,)
@ -251,6 +254,8 @@ class StatsHandler:
room_stats_delta["left_members"] += 1
elif membership == Membership.BAN:
room_stats_delta["banned_members"] += 1
elif membership == Membership.KNOCK:
room_stats_delta["knocked_members"] += 1
else:
raise ValueError("%r is not a valid membership" % (membership,))

View file

@ -160,6 +160,16 @@ class InvitedSyncResult:
return True
@attr.s(slots=True, frozen=True)
class KnockedSyncResult:
room_id = attr.ib(type=str)
knock = attr.ib(type=EventBase)
def __bool__(self) -> bool:
"""Knocked rooms should always be reported to the client"""
return True
@attr.s(slots=True, frozen=True)
class GroupsSyncResult:
join = attr.ib(type=JsonDict)
@ -193,6 +203,7 @@ class _RoomChanges:
room_entries = attr.ib(type=List["RoomSyncResultBuilder"])
invited = attr.ib(type=List[InvitedSyncResult])
knocked = attr.ib(type=List[KnockedSyncResult])
newly_joined_rooms = attr.ib(type=List[str])
newly_left_rooms = attr.ib(type=List[str])
@ -206,6 +217,7 @@ class SyncResult:
account_data: List of account_data events for the user.
joined: JoinedSyncResult for each joined room.
invited: InvitedSyncResult for each invited room.
knocked: KnockedSyncResult for each knocked on room.
archived: ArchivedSyncResult for each archived room.
to_device: List of direct messages for the device.
device_lists: List of user_ids whose devices have changed
@ -221,6 +233,7 @@ class SyncResult:
account_data = attr.ib(type=List[JsonDict])
joined = attr.ib(type=List[JoinedSyncResult])
invited = attr.ib(type=List[InvitedSyncResult])
knocked = attr.ib(type=List[KnockedSyncResult])
archived = attr.ib(type=List[ArchivedSyncResult])
to_device = attr.ib(type=List[JsonDict])
device_lists = attr.ib(type=DeviceLists)
@ -237,6 +250,7 @@ class SyncResult:
self.presence
or self.joined
or self.invited
or self.knocked
or self.archived
or self.account_data
or self.to_device
@ -1032,7 +1046,7 @@ class SyncHandler:
res = await self._generate_sync_entry_for_rooms(
sync_result_builder, account_data_by_room
)
newly_joined_rooms, newly_joined_or_invited_users, _, _ = res
newly_joined_rooms, newly_joined_or_invited_or_knocked_users, _, _ = res
_, _, newly_left_rooms, newly_left_users = res
block_all_presence_data = (
@ -1041,7 +1055,9 @@ class SyncHandler:
if self.hs_config.use_presence and not block_all_presence_data:
logger.debug("Fetching presence data")
await self._generate_sync_entry_for_presence(
sync_result_builder, newly_joined_rooms, newly_joined_or_invited_users
sync_result_builder,
newly_joined_rooms,
newly_joined_or_invited_or_knocked_users,
)
logger.debug("Fetching to-device data")
@ -1050,7 +1066,7 @@ class SyncHandler:
device_lists = await self._generate_sync_entry_for_device_list(
sync_result_builder,
newly_joined_rooms=newly_joined_rooms,
newly_joined_or_invited_users=newly_joined_or_invited_users,
newly_joined_or_invited_or_knocked_users=newly_joined_or_invited_or_knocked_users,
newly_left_rooms=newly_left_rooms,
newly_left_users=newly_left_users,
)
@ -1084,6 +1100,7 @@ class SyncHandler:
account_data=sync_result_builder.account_data,
joined=sync_result_builder.joined,
invited=sync_result_builder.invited,
knocked=sync_result_builder.knocked,
archived=sync_result_builder.archived,
to_device=sync_result_builder.to_device,
device_lists=device_lists,
@ -1143,7 +1160,7 @@ class SyncHandler:
self,
sync_result_builder: "SyncResultBuilder",
newly_joined_rooms: Set[str],
newly_joined_or_invited_users: Set[str],
newly_joined_or_invited_or_knocked_users: Set[str],
newly_left_rooms: Set[str],
newly_left_users: Set[str],
) -> DeviceLists:
@ -1152,8 +1169,9 @@ class SyncHandler:
Args:
sync_result_builder
newly_joined_rooms: Set of rooms user has joined since previous sync
newly_joined_or_invited_users: Set of users that have joined or
been invited to a room since previous sync.
newly_joined_or_invited_or_knocked_users: Set of users that have joined,
been invited to a room or are knocking on a room since
previous sync.
newly_left_rooms: Set of rooms user has left since previous sync
newly_left_users: Set of users that have left a room we're in since
previous sync
@ -1164,7 +1182,9 @@ class SyncHandler:
# We're going to mutate these fields, so lets copy them rather than
# assume they won't get used later.
newly_joined_or_invited_users = set(newly_joined_or_invited_users)
newly_joined_or_invited_or_knocked_users = set(
newly_joined_or_invited_or_knocked_users
)
newly_left_users = set(newly_left_users)
if since_token and since_token.device_list_key:
@ -1203,11 +1223,11 @@ class SyncHandler:
# Step 1b, check for newly joined rooms
for room_id in newly_joined_rooms:
joined_users = await self.store.get_users_in_room(room_id)
newly_joined_or_invited_users.update(joined_users)
newly_joined_or_invited_or_knocked_users.update(joined_users)
# TODO: Check that these users are actually new, i.e. either they
# weren't in the previous sync *or* they left and rejoined.
users_that_have_changed.update(newly_joined_or_invited_users)
users_that_have_changed.update(newly_joined_or_invited_or_knocked_users)
user_signatures_changed = (
await self.store.get_users_whose_signatures_changed(
@ -1453,6 +1473,7 @@ class SyncHandler:
room_entries = room_changes.room_entries
invited = room_changes.invited
knocked = room_changes.knocked
newly_joined_rooms = room_changes.newly_joined_rooms
newly_left_rooms = room_changes.newly_left_rooms
@ -1473,9 +1494,10 @@ class SyncHandler:
await concurrently_execute(handle_room_entries, room_entries, 10)
sync_result_builder.invited.extend(invited)
sync_result_builder.knocked.extend(knocked)
# Now we want to get any newly joined or invited users
newly_joined_or_invited_users = set()
# Now we want to get any newly joined, invited or knocking users
newly_joined_or_invited_or_knocked_users = set()
newly_left_users = set()
if since_token:
for joined_sync in sync_result_builder.joined:
@ -1487,19 +1509,22 @@ class SyncHandler:
if (
event.membership == Membership.JOIN
or event.membership == Membership.INVITE
or event.membership == Membership.KNOCK
):
newly_joined_or_invited_users.add(event.state_key)
newly_joined_or_invited_or_knocked_users.add(
event.state_key
)
else:
prev_content = event.unsigned.get("prev_content", {})
prev_membership = prev_content.get("membership", None)
if prev_membership == Membership.JOIN:
newly_left_users.add(event.state_key)
newly_left_users -= newly_joined_or_invited_users
newly_left_users -= newly_joined_or_invited_or_knocked_users
return (
set(newly_joined_rooms),
newly_joined_or_invited_users,
newly_joined_or_invited_or_knocked_users,
set(newly_left_rooms),
newly_left_users,
)
@ -1554,6 +1579,7 @@ class SyncHandler:
newly_left_rooms = []
room_entries = []
invited = []
knocked = []
for room_id, events in mem_change_events_by_room_id.items():
logger.debug(
"Membership changes in %s: [%s]",
@ -1633,9 +1659,17 @@ class SyncHandler:
should_invite = non_joins[-1].membership == Membership.INVITE
if should_invite:
if event.sender not in ignored_users:
room_sync = InvitedSyncResult(room_id, invite=non_joins[-1])
if room_sync:
invited.append(room_sync)
invite_room_sync = InvitedSyncResult(room_id, invite=non_joins[-1])
if invite_room_sync:
invited.append(invite_room_sync)
# Only bother if our latest membership in the room is knock (and we haven't
# been accepted/rejected in the meantime).
should_knock = non_joins[-1].membership == Membership.KNOCK
if should_knock:
knock_room_sync = KnockedSyncResult(room_id, knock=non_joins[-1])
if knock_room_sync:
knocked.append(knock_room_sync)
# Always include leave/ban events. Just take the last one.
# TODO: How do we handle ban -> leave in same batch?
@ -1739,7 +1773,13 @@ class SyncHandler:
)
room_entries.append(entry)
return _RoomChanges(room_entries, invited, newly_joined_rooms, newly_left_rooms)
return _RoomChanges(
room_entries,
invited,
knocked,
newly_joined_rooms,
newly_left_rooms,
)
async def _get_all_rooms(
self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str]
@ -1759,6 +1799,7 @@ class SyncHandler:
membership_list = (
Membership.INVITE,
Membership.KNOCK,
Membership.JOIN,
Membership.LEAVE,
Membership.BAN,
@ -1770,6 +1811,7 @@ class SyncHandler:
room_entries = []
invited = []
knocked = []
for event in room_list:
if event.membership == Membership.JOIN:
@ -1789,8 +1831,11 @@ class SyncHandler:
continue
invite = await self.store.get_event(event.event_id)
invited.append(InvitedSyncResult(room_id=event.room_id, invite=invite))
elif event.membership == Membership.KNOCK:
knock = await self.store.get_event(event.event_id)
knocked.append(KnockedSyncResult(room_id=event.room_id, knock=knock))
elif event.membership in (Membership.LEAVE, Membership.BAN):
# Always send down rooms we were banned or kicked from.
# Always send down rooms we were banned from or kicked from.
if not sync_config.filter_collection.include_leave:
if event.membership == Membership.LEAVE:
if user_id == event.sender:
@ -1811,7 +1856,7 @@ class SyncHandler:
)
)
return _RoomChanges(room_entries, invited, [], [])
return _RoomChanges(room_entries, invited, knocked, [], [])
async def _generate_room_entry(
self,
@ -2102,6 +2147,7 @@ class SyncResultBuilder:
account_data (list)
joined (list[JoinedSyncResult])
invited (list[InvitedSyncResult])
knocked (list[KnockedSyncResult])
archived (list[ArchivedSyncResult])
groups (GroupsSyncResult|None)
to_device (list)
@ -2117,6 +2163,7 @@ class SyncResultBuilder:
account_data = attr.ib(type=List[JsonDict], default=attr.Factory(list))
joined = attr.ib(type=List[JoinedSyncResult], default=attr.Factory(list))
invited = attr.ib(type=List[InvitedSyncResult], default=attr.Factory(list))
knocked = attr.ib(type=List[KnockedSyncResult], default=attr.Factory(list))
archived = attr.ib(type=List[ArchivedSyncResult], default=attr.Factory(list))
groups = attr.ib(type=Optional[GroupsSyncResult], default=None)
to_device = attr.ib(type=List[JsonDict], default=attr.Factory(list))

View file

@ -65,13 +65,9 @@ from synapse.http.client import (
read_body_with_max_size,
)
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.logging import opentracing
from synapse.logging.context import make_deferred_yieldable
from synapse.logging.opentracing import (
inject_active_span_byte_dict,
set_tag,
start_active_span,
tags,
)
from synapse.logging.opentracing import set_tag, start_active_span, tags
from synapse.types import ISynapseReactor, JsonDict
from synapse.util import json_decoder
from synapse.util.async_helpers import timeout_deferred
@ -322,7 +318,9 @@ class MatrixFederationHttpClient:
# We need to use a DNS resolver which filters out blacklisted IP
# addresses, to prevent DNS rebinding.
self.reactor = BlacklistingReactorWrapper(
hs.get_reactor(), None, hs.config.federation_ip_range_blacklist
hs.get_reactor(),
hs.config.federation_ip_range_whitelist,
hs.config.federation_ip_range_blacklist,
) # type: ISynapseReactor
user_agent = hs.version_string
@ -497,7 +495,7 @@ class MatrixFederationHttpClient:
# Inject the span into the headers
headers_dict = {} # type: Dict[bytes, List[bytes]]
inject_active_span_byte_dict(headers_dict, request.destination)
opentracing.inject_header_dict(headers_dict, request.destination)
headers_dict[b"User-Agent"] = [self.version_string_bytes]

View file

@ -13,7 +13,6 @@
# limitations under the License.
""" This module contains base REST classes for constructing REST servlets. """
import logging
from typing import Dict, Iterable, List, Optional, overload
@ -295,6 +294,30 @@ def parse_strings_from_args(
return default
@overload
def parse_string_from_args(
args: Dict[bytes, List[bytes]],
name: str,
default: Optional[str] = None,
required: Literal[True] = True,
allowed_values: Optional[Iterable[str]] = None,
encoding: str = "ascii",
) -> str:
...
@overload
def parse_string_from_args(
args: Dict[bytes, List[bytes]],
name: str,
default: Optional[str] = None,
required: bool = False,
allowed_values: Optional[Iterable[str]] = None,
encoding: str = "ascii",
) -> Optional[str]:
...
def parse_string_from_args(
args: Dict[bytes, List[bytes]],
name: str,

View file

@ -168,7 +168,7 @@ import inspect
import logging
import re
from functools import wraps
from typing import TYPE_CHECKING, Dict, Optional, Pattern, Type
from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Pattern, Type
import attr
@ -278,6 +278,10 @@ class SynapseTags:
DB_TXN_ID = "db.txn_id"
class SynapseBaggage:
FORCE_TRACING = "synapse-force-tracing"
# Block everything by default
# A regex which matches the server_names to expose traces for.
# None means 'block everything'.
@ -285,6 +289,8 @@ _homeserver_whitelist = None # type: Optional[Pattern[str]]
# Util methods
Sentinel = object()
def only_if_tracing(func):
"""Executes the function only if we're tracing. Otherwise returns None."""
@ -447,12 +453,28 @@ def start_active_span(
)
def start_active_span_follows_from(operation_name, contexts):
def start_active_span_follows_from(
operation_name: str, contexts: Collection, inherit_force_tracing=False
):
"""Starts an active opentracing span, with additional references to previous spans
Args:
operation_name: name of the operation represented by the new span
contexts: the previous spans to inherit from
inherit_force_tracing: if set, and any of the previous contexts have had tracing
forced, the new span will also have tracing forced.
"""
if opentracing is None:
return noop_context_manager()
references = [opentracing.follows_from(context) for context in contexts]
scope = start_active_span(operation_name, references=references)
if inherit_force_tracing and any(
is_context_forced_tracing(ctx) for ctx in contexts
):
force_tracing(scope.span)
return scope
@ -551,6 +573,10 @@ def start_active_span_from_edu(
# Opentracing setters for tags, logs, etc
@only_if_tracing
def active_span():
"""Get the currently active span, if any"""
return opentracing.tracer.active_span
@ensure_active_span("set a tag")
@ -571,25 +597,52 @@ def set_operation_name(operation_name):
opentracing.tracer.active_span.set_operation_name(operation_name)
@only_if_tracing
def force_tracing(span=Sentinel) -> None:
"""Force sampling for the active/given span and its children.
Args:
span: span to force tracing for. By default, the active span.
"""
if span is Sentinel:
span = opentracing.tracer.active_span
if span is None:
logger.error("No active span in force_tracing")
return
span.set_tag(opentracing.tags.SAMPLING_PRIORITY, 1)
# also set a bit of baggage, so that we have a way of figuring out if
# it is enabled later
span.set_baggage_item(SynapseBaggage.FORCE_TRACING, "1")
def is_context_forced_tracing(span_context) -> bool:
"""Check if sampling has been force for the given span context."""
if span_context is None:
return False
return span_context.baggage.get(SynapseBaggage.FORCE_TRACING) is not None
# Injection and extraction
@ensure_active_span("inject the span into a header")
def inject_active_span_twisted_headers(headers, destination, check_destination=True):
@ensure_active_span("inject the span into a header dict")
def inject_header_dict(
headers: Dict[bytes, List[bytes]],
destination: Optional[str] = None,
check_destination: bool = True,
) -> None:
"""
Injects a span context into twisted headers in-place
Injects a span context into a dict of HTTP headers
Args:
headers (twisted.web.http_headers.Headers)
destination (str): address of entity receiving the span context. If check_destination
is true the context will only be injected if the destination matches the
opentracing whitelist
headers: the dict to inject headers into
destination: address of entity receiving the span context. Must be given unless
check_destination is False. The context will only be injected if the
destination matches the opentracing whitelist
check_destination (bool): If false, destination will be ignored and the context
will always be injected.
span (opentracing.Span)
Returns:
In-place modification of headers
Note:
The headers set by the tracer are custom to the tracer implementation which
@ -598,45 +651,13 @@ def inject_active_span_twisted_headers(headers, destination, check_destination=T
here:
https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/constants.py
"""
if check_destination and not whitelisted_homeserver(destination):
return
span = opentracing.tracer.active_span
carrier = {} # type: Dict[str, str]
opentracing.tracer.inject(span.context, opentracing.Format.HTTP_HEADERS, carrier)
for key, value in carrier.items():
headers.addRawHeaders(key, value)
@ensure_active_span("inject the span into a byte dict")
def inject_active_span_byte_dict(headers, destination, check_destination=True):
"""
Injects a span context into a dict where the headers are encoded as byte
strings
Args:
headers (dict)
destination (str): address of entity receiving the span context. If check_destination
is true the context will only be injected if the destination matches the
opentracing whitelist
check_destination (bool): If false, destination will be ignored and the context
will always be injected.
span (opentracing.Span)
Returns:
In-place modification of headers
Note:
The headers set by the tracer are custom to the tracer implementation which
should be unique enough that they don't interfere with any headers set by
synapse or twisted. If we're still using jaeger these headers would be those
here:
https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/constants.py
"""
if check_destination and not whitelisted_homeserver(destination):
return
if check_destination:
if destination is None:
raise ValueError(
"destination must be given unless check_destination is False"
)
if not whitelisted_homeserver(destination):
return
span = opentracing.tracer.active_span
@ -647,38 +668,6 @@ def inject_active_span_byte_dict(headers, destination, check_destination=True):
headers[key.encode()] = [value.encode()]
@ensure_active_span("inject the span into a text map")
def inject_active_span_text_map(carrier, destination, check_destination=True):
"""
Injects a span context into a dict
Args:
carrier (dict)
destination (str): address of entity receiving the span context. If check_destination
is true the context will only be injected if the destination matches the
opentracing whitelist
check_destination (bool): If false, destination will be ignored and the context
will always be injected.
Returns:
In-place modification of carrier
Note:
The headers set by the tracer are custom to the tracer implementation which
should be unique enough that they don't interfere with any headers set by
synapse or twisted. If we're still using jaeger these headers would be those
here:
https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/constants.py
"""
if check_destination and not whitelisted_homeserver(destination):
return
opentracing.tracer.inject(
opentracing.tracer.active_span.context, opentracing.Format.TEXT_MAP, carrier
)
@ensure_active_span("get the active span context as a dict", ret={})
def get_active_span_text_map(destination=None):
"""

View file

@ -23,7 +23,8 @@ from prometheus_client import Counter, Gauge
from synapse.api.errors import HttpResponseException, SynapseError
from synapse.http import RequestTimedOutError
from synapse.logging.opentracing import inject_active_span_byte_dict, trace
from synapse.logging import opentracing
from synapse.logging.opentracing import trace
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import random_string
@ -235,7 +236,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
# Add an authorization header, if configured.
if replication_secret:
headers[b"Authorization"] = [b"Bearer " + replication_secret]
inject_active_span_byte_dict(headers, None, check_destination=False)
opentracing.inject_header_dict(headers, check_destination=False)
try:
result = await request_func(uri, data, headers=headers)
break
@ -284,7 +285,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
self.__class__.__name__,
)
def _check_auth_and_handle(self, request, **kwargs):
async def _check_auth_and_handle(self, request, **kwargs):
"""Called on new incoming requests when caching is enabled. Checks
if there is a cached response for the request and returns that,
otherwise calls `_handle_request` and caches its response.
@ -299,8 +300,8 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
if self.CACHE:
txn_id = kwargs.pop("txn_id")
return self.response_cache.wrap(
return await self.response_cache.wrap(
txn_id, self._handle_request, request, **kwargs
)
return self._handle_request(request, **kwargs)
return await self._handle_request(request, **kwargs)

View file

@ -97,6 +97,76 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
return 200, {"event_id": event_id, "stream_id": stream_id}
class ReplicationRemoteKnockRestServlet(ReplicationEndpoint):
"""Perform a remote knock for the given user on the given room
Request format:
POST /_synapse/replication/remote_knock/:room_id/:user_id
{
"requester": ...,
"remote_room_hosts": [...],
"content": { ... }
}
"""
NAME = "remote_knock"
PATH_ARGS = ("room_id", "user_id")
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.federation_handler = hs.get_federation_handler()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
@staticmethod
async def _serialize_payload( # type: ignore
requester: Requester,
room_id: str,
user_id: str,
remote_room_hosts: List[str],
content: JsonDict,
):
"""
Args:
requester: The user making the request, according to the access token.
room_id: The ID of the room to knock on.
user_id: The ID of the knocking user.
remote_room_hosts: Servers to try and send the knock via.
content: The event content to use for the knock event.
"""
return {
"requester": requester.serialize(),
"remote_room_hosts": remote_room_hosts,
"content": content,
}
async def _handle_request( # type: ignore
self,
request: SynapseRequest,
room_id: str,
user_id: str,
):
content = parse_json_object_from_request(request)
remote_room_hosts = content["remote_room_hosts"]
event_content = content["content"]
requester = Requester.deserialize(self.store, content["requester"])
request.requester = requester
logger.debug("remote_knock: %s on room: %s", user_id, room_id)
event_id, stream_id = await self.federation_handler.do_knock(
remote_room_hosts, room_id, user_id, event_content
)
return 200, {"event_id": event_id, "stream_id": stream_id}
class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
"""Rejects an out-of-band invite we have received from a remote server
@ -167,6 +237,75 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
return 200, {"event_id": event_id, "stream_id": stream_id}
class ReplicationRemoteRescindKnockRestServlet(ReplicationEndpoint):
"""Rescinds a local knock made on a remote room
Request format:
POST /_synapse/replication/remote_rescind_knock/:event_id
{
"txn_id": ...,
"requester": ...,
"content": { ... }
}
"""
NAME = "remote_rescind_knock"
PATH_ARGS = ("knock_event_id",)
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.member_handler = hs.get_room_member_handler()
@staticmethod
async def _serialize_payload( # type: ignore
knock_event_id: str,
txn_id: Optional[str],
requester: Requester,
content: JsonDict,
):
"""
Args:
knock_event_id: The ID of the knock to be rescinded.
txn_id: An optional transaction ID supplied by the client.
requester: The user making the rescind request, according to the access token.
content: The content to include in the rescind event.
"""
return {
"txn_id": txn_id,
"requester": requester.serialize(),
"content": content,
}
async def _handle_request( # type: ignore
self,
request: SynapseRequest,
knock_event_id: str,
):
content = parse_json_object_from_request(request)
txn_id = content["txn_id"]
event_content = content["content"]
requester = Requester.deserialize(self.store, content["requester"])
request.requester = requester
# hopefully we're now on the master, so this won't recurse!
event_id, stream_id = await self.member_handler.remote_rescind_knock(
knock_event_id,
txn_id,
requester,
event_content,
)
return 200, {"event_id": event_id, "stream_id": stream_id}
class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
"""Notifies that a user has joined or left the room
@ -206,7 +345,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
return {}
def _handle_request( # type: ignore
async def _handle_request( # type: ignore
self, request: Request, room_id: str, user_id: str, change: str
) -> Tuple[int, JsonDict]:
logger.info("user membership change: %s in %s", user_id, room_id)

View file

@ -38,6 +38,7 @@ from synapse.rest.client.v2_alpha import (
filter,
groups,
keys,
knock,
notifications,
openid,
password_policy,
@ -120,6 +121,7 @@ class ClientRestResource(JsonResource):
account_validity.register_servlets(hs, client_resource)
relations.register_servlets(hs, client_resource)
password_policy.register_servlets(hs, client_resource)
knock.register_servlets(hs, client_resource)
# moving to /_synapse/admin
admin.register_servlets_for_client_rest_resource(hs, client_resource)

View file

@ -14,10 +14,9 @@
# limitations under the License.
""" This module contains REST servlets to do with rooms: /rooms/<paths> """
import logging
import re
from typing import TYPE_CHECKING, List, Optional, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from urllib import parse as urlparse
from synapse.api.constants import EventTypes, Membership
@ -38,6 +37,7 @@ from synapse.http.servlet import (
parse_integer,
parse_json_object_from_request,
parse_string,
parse_strings_from_args,
)
from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import set_tag
@ -278,7 +278,12 @@ class JoinRoomAliasServlet(TransactionRestServlet):
PATTERNS = "/join/(?P<room_identifier>[^/]*)"
register_txn_path(self, PATTERNS, http_server)
async def on_POST(self, request, room_identifier, txn_id=None):
async def on_POST(
self,
request: SynapseRequest,
room_identifier: str,
txn_id: Optional[str] = None,
):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
try:
@ -290,17 +295,18 @@ class JoinRoomAliasServlet(TransactionRestServlet):
if RoomID.is_valid(room_identifier):
room_id = room_identifier
try:
remote_room_hosts = [
x.decode("ascii") for x in request.args[b"server_name"]
] # type: Optional[List[str]]
except Exception:
remote_room_hosts = None
# twisted.web.server.Request.args is incorrectly defined as Optional[Any]
args: Dict[bytes, List[bytes]] = request.args # type: ignore
remote_room_hosts = parse_strings_from_args(
args, "server_name", required=False
)
elif RoomAlias.is_valid(room_identifier):
handler = self.room_member_handler
room_alias = RoomAlias.from_string(room_identifier)
room_id, remote_room_hosts = await handler.lookup_room_alias(room_alias)
room_id = room_id.to_string()
room_id_obj, remote_room_hosts = await handler.lookup_room_alias(room_alias)
room_id = room_id_obj.to_string()
else:
raise SynapseError(
400, "%s was not legal room ID or room alias" % (room_identifier,)

View file

@ -160,9 +160,12 @@ class KeyQueryServlet(RestServlet):
async def on_POST(self, request):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
user_id = requester.user.to_string()
device_id = requester.device_id
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
result = await self.e2e_keys_handler.query_devices(body, timeout, user_id)
result = await self.e2e_keys_handler.query_devices(
body, timeout, user_id, device_id
)
return 200, result

View file

@ -0,0 +1,107 @@
# Copyright 2020 Sorunome
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from twisted.web.server import Request
from synapse.api.constants import Membership
from synapse.api.errors import SynapseError
from synapse.http.servlet import (
RestServlet,
parse_json_object_from_request,
parse_strings_from_args,
)
from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import set_tag
from synapse.rest.client.transactions import HttpTransactionCache
from synapse.types import JsonDict, RoomAlias, RoomID
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
from ._base import client_patterns
logger = logging.getLogger(__name__)
class KnockRoomAliasServlet(RestServlet):
"""
POST /knock/{roomIdOrAlias}
"""
PATTERNS = client_patterns("/knock/(?P<room_identifier>[^/]*)")
def __init__(self, hs: "HomeServer"):
super().__init__()
self.txns = HttpTransactionCache(hs)
self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth()
async def on_POST(
self,
request: SynapseRequest,
room_identifier: str,
txn_id: Optional[str] = None,
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
event_content = None
if "reason" in content:
event_content = {"reason": content["reason"]}
if RoomID.is_valid(room_identifier):
room_id = room_identifier
# twisted.web.server.Request.args is incorrectly defined as Optional[Any]
args: Dict[bytes, List[bytes]] = request.args # type: ignore
remote_room_hosts = parse_strings_from_args(
args, "server_name", required=False
)
elif RoomAlias.is_valid(room_identifier):
handler = self.room_member_handler
room_alias = RoomAlias.from_string(room_identifier)
room_id_obj, remote_room_hosts = await handler.lookup_room_alias(room_alias)
room_id = room_id_obj.to_string()
else:
raise SynapseError(
400, "%s was not legal room ID or room alias" % (room_identifier,)
)
await self.room_member_handler.update_membership(
requester=requester,
target=requester.user,
room_id=room_id,
action=Membership.KNOCK,
txn_id=txn_id,
third_party_signed=None,
remote_room_hosts=remote_room_hosts,
content=event_content,
)
return 200, {"room_id": room_id}
def on_PUT(self, request: Request, room_identifier: str, txn_id: str):
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_identifier, txn_id
)
def register_servlets(hs, http_server):
KnockRoomAliasServlet(hs).register(http_server)

View file

@ -85,7 +85,7 @@ class IdTokenServlet(RestServlet):
"access_token": token,
"token_type": "Bearer",
"matrix_server_name": self.server_name,
"expires_in": self.EXPIRES_MS / 1000,
"expires_in": self.EXPIRES_MS // 1000,
},
)

View file

@ -11,12 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import logging
from typing import TYPE_CHECKING, Tuple
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple
from synapse.api.constants import PresenceState
from synapse.api.constants import Membership, PresenceState
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection
from synapse.events.utils import (
@ -24,7 +23,7 @@ from synapse.events.utils import (
format_event_raw,
)
from synapse.handlers.presence import format_user_presence_state
from synapse.handlers.sync import SyncConfig
from synapse.handlers.sync import KnockedSyncResult, SyncConfig
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict, StreamToken
@ -220,6 +219,10 @@ class SyncRestServlet(RestServlet):
sync_result.invited, time_now, access_token_id, event_formatter
)
knocked = await self.encode_knocked(
sync_result.knocked, time_now, access_token_id, event_formatter
)
archived = await self.encode_archived(
sync_result.archived,
time_now,
@ -237,11 +240,16 @@ class SyncRestServlet(RestServlet):
"left": list(sync_result.device_lists.left),
},
"presence": SyncRestServlet.encode_presence(sync_result.presence, time_now),
"rooms": {"join": joined, "invite": invited, "leave": archived},
"rooms": {
Membership.JOIN: joined,
Membership.INVITE: invited,
Membership.KNOCK: knocked,
Membership.LEAVE: archived,
},
"groups": {
"join": sync_result.groups.join,
"invite": sync_result.groups.invite,
"leave": sync_result.groups.leave,
Membership.JOIN: sync_result.groups.join,
Membership.INVITE: sync_result.groups.invite,
Membership.LEAVE: sync_result.groups.leave,
},
"device_one_time_keys_count": sync_result.device_one_time_keys_count,
"org.matrix.msc2732.device_unused_fallback_key_types": sync_result.device_unused_fallback_key_types,
@ -303,7 +311,7 @@ class SyncRestServlet(RestServlet):
Args:
rooms(list[synapse.handlers.sync.InvitedSyncResult]): list of
sync results for rooms this user is joined to
sync results for rooms this user is invited to
time_now(int): current time - used as a baseline for age
calculations
token_id(int): ID of the user's auth token - used for namespacing
@ -322,7 +330,7 @@ class SyncRestServlet(RestServlet):
time_now,
token_id=token_id,
event_format=event_formatter,
is_invite=True,
include_stripped_room_state=True,
)
unsigned = dict(invite.get("unsigned", {}))
invite["unsigned"] = unsigned
@ -332,6 +340,60 @@ class SyncRestServlet(RestServlet):
return invited
async def encode_knocked(
self,
rooms: List[KnockedSyncResult],
time_now: int,
token_id: int,
event_formatter: Callable[[Dict], Dict],
) -> Dict[str, Dict[str, Any]]:
"""
Encode the rooms we've knocked on in a sync result.
Args:
rooms: list of sync results for rooms this user is knocking on
time_now: current time - used as a baseline for age calculations
token_id: ID of the user's auth token - used for namespacing of transaction IDs
event_formatter: function to convert from federation format to client format
Returns:
The list of rooms the user has knocked on, in our response format.
"""
knocked = {}
for room in rooms:
knock = await self._event_serializer.serialize_event(
room.knock,
time_now,
token_id=token_id,
event_format=event_formatter,
include_stripped_room_state=True,
)
# Extract the `unsigned` key from the knock event.
# This is where we (cheekily) store the knock state events
unsigned = knock.setdefault("unsigned", {})
# Duplicate the dictionary in order to avoid modifying the original
unsigned = dict(unsigned)
# Extract the stripped room state from the unsigned dict
# This is for clients to get a little bit of information about
# the room they've knocked on, without revealing any sensitive information
knocked_state = list(unsigned.pop("knock_room_state", []))
# Append the actual knock membership event itself as well. This provides
# the client with:
#
# * A knock state event that they can use for easier internal tracking
# * The rough timestamp of when the knock occurred contained within the event
knocked_state.append(knock)
# Build the `knock_state` dictionary, which will contain the state of the
# room that the client has knocked on
knocked[room.room_id] = {"knock_state": {"events": knocked_state}}
return knocked
async def encode_archived(
self, rooms, time_now, token_id, event_fields, event_formatter
):

View file

@ -19,7 +19,7 @@ from abc import abstractmethod
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple
from synapse.api.constants import EventTypes
from synapse.api.constants import EventTypes, JoinRules
from synapse.api.errors import StoreError
from synapse.api.room_versions import RoomVersion, RoomVersions
from synapse.storage._base import SQLBaseStore, db_to_json
@ -177,11 +177,13 @@ class RoomWorkerStore(SQLBaseStore):
INNER JOIN room_stats_current USING (room_id)
WHERE
(
join_rules = 'public' OR history_visibility = 'world_readable'
join_rules = 'public' OR join_rules = '%(knock_join_rule)s'
OR history_visibility = 'world_readable'
)
AND joined_members > 0
""" % {
"published_sql": published_sql
"published_sql": published_sql,
"knock_join_rule": JoinRules.KNOCK,
}
txn.execute(sql, query_args)
@ -303,7 +305,7 @@ class RoomWorkerStore(SQLBaseStore):
sql = """
SELECT
room_id, name, topic, canonical_alias, joined_members,
avatar, history_visibility, joined_members, guest_access
avatar, history_visibility, guest_access, join_rules
FROM (
%(published_sql)s
) published
@ -311,7 +313,8 @@ class RoomWorkerStore(SQLBaseStore):
INNER JOIN room_stats_current USING (room_id)
WHERE
(
join_rules = 'public' OR history_visibility = 'world_readable'
join_rules = 'public' OR join_rules = '%(knock_join_rule)s'
OR history_visibility = 'world_readable'
)
AND joined_members > 0
%(where_clause)s
@ -320,6 +323,7 @@ class RoomWorkerStore(SQLBaseStore):
"published_sql": published_sql,
"where_clause": where_clause,
"dir": "DESC" if forwards else "ASC",
"knock_join_rule": JoinRules.KNOCK,
}
if limit is not None:

View file

@ -41,6 +41,7 @@ ABSOLUTE_STATS_FIELDS = {
"current_state_events",
"joined_members",
"invited_members",
"knocked_members",
"left_members",
"banned_members",
"local_users_in_room",

View file

@ -16,9 +16,24 @@
import itertools
import logging
from collections import deque, namedtuple
from typing import Collection, Dict, Iterable, List, Optional, Set, Tuple
from collections import deque
from typing import (
Any,
Awaitable,
Callable,
Collection,
Deque,
Dict,
Generic,
Iterable,
List,
Optional,
Set,
Tuple,
TypeVar,
)
import attr
from prometheus_client import Counter, Histogram
from twisted.internet import defer
@ -26,6 +41,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.logging import opentracing
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.databases import Databases
@ -37,7 +53,7 @@ from synapse.types import (
StateMap,
get_domain_from_id,
)
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.async_helpers import ObservableDeferred, yieldable_gather_results
from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
@ -89,25 +105,53 @@ times_pruned_extremities = Counter(
)
class _EventPeristenceQueue:
@attr.s(auto_attribs=True, slots=True)
class _EventPersistQueueItem:
events_and_contexts: List[Tuple[EventBase, EventContext]]
backfilled: bool
deferred: ObservableDeferred
parent_opentracing_span_contexts: List = []
"""A list of opentracing spans waiting for this batch"""
opentracing_span_context: Any = None
"""The opentracing span under which the persistence actually happened"""
_PersistResult = TypeVar("_PersistResult")
class _EventPeristenceQueue(Generic[_PersistResult]):
"""Queues up events so that they can be persisted in bulk with only one
concurrent transaction per room.
"""
_EventPersistQueueItem = namedtuple(
"_EventPersistQueueItem", ("events_and_contexts", "backfilled", "deferred")
)
def __init__(
self,
per_item_callback: Callable[
[List[Tuple[EventBase, EventContext]], bool],
Awaitable[_PersistResult],
],
):
"""Create a new event persistence queue
def __init__(self):
self._event_persist_queues = {}
self._currently_persisting_rooms = set()
The per_item_callback will be called for each item added via add_to_queue,
and its result will be returned via the Deferreds returned from add_to_queue.
"""
self._event_persist_queues: Dict[str, Deque[_EventPersistQueueItem]] = {}
self._currently_persisting_rooms: Set[str] = set()
self._per_item_callback = per_item_callback
def add_to_queue(self, room_id, events_and_contexts, backfilled):
async def add_to_queue(
self,
room_id: str,
events_and_contexts: Iterable[Tuple[EventBase, EventContext]],
backfilled: bool,
) -> _PersistResult:
"""Add events to the queue, with the given persist_event options.
NB: due to the normal usage pattern of this method, it does *not*
follow the synapse logcontext rules, and leaves the logcontext in
place whether or not the returned deferred is ready.
If we are not already processing events in this room, starts off a background
process to to so, calling the per_item_callback for each item.
Args:
room_id (str):
@ -115,38 +159,54 @@ class _EventPeristenceQueue:
backfilled (bool):
Returns:
defer.Deferred: a deferred which will resolve once the events are
persisted. Runs its callbacks *without* a logcontext. The result
is the same as that returned by the callback passed to
`handle_queue`.
the result returned by the `_per_item_callback` passed to
`__init__`.
"""
queue = self._event_persist_queues.setdefault(room_id, deque())
if queue:
# if the last item in the queue has the same `backfilled` setting,
# we can just add these new events to that item.
# if the last item in the queue has the same `backfilled` setting,
# we can just add these new events to that item.
if queue and queue[-1].backfilled == backfilled:
end_item = queue[-1]
if end_item.backfilled == backfilled:
end_item.events_and_contexts.extend(events_and_contexts)
return end_item.deferred.observe()
else:
# need to make a new queue item
deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
queue.append(
self._EventPersistQueueItem(
events_and_contexts=events_and_contexts,
end_item = _EventPersistQueueItem(
events_and_contexts=[],
backfilled=backfilled,
deferred=deferred,
)
)
queue.append(end_item)
return deferred.observe()
# add our events to the queue item
end_item.events_and_contexts.extend(events_and_contexts)
def handle_queue(self, room_id, per_item_callback):
# also add our active opentracing span to the item so that we get a link back
span = opentracing.active_span()
if span:
end_item.parent_opentracing_span_contexts.append(span.context)
# start a processor for the queue, if there isn't one already
self._handle_queue(room_id)
# wait for the queue item to complete
res = await make_deferred_yieldable(end_item.deferred.observe())
# add another opentracing span which links to the persist trace.
with opentracing.start_active_span_follows_from(
"persist_event_batch_complete", (end_item.opentracing_span_context,)
):
pass
return res
def _handle_queue(self, room_id):
"""Attempts to handle the queue for a room if not already being handled.
The given callback will be invoked with for each item in the queue,
The queue's callback will be invoked with for each item in the queue,
of type _EventPersistQueueItem. The per_item_callback will continuously
be called with new items, unless the queue becomnes empty. The return
be called with new items, unless the queue becomes empty. The return
value of the function will be given to the deferreds waiting on the item,
exceptions will be passed to the deferreds as well.
@ -156,7 +216,6 @@ class _EventPeristenceQueue:
If another callback is currently handling the queue then it will not be
invoked.
"""
if room_id in self._currently_persisting_rooms:
return
@ -167,7 +226,17 @@ class _EventPeristenceQueue:
queue = self._get_drainining_queue(room_id)
for item in queue:
try:
ret = await per_item_callback(item)
with opentracing.start_active_span_follows_from(
"persist_event_batch",
item.parent_opentracing_span_contexts,
inherit_force_tracing=True,
) as scope:
if scope:
item.opentracing_span_context = scope.span.context
ret = await self._per_item_callback(
item.events_and_contexts, item.backfilled
)
except Exception:
with PreserveLoggingContext():
item.deferred.errback()
@ -214,9 +283,10 @@ class EventsPersistenceStorage:
self._clock = hs.get_clock()
self._instance_name = hs.get_instance_name()
self.is_mine_id = hs.is_mine_id
self._event_persist_queue = _EventPeristenceQueue()
self._event_persist_queue = _EventPeristenceQueue(self._persist_event_batch)
self._state_resolution_handler = hs.get_state_resolution_handler()
@opentracing.trace
async def persist_events(
self,
events_and_contexts: Iterable[Tuple[EventBase, EventContext]],
@ -241,26 +311,21 @@ class EventsPersistenceStorage:
for event, ctx in events_and_contexts:
partitioned.setdefault(event.room_id, []).append((event, ctx))
deferreds = []
for room_id, evs_ctxs in partitioned.items():
d = self._event_persist_queue.add_to_queue(
async def enqueue(item):
room_id, evs_ctxs = item
return await self._event_persist_queue.add_to_queue(
room_id, evs_ctxs, backfilled=backfilled
)
deferreds.append(d)
for room_id in partitioned:
self._maybe_start_persisting(room_id)
ret_vals = await yieldable_gather_results(enqueue, partitioned.items())
# Each deferred returns a map from event ID to existing event ID if the
# event was deduplicated. (The dict may also include other entries if
# Each call to add_to_queue returns a map from event ID to existing event ID if
# the event was deduplicated. (The dict may also include other entries if
# the event was persisted in a batch with other events).
#
# Since we use `defer.gatherResults` we need to merge the returned list
# Since we use `yieldable_gather_results` we need to merge the returned list
# of dicts into one.
ret_vals = await make_deferred_yieldable(
defer.gatherResults(deferreds, consumeErrors=True)
)
replaced_events = {}
replaced_events: Dict[str, str] = {}
for d in ret_vals:
replaced_events.update(d)
@ -277,6 +342,7 @@ class EventsPersistenceStorage:
self.main_store.get_room_max_token(),
)
@opentracing.trace
async def persist_event(
self, event: EventBase, context: EventContext, backfilled: bool = False
) -> Tuple[EventBase, PersistedEventPosition, RoomStreamToken]:
@ -287,16 +353,12 @@ class EventsPersistenceStorage:
event if it was deduplicated due to an existing event matching the
transaction ID.
"""
deferred = self._event_persist_queue.add_to_queue(
event.room_id, [(event, context)], backfilled=backfilled
)
self._maybe_start_persisting(event.room_id)
# The deferred returns a map from event ID to existing event ID if the
# add_to_queue returns a map from event ID to existing event ID if the
# event was deduplicated. (The dict may also include other entries if
# the event was persisted in a batch with other events.)
replaced_events = await make_deferred_yieldable(deferred)
replaced_events = await self._event_persist_queue.add_to_queue(
event.room_id, [(event, context)], backfilled=backfilled
)
replaced_event = replaced_events.get(event.event_id)
if replaced_event:
event = await self.main_store.get_event(replaced_event)
@ -308,29 +370,14 @@ class EventsPersistenceStorage:
pos = PersistedEventPosition(self._instance_name, event_stream_id)
return event, pos, self.main_store.get_room_max_token()
def _maybe_start_persisting(self, room_id: str):
"""Pokes the `_event_persist_queue` to start handling new items in the
queue, if not already in progress.
Causes the deferreds returned by `add_to_queue` to resolve with: a
dictionary of event ID to event ID we didn't persist as we already had
another event persisted with the same TXN ID.
"""
async def persisting_queue(item):
with Measure(self._clock, "persist_events"):
return await self._persist_events(
item.events_and_contexts, backfilled=item.backfilled
)
self._event_persist_queue.handle_queue(room_id, persisting_queue)
async def _persist_events(
async def _persist_event_batch(
self,
events_and_contexts: List[Tuple[EventBase, EventContext]],
backfilled: bool = False,
) -> Dict[str, str]:
"""Calculates the change to current state and forward extremities, and
"""Callback for the _event_persist_queue
Calculates the change to current state and forward extremities, and
persists the given events and with those updates.
Returns:

View file

@ -1,5 +1,4 @@
# Copyright 2014 - 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# Copyright 2014 - 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -26,7 +25,7 @@ from synapse.config.homeserver import HomeServerConfig
from synapse.storage.database import LoggingDatabaseConnection
from synapse.storage.engines import BaseDatabaseEngine
from synapse.storage.engines.postgres import PostgresEngine
from synapse.storage.schema import SCHEMA_VERSION
from synapse.storage.schema import SCHEMA_COMPAT_VERSION, SCHEMA_VERSION
from synapse.storage.types import Cursor
logger = logging.getLogger(__name__)
@ -59,6 +58,28 @@ UNAPPLIED_DELTA_ON_WORKER_ERROR = (
)
@attr.s
class _SchemaState:
current_version: int = attr.ib()
"""The current schema version of the database"""
compat_version: Optional[int] = attr.ib()
"""The SCHEMA_VERSION of the oldest version of Synapse for this database
If this is None, we have an old version of the database without the necessary
table.
"""
applied_deltas: Collection[str] = attr.ib(factory=tuple)
"""Any delta files for `current_version` which have already been applied"""
upgraded: bool = attr.ib(default=False)
"""Whether the current state was reached by applying deltas.
If False, we have run the full schema for `current_version`, and have applied no
deltas since. If True, we have run some deltas since the original creation."""
def prepare_database(
db_conn: LoggingDatabaseConnection,
database_engine: BaseDatabaseEngine,
@ -96,12 +117,11 @@ def prepare_database(
version_info = _get_or_create_schema_state(cur, database_engine)
if version_info:
user_version, delta_files, upgraded = version_info
logger.info(
"%r: Existing schema is %i (+%i deltas)",
databases,
user_version,
len(delta_files),
version_info.current_version,
len(version_info.applied_deltas),
)
# config should only be None when we are preparing an in-memory SQLite db,
@ -113,16 +133,18 @@ def prepare_database(
# if it's a worker app, refuse to upgrade the database, to avoid multiple
# workers doing it at once.
if config.worker_app is not None and user_version != SCHEMA_VERSION:
if (
config.worker_app is not None
and version_info.current_version != SCHEMA_VERSION
):
raise UpgradeDatabaseException(
OUTDATED_SCHEMA_ON_WORKER_ERROR % (SCHEMA_VERSION, user_version)
OUTDATED_SCHEMA_ON_WORKER_ERROR
% (SCHEMA_VERSION, version_info.current_version)
)
_upgrade_existing_database(
cur,
user_version,
delta_files,
upgraded,
version_info,
database_engine,
config,
databases=databases,
@ -261,9 +283,7 @@ def _setup_new_database(
_upgrade_existing_database(
cur,
current_version=max_current_ver,
applied_delta_files=[],
upgraded=False,
_SchemaState(current_version=max_current_ver, compat_version=None),
database_engine=database_engine,
config=None,
databases=databases,
@ -273,9 +293,7 @@ def _setup_new_database(
def _upgrade_existing_database(
cur: Cursor,
current_version: int,
applied_delta_files: List[str],
upgraded: bool,
current_schema_state: _SchemaState,
database_engine: BaseDatabaseEngine,
config: Optional[HomeServerConfig],
databases: Collection[str],
@ -321,12 +339,8 @@ def _upgrade_existing_database(
Args:
cur
current_version: The current version of the schema.
applied_delta_files: A list of deltas that have already been applied.
upgraded: Whether the current version was generated by having
applied deltas or from full schema file. If `True` the function
will never apply delta files for the given `current_version`, since
the current_version wasn't generated by applying those delta files.
current_schema_state: The current version of the schema, as
returned by _get_or_create_schema_state
database_engine
config:
None if we are initialising a blank database, otherwise the application
@ -337,13 +351,16 @@ def _upgrade_existing_database(
upgrade portions of the delta scripts.
"""
if is_empty:
assert not applied_delta_files
assert not current_schema_state.applied_deltas
else:
assert config
is_worker = config and config.worker_app is not None
if current_version > SCHEMA_VERSION:
if (
current_schema_state.compat_version is not None
and current_schema_state.compat_version > SCHEMA_VERSION
):
raise ValueError(
"Cannot use this database as it is too "
+ "new for the server to understand"
@ -357,14 +374,26 @@ def _upgrade_existing_database(
assert config is not None
check_database_before_upgrade(cur, database_engine, config)
start_ver = current_version
# update schema_compat_version before we run any upgrades, so that if synapse
# gets downgraded again, it won't try to run against the upgraded database.
if (
current_schema_state.compat_version is None
or current_schema_state.compat_version < SCHEMA_COMPAT_VERSION
):
cur.execute("DELETE FROM schema_compat_version")
cur.execute(
"INSERT INTO schema_compat_version(compat_version) VALUES (?)",
(SCHEMA_COMPAT_VERSION,),
)
start_ver = current_schema_state.current_version
# if we got to this schema version by running a full_schema rather than a series
# of deltas, we should not run the deltas for this version.
if not upgraded:
if not current_schema_state.upgraded:
start_ver += 1
logger.debug("applied_delta_files: %s", applied_delta_files)
logger.debug("applied_delta_files: %s", current_schema_state.applied_deltas)
if isinstance(database_engine, PostgresEngine):
specific_engine_extension = ".postgres"
@ -440,7 +469,7 @@ def _upgrade_existing_database(
absolute_path = entry.absolute_path
logger.debug("Found file: %s (%s)", relative_path, absolute_path)
if relative_path in applied_delta_files:
if relative_path in current_schema_state.applied_deltas:
continue
root_name, ext = os.path.splitext(file_name)
@ -621,7 +650,7 @@ def execute_statements_from_stream(cur: Cursor, f: TextIO) -> None:
def _get_or_create_schema_state(
txn: Cursor, database_engine: BaseDatabaseEngine
) -> Optional[Tuple[int, List[str], bool]]:
) -> Optional[_SchemaState]:
# Bluntly try creating the schema_version tables.
sql_path = os.path.join(schema_path, "common", "schema_version.sql")
executescript(txn, sql_path)
@ -629,17 +658,31 @@ def _get_or_create_schema_state(
txn.execute("SELECT version, upgraded FROM schema_version")
row = txn.fetchone()
if row is not None:
current_version = int(row[0])
txn.execute(
"SELECT file FROM applied_schema_deltas WHERE version >= ?",
(current_version,),
)
applied_deltas = [d for d, in txn]
upgraded = bool(row[1])
return current_version, applied_deltas, upgraded
if row is None:
# new database
return None
return None
current_version = int(row[0])
upgraded = bool(row[1])
compat_version: Optional[int] = None
txn.execute("SELECT compat_version FROM schema_compat_version")
row = txn.fetchone()
if row is not None:
compat_version = int(row[0])
txn.execute(
"SELECT file FROM applied_schema_deltas WHERE version >= ?",
(current_version,),
)
applied_deltas = tuple(d for d, in txn)
return _SchemaState(
current_version=current_version,
compat_version=compat_version,
applied_deltas=applied_deltas,
upgraded=upgraded,
)
@attr.s(slots=True)

View file

@ -1,37 +1,4 @@
# Synapse Database Schemas
This directory contains the schema files used to build Synapse databases.
Synapse supports splitting its datastore across multiple physical databases (which can
be useful for large installations), and the schema files are therefore split according
to the logical database they are apply to.
At the time of writing, the following "logical" databases are supported:
* `state` - used to store Matrix room state (more specifically, `state_groups`,
their relationships and contents.)
* `main` - stores everything else.
Addionally, the `common` directory contains schema files for tables which must be
present on *all* physical databases.
## Full schema dumps
In the `full_schemas` directories, only the most recently-numbered snapshot is useful
(`54` at the time of writing). Older snapshots (eg, `16`) are present for historical
reference only.
## Building full schema dumps
If you want to recreate these schemas, they need to be made from a database that
has had all background updates run.
To do so, use `scripts-dev/make_full_schema.sh`. This will produce new
`full.sql.postgres` and `full.sql.sqlite` files.
Ensure postgres is installed, then run:
./scripts-dev/make_full_schema.sh -p postgres_username -o output_dir/
NB at the time of writing, this script predates the split into separate `state`/`main`
databases so will require updates to handle that correctly.
This directory contains the schema files used to build Synapse databases. For more
information, see /docs/development/database_schema.md.

View file

@ -12,6 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 59
"""Represents the expectations made by the codebase about the database schema
This should be incremented whenever the codebase changes its requirements on the
shape of the database schema (even if those requirements are backwards-compatible with
older versions of Synapse).
See `README.md <synapse/storage/schema/README.md>`_ for more information on how this
works.
"""
SCHEMA_COMPAT_VERSION = 59
"""Limit on how far the synapse codebase can be rolled back without breaking db compat
This value is stored in the database, and checked on startup. If the value in the
database is greater than SCHEMA_VERSION, then Synapse will refuse to start.
"""

View file

@ -20,6 +20,13 @@ CREATE TABLE IF NOT EXISTS schema_version(
CHECK (Lock='X')
);
CREATE TABLE IF NOT EXISTS schema_compat_version(
Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row.
-- The SCHEMA_VERSION of the oldest synapse this database can be used with
compat_version INTEGER NOT NULL,
CHECK (Lock='X')
);
CREATE TABLE IF NOT EXISTS applied_schema_deltas(
version INTEGER NOT NULL,
file TEXT NOT NULL,

View file

@ -0,0 +1,17 @@
/* Copyright 2020 Sorunome
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
ALTER TABLE room_stats_current ADD COLUMN knocked_members INT NOT NULL DEFAULT '0';
ALTER TABLE room_stats_historical ADD COLUMN knocked_members BIGINT NOT NULL DEFAULT '0';

View file

@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Callable, Dict, Generic, Optional, TypeVar
from typing import Any, Awaitable, Callable, Dict, Generic, Optional, TypeVar
import attr
from twisted.internet import defer
@ -23,10 +25,36 @@ from synapse.util.caches import register_cache
logger = logging.getLogger(__name__)
T = TypeVar("T")
# the type of the key in the cache
KV = TypeVar("KV")
# the type of the result from the operation
RV = TypeVar("RV")
class ResponseCache(Generic[T]):
@attr.s(auto_attribs=True)
class ResponseCacheContext(Generic[KV]):
"""Information about a missed ResponseCache hit
This object can be passed into the callback for additional feedback
"""
cache_key: KV
"""The cache key that caused the cache miss
This should be considered read-only.
TODO: in attrs 20.1, make it frozen with an on_setattr.
"""
should_cache: bool = True
"""Whether the result should be cached once the request completes.
This can be modified by the callback if it decides its result should not be cached.
"""
class ResponseCache(Generic[KV]):
"""
This caches a deferred response. Until the deferred completes it will be
returned from the cache. This means that if the client retries the request
@ -35,8 +63,10 @@ class ResponseCache(Generic[T]):
"""
def __init__(self, clock: Clock, name: str, timeout_ms: float = 0):
# Requests that haven't finished yet.
self.pending_result_cache = {} # type: Dict[T, ObservableDeferred]
# This is poorly-named: it includes both complete and incomplete results.
# We keep complete results rather than switching to absolute values because
# that makes it easier to cache Failure results.
self.pending_result_cache = {} # type: Dict[KV, ObservableDeferred]
self.clock = clock
self.timeout_sec = timeout_ms / 1000.0
@ -50,16 +80,13 @@ class ResponseCache(Generic[T]):
def __len__(self) -> int:
return self.size()
def get(self, key: T) -> Optional[defer.Deferred]:
def get(self, key: KV) -> Optional[defer.Deferred]:
"""Look up the given key.
Can return either a new Deferred (which also doesn't follow the synapse
logcontext rules), or, if the request has completed, the actual
result. You will probably want to make_deferred_yieldable the result.
Returns a new Deferred (which also doesn't follow the synapse
logcontext rules). You will probably want to make_deferred_yieldable the result.
If there is no entry for the key, returns None. It is worth noting that
this means there is no way to distinguish a completed result of None
from an absent cache entry.
If there is no entry for the key, returns None.
Args:
key: key to get/set in the cache
@ -76,42 +103,56 @@ class ResponseCache(Generic[T]):
self._metrics.inc_misses()
return None
def set(self, key: T, deferred: defer.Deferred) -> defer.Deferred:
def _set(
self, context: ResponseCacheContext[KV], deferred: defer.Deferred
) -> defer.Deferred:
"""Set the entry for the given key to the given deferred.
*deferred* should run its callbacks in the sentinel logcontext (ie,
you should wrap normal synapse deferreds with
synapse.logging.context.run_in_background).
Can return either a new Deferred (which also doesn't follow the synapse
logcontext rules), or, if *deferred* was already complete, the actual
result. You will probably want to make_deferred_yieldable the result.
Returns a new Deferred (which also doesn't follow the synapse logcontext rules).
You will probably want to make_deferred_yieldable the result.
Args:
key: key to get/set in the cache
context: Information about the cache miss
deferred: The deferred which resolves to the result.
Returns:
A new deferred which resolves to the actual result.
"""
result = ObservableDeferred(deferred, consumeErrors=True)
key = context.cache_key
self.pending_result_cache[key] = result
def remove(r):
if self.timeout_sec:
def on_complete(r):
# if this cache has a non-zero timeout, and the callback has not cleared
# the should_cache bit, we leave it in the cache for now and schedule
# its removal later.
if self.timeout_sec and context.should_cache:
self.clock.call_later(
self.timeout_sec, self.pending_result_cache.pop, key, None
)
else:
# otherwise, remove the result immediately.
self.pending_result_cache.pop(key, None)
return r
result.addBoth(remove)
# make sure we do this *after* adding the entry to pending_result_cache,
# in case the result is already complete (in which case flipping the order would
# leave us with a stuck entry in the cache).
result.addBoth(on_complete)
return result.observe()
def wrap(
self, key: T, callback: Callable[..., Any], *args: Any, **kwargs: Any
) -> defer.Deferred:
async def wrap(
self,
key: KV,
callback: Callable[..., Awaitable[RV]],
*args: Any,
cache_context: bool = False,
**kwargs: Any,
) -> RV:
"""Wrap together a *get* and *set* call, taking care of logcontexts
First looks up the key in the cache, and if it is present makes it
@ -140,22 +181,28 @@ class ResponseCache(Generic[T]):
*args: positional parameters to pass to the callback, if it is used
cache_context: if set, the callback will be given a `cache_context` kw arg,
which will be a ResponseCacheContext object.
**kwargs: named parameters to pass to the callback, if it is used
Returns:
Deferred which resolves to the result
The result of the callback (from the cache, or otherwise)
"""
result = self.get(key)
if not result:
logger.debug(
"[%s]: no cached result for [%s], calculating new one", self._name, key
)
context = ResponseCacheContext(cache_key=key)
if cache_context:
kwargs["cache_context"] = context
d = run_in_background(callback, *args, **kwargs)
result = self.set(key, d)
result = self._set(context, d)
elif not isinstance(result, defer.Deferred) or result.called:
logger.info("[%s]: using completed cached result for [%s]", self._name, key)
else:
logger.info(
"[%s]: using incomplete cached result for [%s]", self._name, key
)
return make_deferred_yieldable(result)
return await make_deferred_yieldable(result)

View file

@ -133,12 +133,17 @@ class Measure:
self.start = self.clock.time()
self._logging_context.__enter__()
in_flight.register((self.name,), self._update_in_flight)
logger.debug("Entering block %s", self.name)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.start is None:
raise RuntimeError("Measure() block exited without being entered")
logger.debug("Exiting block %s", self.name)
duration = self.clock.time() - self.start
usage = self.get_resource_usage()

View file

@ -0,0 +1,298 @@
# Copyright 2020 Matrix.org Federation C.I.C
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
from typing import Dict, List
from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.api.room_versions import RoomVersions
from synapse.events import builder
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
from synapse.server import HomeServer
from synapse.types import RoomAlias
from tests.test_utils import event_injection
from tests.unittest import FederatingHomeserverTestCase, TestCase, override_config
class KnockingStrippedStateEventHelperMixin(TestCase):
def send_example_state_events_to_room(
self,
hs: "HomeServer",
room_id: str,
sender: str,
) -> OrderedDict:
"""Adds some state to a room. State events are those that should be sent to a knocking
user after they knock on the room, as well as some state that *shouldn't* be sent
to the knocking user.
Args:
hs: The homeserver of the sender.
room_id: The ID of the room to send state into.
sender: The ID of the user to send state as. Must be in the room.
Returns:
The OrderedDict of event types and content that a user is expected to see
after knocking on a room.
"""
# To set a canonical alias, we'll need to point an alias at the room first.
canonical_alias = "#fancy_alias:test"
self.get_success(
self.store.create_room_alias_association(
RoomAlias.from_string(canonical_alias), room_id, ["test"]
)
)
# Send some state that we *don't* expect to be given to knocking users
self.get_success(
event_injection.inject_event(
hs,
room_version=RoomVersions.V7.identifier,
room_id=room_id,
sender=sender,
type="com.example.secret",
state_key="",
content={"secret": "password"},
)
)
# We use an OrderedDict here to ensure that the knock membership appears last.
# Note that order only matters when sending stripped state to clients, not federated
# homeservers.
room_state = OrderedDict(
[
# We need to set the room's join rules to allow knocking
(
EventTypes.JoinRules,
{"content": {"join_rule": JoinRules.KNOCK}, "state_key": ""},
),
# Below are state events that are to be stripped and sent to clients
(
EventTypes.Name,
{"content": {"name": "A cool room"}, "state_key": ""},
),
(
EventTypes.RoomAvatar,
{
"content": {
"info": {
"h": 398,
"mimetype": "image/jpeg",
"size": 31037,
"w": 394,
},
"url": "mxc://example.org/JWEIFJgwEIhweiWJE",
},
"state_key": "",
},
),
(
EventTypes.RoomEncryption,
{"content": {"algorithm": "m.megolm.v1.aes-sha2"}, "state_key": ""},
),
(
EventTypes.CanonicalAlias,
{
"content": {"alias": canonical_alias, "alt_aliases": []},
"state_key": "",
},
),
]
)
for event_type, event_dict in room_state.items():
event_content = event_dict["content"]
state_key = event_dict["state_key"]
self.get_success(
event_injection.inject_event(
hs,
room_version=RoomVersions.V7.identifier,
room_id=room_id,
sender=sender,
type=event_type,
state_key=state_key,
content=event_content,
)
)
# Finally, we expect to see the m.room.create event of the room as part of the
# stripped state. We don't need to inject this event though.
room_state[EventTypes.Create] = {
"content": {
"creator": sender,
"room_version": RoomVersions.V7.identifier,
},
"state_key": "",
}
return room_state
def check_knock_room_state_against_room_state(
self,
knock_room_state: List[Dict],
expected_room_state: Dict,
) -> None:
"""Test a list of stripped room state events received over federation against a
dict of expected state events.
Args:
knock_room_state: The list of room state that was received over federation.
expected_room_state: A dict containing the room state we expect to see in
`knock_room_state`.
"""
for event in knock_room_state:
event_type = event["type"]
# Check that this event type is one of those that we expected.
# Note: This will also check that no excess state was included
self.assertIn(event_type, expected_room_state)
# Check the state content matches
self.assertEquals(
expected_room_state[event_type]["content"], event["content"]
)
# Check the state key is correct
self.assertEqual(
expected_room_state[event_type]["state_key"], event["state_key"]
)
# Ensure the event has been stripped
self.assertNotIn("signatures", event)
# Pop once we've found and processed a state event
expected_room_state.pop(event_type)
# Check that all expected state events were accounted for
self.assertEqual(len(expected_room_state), 0)
class FederationKnockingTestCase(
FederatingHomeserverTestCase, KnockingStrippedStateEventHelperMixin
):
servlets = [
admin.register_servlets,
room.register_servlets,
login.register_servlets,
]
def prepare(self, reactor, clock, homeserver):
self.store = homeserver.get_datastore()
# We're not going to be properly signing events as our remote homeserver is fake,
# therefore disable event signature checks.
# Note that these checks are not relevant to this test case.
# Have this homeserver auto-approve all event signature checking.
async def approve_all_signature_checking(_, pdu):
return pdu
homeserver.get_federation_server()._check_sigs_and_hash = (
approve_all_signature_checking
)
# Have this homeserver skip event auth checks. This is necessary due to
# event auth checks ensuring that events were signed by the sender's homeserver.
async def _check_event_auth(
origin, event, context, state, auth_events, backfilled
):
return context
homeserver.get_federation_handler()._check_event_auth = _check_event_auth
return super().prepare(reactor, clock, homeserver)
@override_config({"experimental_features": {"msc2403_enabled": True}})
def test_room_state_returned_when_knocking(self):
"""
Tests that specific, stripped state events from a room are returned after
a remote homeserver successfully knocks on a local room.
"""
user_id = self.register_user("u1", "you the one")
user_token = self.login("u1", "you the one")
fake_knocking_user_id = "@user:other.example.com"
# Create a room with a room version that includes knocking
room_id = self.helper.create_room_as(
"u1",
is_public=False,
room_version=RoomVersions.V7.identifier,
tok=user_token,
)
# Update the join rules and add additional state to the room to check for later
expected_room_state = self.send_example_state_events_to_room(
self.hs, room_id, user_id
)
channel = self.make_request(
"GET",
"/_matrix/federation/v1/make_knock/%s/%s?ver=%s"
% (
room_id,
fake_knocking_user_id,
# Inform the remote that we support the room version of the room we're
# knocking on
RoomVersions.V7.identifier,
),
)
self.assertEquals(200, channel.code, channel.result)
# Note: We don't expect the knock membership event to be sent over federation as
# part of the stripped room state, as the knocking homeserver already has that
# event. It is only done for clients during /sync
# Extract the generated knock event json
knock_event = channel.json_body["event"]
# Check that the event has things we expect in it
self.assertEquals(knock_event["room_id"], room_id)
self.assertEquals(knock_event["sender"], fake_knocking_user_id)
self.assertEquals(knock_event["state_key"], fake_knocking_user_id)
self.assertEquals(knock_event["type"], EventTypes.Member)
self.assertEquals(knock_event["content"]["membership"], Membership.KNOCK)
# Turn the event json dict into a proper event.
# We won't sign it properly, but that's OK as we stub out event auth in `prepare`
signed_knock_event = builder.create_local_event_from_event_dict(
self.clock,
self.hs.hostname,
self.hs.signing_key,
room_version=RoomVersions.V7,
event_dict=knock_event,
)
# Convert our proper event back to json dict format
signed_knock_event_json = signed_knock_event.get_pdu_json(
self.clock.time_msec()
)
# Send the signed knock event into the room
channel = self.make_request(
"PUT",
"/_matrix/federation/v1/send_knock/%s/%s"
% (room_id, signed_knock_event.event_id),
signed_knock_event_json,
)
self.assertEquals(200, channel.code, channel.result)
# Check that we got the stripped room state in return
room_state_events = channel.json_body["knock_state_events"]
# Validate the stripped room state events
self.check_knock_room_state_against_room_state(
room_state_events, expected_room_state
)

View file

@ -257,7 +257,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys2))
devices = self.get_success(
self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user)
self.handler.query_devices(
{"device_keys": {local_user: []}}, 0, local_user, "device123"
)
)
self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
@ -357,7 +359,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
device_key_1["signatures"][local_user]["ed25519:abc"] = "base64+signature"
device_key_2["signatures"][local_user]["ed25519:def"] = "base64+signature"
devices = self.get_success(
self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user)
self.handler.query_devices(
{"device_keys": {local_user: []}}, 0, local_user, "device123"
)
)
del devices["device_keys"][local_user]["abc"]["unsigned"]
del devices["device_keys"][local_user]["def"]["unsigned"]
@ -591,7 +595,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
# fetch the signed keys/devices and make sure that the signatures are there
ret = self.get_success(
self.handler.query_devices(
{"device_keys": {local_user: [], other_user: []}}, 0, local_user
{"device_keys": {local_user: [], other_user: []}},
0,
local_user,
"device123",
)
)

View file

@ -17,10 +17,14 @@ import json
import synapse.rest.admin
from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import read_marker, sync
from synapse.rest.client.v2_alpha import knock, read_marker, sync
from tests import unittest
from tests.federation.transport.test_knocking import (
KnockingStrippedStateEventHelperMixin,
)
from tests.server import TimedOutException
from tests.unittest import override_config
class FilterTestCase(unittest.HomeserverTestCase):
@ -305,6 +309,93 @@ class SyncTypingTests(unittest.HomeserverTestCase):
self.make_request("GET", sync_url % (access_token, next_batch))
class SyncKnockTestCase(
unittest.HomeserverTestCase, KnockingStrippedStateEventHelperMixin
):
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
room.register_servlets,
sync.register_servlets,
knock.register_servlets,
]
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.url = "/sync?since=%s"
self.next_batch = "s0"
# Register the first user (used to create the room to knock on).
self.user_id = self.register_user("kermit", "monkey")
self.tok = self.login("kermit", "monkey")
# Create the room we'll knock on.
self.room_id = self.helper.create_room_as(
self.user_id,
is_public=False,
room_version="7",
tok=self.tok,
)
# Register the second user (used to knock on the room).
self.knocker = self.register_user("knocker", "monkey")
self.knocker_tok = self.login("knocker", "monkey")
# Perform an initial sync for the knocking user.
channel = self.make_request(
"GET",
self.url % self.next_batch,
access_token=self.tok,
)
self.assertEqual(channel.code, 200, channel.json_body)
# Store the next batch for the next request.
self.next_batch = channel.json_body["next_batch"]
# Set up some room state to test with.
self.expected_room_state = self.send_example_state_events_to_room(
hs, self.room_id, self.user_id
)
@override_config({"experimental_features": {"msc2403_enabled": True}})
def test_knock_room_state(self):
"""Tests that /sync returns state from a room after knocking on it."""
# Knock on a room
channel = self.make_request(
"POST",
"/_matrix/client/r0/knock/%s" % (self.room_id,),
b"{}",
self.knocker_tok,
)
self.assertEquals(200, channel.code, channel.result)
# We expect to see the knock event in the stripped room state later
self.expected_room_state[EventTypes.Member] = {
"content": {"membership": "knock", "displayname": "knocker"},
"state_key": "@knocker:test",
}
# Check that /sync includes stripped state from the room
channel = self.make_request(
"GET",
self.url % self.next_batch,
access_token=self.knocker_tok,
)
self.assertEqual(channel.code, 200, channel.json_body)
# Extract the stripped room state events from /sync
knock_entry = channel.json_body["rooms"]["knock"]
room_state_events = knock_entry[self.room_id]["knock_state"]["events"]
# Validate that the knock membership event came last
self.assertEqual(room_state_events[-1]["type"], EventTypes.Member)
# Validate the stripped room state events
self.check_knock_room_state_against_room_state(
room_state_events, self.expected_room_state
)
class UnreadMessagesTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets,
@ -447,7 +538,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
)
self._check_unread_count(5)
def _check_unread_count(self, expected_count: True):
def _check_unread_count(self, expected_count: int):
"""Syncs and compares the unread count with the expected value."""
channel = self.make_request(

View file

@ -11,14 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from parameterized import parameterized
from synapse.util.caches.response_cache import ResponseCache
from twisted.internet import defer
from synapse.util.caches.response_cache import ResponseCache, ResponseCacheContext
from tests.server import get_clock
from tests.unittest import TestCase
class DeferredCacheTestCase(TestCase):
class ResponseCacheTestCase(TestCase):
"""
A TestCase class for ResponseCache.
@ -48,7 +51,9 @@ class DeferredCacheTestCase(TestCase):
expected_result = "howdy"
wrap_d = cache.wrap(0, self.instant_return, expected_result)
wrap_d = defer.ensureDeferred(
cache.wrap(0, self.instant_return, expected_result)
)
self.assertEqual(
expected_result,
@ -66,7 +71,9 @@ class DeferredCacheTestCase(TestCase):
expected_result = "howdy"
wrap_d = cache.wrap(0, self.instant_return, expected_result)
wrap_d = defer.ensureDeferred(
cache.wrap(0, self.instant_return, expected_result)
)
self.assertEqual(
expected_result,
@ -80,7 +87,9 @@ class DeferredCacheTestCase(TestCase):
expected_result = "howdy"
wrap_d = cache.wrap(0, self.instant_return, expected_result)
wrap_d = defer.ensureDeferred(
cache.wrap(0, self.instant_return, expected_result)
)
self.assertEqual(expected_result, self.successResultOf(wrap_d))
self.assertEqual(
@ -99,7 +108,10 @@ class DeferredCacheTestCase(TestCase):
expected_result = "howdy"
wrap_d = cache.wrap(0, self.delayed_return, expected_result)
wrap_d = defer.ensureDeferred(
cache.wrap(0, self.delayed_return, expected_result)
)
self.assertNoResult(wrap_d)
# function wakes up, returns result
@ -112,7 +124,9 @@ class DeferredCacheTestCase(TestCase):
expected_result = "howdy"
wrap_d = cache.wrap(0, self.delayed_return, expected_result)
wrap_d = defer.ensureDeferred(
cache.wrap(0, self.delayed_return, expected_result)
)
self.assertNoResult(wrap_d)
# stop at 1 second to callback cache eviction callLater at that time, then another to set time at 2
@ -129,3 +143,50 @@ class DeferredCacheTestCase(TestCase):
self.reactor.pump((2,))
self.assertIsNone(cache.get(0), "cache should not have the result now")
@parameterized.expand([(True,), (False,)])
def test_cache_context_nocache(self, should_cache: bool):
"""If the callback clears the should_cache bit, the result should not be cached"""
cache = self.with_cache("medium_cache", ms=3000)
expected_result = "howdy"
call_count = 0
async def non_caching(o: str, cache_context: ResponseCacheContext[int]):
nonlocal call_count
call_count += 1
await self.clock.sleep(1)
cache_context.should_cache = should_cache
return o
wrap_d = defer.ensureDeferred(
cache.wrap(0, non_caching, expected_result, cache_context=True)
)
# there should be no result to start with
self.assertNoResult(wrap_d)
# a second call should also return a pending deferred
wrap2_d = defer.ensureDeferred(
cache.wrap(0, non_caching, expected_result, cache_context=True)
)
self.assertNoResult(wrap2_d)
# and there should have been exactly one call
self.assertEqual(call_count, 1)
# let the call complete
self.reactor.advance(1)
# both results should have completed
self.assertEqual(expected_result, self.successResultOf(wrap_d))
self.assertEqual(expected_result, self.successResultOf(wrap2_d))
if should_cache:
self.assertEqual(
expected_result,
self.successResultOf(cache.get(0)),
"cache should still have the result",
)
else:
self.assertIsNone(cache.get(0), "cache should not have the result")