0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-14 07:13:46 +01:00

Merge branch 'develop' of github.com:matrix-org/synapse into erikj/add_rate_limiting_to_joins

This commit is contained in:
Erik Johnston 2020-07-31 15:07:01 +01:00
commit faba873d4b
121 changed files with 1937 additions and 1256 deletions

View file

@ -1,10 +1,12 @@
- [Choosing your server name](#choosing-your-server-name) - [Choosing your server name](#choosing-your-server-name)
- [Picking a database engine](#picking-a-database-engine)
- [Installing Synapse](#installing-synapse) - [Installing Synapse](#installing-synapse)
- [Installing from source](#installing-from-source) - [Installing from source](#installing-from-source)
- [Platform-Specific Instructions](#platform-specific-instructions) - [Platform-Specific Instructions](#platform-specific-instructions)
- [Prebuilt packages](#prebuilt-packages) - [Prebuilt packages](#prebuilt-packages)
- [Setting up Synapse](#setting-up-synapse) - [Setting up Synapse](#setting-up-synapse)
- [TLS certificates](#tls-certificates) - [TLS certificates](#tls-certificates)
- [Client Well-Known URI](#client-well-known-uri)
- [Email](#email) - [Email](#email)
- [Registering a user](#registering-a-user) - [Registering a user](#registering-a-user)
- [Setting up a TURN server](#setting-up-a-turn-server) - [Setting up a TURN server](#setting-up-a-turn-server)
@ -27,6 +29,25 @@ that your email address is probably `user@example.com` rather than
`user@email.example.com`) - but doing so may require more advanced setup: see `user@email.example.com`) - but doing so may require more advanced setup: see
[Setting up Federation](docs/federate.md). [Setting up Federation](docs/federate.md).
# Picking a database engine
Synapse offers two database engines:
* [PostgreSQL](https://www.postgresql.org)
* [SQLite](https://sqlite.org/)
Almost all installations should opt to use PostgreSQL. Advantages include:
* significant performance improvements due to the superior threading and
caching model, smarter query optimiser
* allowing the DB to be run on separate hardware
For information on how to install and use PostgreSQL, please see
[docs/postgres.md](docs/postgres.md)
By default Synapse uses SQLite and in doing so trades performance for convenience.
SQLite is only recommended in Synapse for testing purposes or for servers with
light workloads.
# Installing Synapse # Installing Synapse
## Installing from source ## Installing from source
@ -234,9 +255,9 @@ for a number of platforms.
There is an offical synapse image available at There is an offical synapse image available at
https://hub.docker.com/r/matrixdotorg/synapse which can be used with https://hub.docker.com/r/matrixdotorg/synapse which can be used with
the docker-compose file available at [contrib/docker](contrib/docker). Further information on the docker-compose file available at [contrib/docker](contrib/docker). Further
this including configuration options is available in the README on information on this including configuration options is available in the README
hub.docker.com. on hub.docker.com.
Alternatively, Andreas Peters (previously Silvio Fricke) has contributed a Alternatively, Andreas Peters (previously Silvio Fricke) has contributed a
Dockerfile to automate a synapse server in a single Docker image, at Dockerfile to automate a synapse server in a single Docker image, at
@ -244,7 +265,8 @@ https://hub.docker.com/r/avhost/docker-matrix/tags/
Slavi Pantaleev has created an Ansible playbook, Slavi Pantaleev has created an Ansible playbook,
which installs the offical Docker image of Matrix Synapse which installs the offical Docker image of Matrix Synapse
along with many other Matrix-related services (Postgres database, riot-web, coturn, mxisd, SSL support, etc.). along with many other Matrix-related services (Postgres database, Element, coturn,
ma1sd, SSL support, etc.).
For more details, see For more details, see
https://github.com/spantaleev/matrix-docker-ansible-deploy https://github.com/spantaleev/matrix-docker-ansible-deploy
@ -277,22 +299,27 @@ The fingerprint of the repository signing key (as shown by `gpg
/usr/share/keyrings/matrix-org-archive-keyring.gpg`) is /usr/share/keyrings/matrix-org-archive-keyring.gpg`) is
`AAF9AE843A7584B5A3E4CD2BCF45A512DE2DA058`. `AAF9AE843A7584B5A3E4CD2BCF45A512DE2DA058`.
#### Downstream Debian/Ubuntu packages #### Downstream Debian packages
For `buster` and `sid`, Synapse is available in the Debian repositories and We do not recommend using the packages from the default Debian `buster`
it should be possible to install it with simply: repository at this time, as they are old and suffer from known security
vulnerabilities. You can install the latest version of Synapse from
[our repository](#matrixorg-packages) or from `buster-backports`. Please
see the [Debian documentation](https://backports.debian.org/Instructions/)
for information on how to use backports.
If you are using Debian `sid` or testing, Synapse is available in the default
repositories and it should be possible to install it simply with:
``` ```
sudo apt install matrix-synapse sudo apt install matrix-synapse
``` ```
There is also a version of `matrix-synapse` in `stretch-backports`. Please see #### Downstream Ubuntu packages
the [Debian documentation on
backports](https://backports.debian.org/Instructions/) for information on how
to use them.
We do not recommend using the packages in downstream Ubuntu at this time, as We do not recommend using the packages in the default Ubuntu repository
they are old and suffer from known security vulnerabilities. at this time, as they are old and suffer from known security vulnerabilities.
The latest version of Synapse can be installed from [our repository](#matrixorg-packages).
### Fedora ### Fedora
@ -419,6 +446,60 @@ so, you will need to edit `homeserver.yaml`, as follows:
For a more detailed guide to configuring your server for federation, see For a more detailed guide to configuring your server for federation, see
[federate.md](docs/federate.md). [federate.md](docs/federate.md).
## Client Well-Known URI
Setting up the client Well-Known URI is optional but if you set it up, it will
allow users to enter their full username (e.g. `@user:<server_name>`) into clients
which support well-known lookup to automatically configure the homeserver and
identity server URLs. This is useful so that users don't have to memorize or think
about the actual homeserver URL you are using.
The URL `https://<server_name>/.well-known/matrix/client` should return JSON in
the following format.
```
{
"m.homeserver": {
"base_url": "https://<matrix.example.com>"
}
}
```
It can optionally contain identity server information as well.
```
{
"m.homeserver": {
"base_url": "https://<matrix.example.com>"
},
"m.identity_server": {
"base_url": "https://<identity.example.com>"
}
}
```
To work in browser based clients, the file must be served with the appropriate
Cross-Origin Resource Sharing (CORS) headers. A recommended value would be
`Access-Control-Allow-Origin: *` which would allow all browser based clients to
view it.
In nginx this would be something like:
```
location /.well-known/matrix/client {
return 200 '{"m.homeserver": {"base_url": "https://<matrix.example.com>"}}';
add_header Content-Type application/json;
add_header Access-Control-Allow-Origin *;
}
```
You should also ensure the `public_baseurl` option in `homeserver.yaml` is set
correctly. `public_baseurl` should be set to the URL that clients will use to
connect to your server. This is the same URL you put for the `m.homeserver`
`base_url` above.
```
public_baseurl: "https://<matrix.example.com>"
```
## Email ## Email
@ -437,7 +518,7 @@ email will be disabled.
## Registering a user ## Registering a user
The easiest way to create a new user is to do so from a client like [Riot](https://riot.im). The easiest way to create a new user is to do so from a client like [Element](https://element.io/).
Alternatively you can do so from the command line if you have installed via pip. Alternatively you can do so from the command line if you have installed via pip.

View file

@ -45,7 +45,7 @@ which handle:
- Eventually-consistent cryptographically secure synchronisation of room - Eventually-consistent cryptographically secure synchronisation of room
state across a global open network of federated servers and services state across a global open network of federated servers and services
- Sending and receiving extensible messages in a room with (optional) - Sending and receiving extensible messages in a room with (optional)
end-to-end encryption[1] end-to-end encryption
- Inviting, joining, leaving, kicking, banning room members - Inviting, joining, leaving, kicking, banning room members
- Managing user accounts (registration, login, logout) - Managing user accounts (registration, login, logout)
- Using 3rd Party IDs (3PIDs) such as email addresses, phone numbers, - Using 3rd Party IDs (3PIDs) such as email addresses, phone numbers,
@ -82,9 +82,6 @@ at the `Matrix spec <https://matrix.org/docs/spec>`_, and experiment with the
Thanks for using Matrix! Thanks for using Matrix!
[1] End-to-end encryption is currently in beta: `blog post <https://matrix.org/blog/2016/11/21/matrixs-olm-end-to-end-encryption-security-assessment-released-and-implemented-cross-platform-on-riot-at-last>`_.
Support Support
======= =======
@ -115,12 +112,11 @@ Unless you are running a test instance of Synapse on your local machine, in
general, you will need to enable TLS support before you can successfully general, you will need to enable TLS support before you can successfully
connect from a client: see `<INSTALL.md#tls-certificates>`_. connect from a client: see `<INSTALL.md#tls-certificates>`_.
An easy way to get started is to login or register via Riot at An easy way to get started is to login or register via Element at
https://riot.im/app/#/login or https://riot.im/app/#/register respectively. https://app.element.io/#/login or https://app.element.io/#/register respectively.
You will need to change the server you are logging into from ``matrix.org`` You will need to change the server you are logging into from ``matrix.org``
and instead specify a Homeserver URL of ``https://<server_name>:8448`` and instead specify a Homeserver URL of ``https://<server_name>:8448``
(or just ``https://<server_name>`` if you are using a reverse proxy). (or just ``https://<server_name>`` if you are using a reverse proxy).
(Leave the identity server as the default - see `Identity servers`_.)
If you prefer to use another client, refer to our If you prefer to use another client, refer to our
`client breakdown <https://matrix.org/docs/projects/clients-matrix>`_. `client breakdown <https://matrix.org/docs/projects/clients-matrix>`_.
@ -137,7 +133,7 @@ it, specify ``enable_registration: true`` in ``homeserver.yaml``. (It is then
recommended to also set up CAPTCHA - see `<docs/CAPTCHA_SETUP.md>`_.) recommended to also set up CAPTCHA - see `<docs/CAPTCHA_SETUP.md>`_.)
Once ``enable_registration`` is set to ``true``, it is possible to register a Once ``enable_registration`` is set to ``true``, it is possible to register a
user via `riot.im <https://riot.im/app/#/register>`_ or other Matrix clients. user via a Matrix client.
Your new user name will be formed partly from the ``server_name``, and partly Your new user name will be formed partly from the ``server_name``, and partly
from a localpart you specify when you create the account. Your name will take from a localpart you specify when you create the account. Your name will take
@ -183,30 +179,6 @@ versions of synapse.
.. _UPGRADE.rst: UPGRADE.rst .. _UPGRADE.rst: UPGRADE.rst
Using PostgreSQL
================
Synapse offers two database engines:
* `PostgreSQL <https://www.postgresql.org>`_
* `SQLite <https://sqlite.org/>`_
Almost all installations should opt to use PostgreSQL. Advantages include:
* significant performance improvements due to the superior threading and
caching model, smarter query optimiser
* allowing the DB to be run on separate hardware
* allowing basic active/backup high-availability with a "hot spare" synapse
pointing at the same DB master, as well as enabling DB replication in
synapse itself.
For information on how to install and use PostgreSQL, please see
`docs/postgres.md <docs/postgres.md>`_.
By default Synapse uses SQLite and in doing so trades performance for convenience.
SQLite is only recommended in Synapse for testing purposes or for servers with
light workloads.
.. _reverse-proxy: .. _reverse-proxy:
Using a reverse proxy with Synapse Using a reverse proxy with Synapse
@ -255,10 +227,9 @@ email address.
Password reset Password reset
============== ==============
If a user has registered an email address to their account using an identity Users can reset their password through their client. Alternatively, a server admin
server, they can request a password-reset token via clients such as Riot. can reset a users password using the `admin API <docs/admin_api/user_admin_api.rst#reset-password>`_
or by directly editing the database as shown below.
A manual password reset can be done via direct database access as follows.
First calculate the hash of the new password:: First calculate the hash of the new password::

1
changelog.d/7736.feature Normal file
View file

@ -0,0 +1 @@
Add unread messages count to sync responses, as specified in [MSC2654](https://github.com/matrix-org/matrix-doc/pull/2654).

1
changelog.d/7899.doc Normal file
View file

@ -0,0 +1 @@
Document how to set up a Client Well-Known file and fix several pieces of outdated documentation.

1
changelog.d/7902.feature Normal file
View file

@ -0,0 +1 @@
Add option to allow server admins to join rooms which fail complexity checks. Contributed by @lugino-emeritus.

1
changelog.d/7936.misc Normal file
View file

@ -0,0 +1 @@
Switch to the JSON implementation from the standard library and bump the minimum version of the canonicaljson library to 1.2.0.

1
changelog.d/7947.misc Normal file
View file

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

1
changelog.d/7948.misc Normal file
View file

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

1
changelog.d/7949.misc Normal file
View file

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

1
changelog.d/7951.misc Normal file
View file

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

1
changelog.d/7952.misc Normal file
View file

@ -0,0 +1 @@
Move some database-related log lines from the default logger to the database/transaction loggers.

1
changelog.d/7963.misc Normal file
View file

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

1
changelog.d/7964.feature Normal file
View file

@ -0,0 +1 @@
Add an option to purge room or not with delete room admin endpoint (`POST /_synapse/admin/v1/rooms/<room_id>/delete`). Contributed by @dklimpel.

1
changelog.d/7965.misc Normal file
View file

@ -0,0 +1 @@
Add a script to detect source code files using non-unix line terminators.

1
changelog.d/7970.misc Normal file
View file

@ -0,0 +1 @@
Add a script to detect source code files using non-unix line terminators.

1
changelog.d/7971.misc Normal file
View file

@ -0,0 +1 @@
Log the SAML session ID during creation.

1
changelog.d/7973.misc Normal file
View file

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

1
changelog.d/7975.misc Normal file
View file

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

1
changelog.d/7976.misc Normal file
View file

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

1
changelog.d/7978.bugfix Normal file
View file

@ -0,0 +1 @@
Fix a long standing bug: 'Duplicate key value violates unique constraint "event_relations_id"' when message retention is configured.

1
changelog.d/7979.misc Normal file
View file

@ -0,0 +1 @@
Switch to the JSON implementation from the standard library and bump the minimum version of the canonicaljson library to 1.2.0.

1
changelog.d/7980.bugfix Normal file
View file

@ -0,0 +1 @@
Fix "no create event in auth events" when trying to reject invitation after inviter leaves. Bug introduced in Synapse v1.10.0.

1
changelog.d/7981.misc Normal file
View file

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

1
changelog.d/7990.doc Normal file
View file

@ -0,0 +1 @@
Improve workers docs.

1
changelog.d/7992.doc Normal file
View file

@ -0,0 +1 @@
Fix typo in `docs/workers.md`.

1
changelog.d/7998.doc Normal file
View file

@ -0,0 +1 @@
Add documentation for how to undo a room shutdown.

View file

@ -609,13 +609,15 @@ class SynapseCmd(cmd.Cmd):
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_event_stream(self, timeout): def _do_event_stream(self, timeout):
res = yield self.http_client.get_json( res = yield defer.ensureDeferred(
self._url() + "/events", self.http_client.get_json(
{ self._url() + "/events",
"access_token": self._tok(), {
"timeout": str(timeout), "access_token": self._tok(),
"from": self.event_stream_token, "timeout": str(timeout),
}, "from": self.event_stream_token,
},
)
) )
print(json.dumps(res, indent=4)) print(json.dumps(res, indent=4))

10
debian/changelog vendored
View file

@ -1,3 +1,13 @@
matrix-synapse-py3 (1.xx.0) stable; urgency=medium
[ Synapse Packaging team ]
* New synapse release 1.xx.0.
[ Aaron Raimist ]
* Fix outdated documentation for SYNAPSE_CACHE_FACTOR
-- Synapse Packaging team <packages@matrix.org> XXXXX
matrix-synapse-py3 (1.18.0) stable; urgency=medium matrix-synapse-py3 (1.18.0) stable; urgency=medium
* New synapse release 1.18.0. * New synapse release 1.18.0.

View file

@ -1,2 +1,2 @@
# Specify environment variables used when running Synapse # Specify environment variables used when running Synapse
# SYNAPSE_CACHE_FACTOR=1 (default) # SYNAPSE_CACHE_FACTOR=0.5 (default)

27
debian/synctl.ronn vendored
View file

@ -46,19 +46,20 @@ Configuration file may be generated as follows:
## ENVIRONMENT ## ENVIRONMENT
* `SYNAPSE_CACHE_FACTOR`: * `SYNAPSE_CACHE_FACTOR`:
Synapse's architecture is quite RAM hungry currently - a lot of Synapse's architecture is quite RAM hungry currently - we deliberately
recent room data and metadata is deliberately cached in RAM in cache a lot of recent room data and metadata in RAM in order to speed up
order to speed up common requests. This will be improved in common requests. We'll improve this in the future, but for now the easiest
future, but for now the easiest way to either reduce the RAM usage way to either reduce the RAM usage (at the risk of slowing things down)
(at the risk of slowing things down) is to set the is to set the almost-undocumented ``SYNAPSE_CACHE_FACTOR`` environment
SYNAPSE_CACHE_FACTOR environment variable. Roughly speaking, a variable. The default is 0.5, which can be decreased to reduce RAM usage
SYNAPSE_CACHE_FACTOR of 1.0 will max out at around 3-4GB of in memory constrained enviroments, or increased if performance starts to
resident memory - this is what we currently run the matrix.org degrade.
on. The default setting is currently 0.1, which is probably around
a ~700MB footprint. You can dial it down further to 0.02 if However, degraded performance due to a low cache factor, common on
desired, which targets roughly ~512MB. Conversely you can dial it machines with slow disks, often leads to explosions in memory use due
up if you need performance for lots of users and have a box with a backlogged requests. In this case, reducing the cache factor will make
lot of RAM. things worse. Instead, try increasing it drastically. 2.0 is a good
starting value.
## COPYRIGHT ## COPYRIGHT

View file

@ -10,5 +10,16 @@
# homeserver.yaml. Instead, if you are starting from scratch, please generate # homeserver.yaml. Instead, if you are starting from scratch, please generate
# a fresh config using Synapse by following the instructions in INSTALL.md. # a fresh config using Synapse by following the instructions in INSTALL.md.
# Configuration options that take a time period can be set using a number
# followed by a letter. Letters have the following meanings:
# s = second
# m = minute
# h = hour
# d = day
# w = week
# y = year
# For example, setting redaction_retention_period: 5m would remove redacted
# messages from the database after 5 minutes, rather than 5 months.
################################################################################ ################################################################################

View file

@ -369,7 +369,9 @@ to the new room will have power level `-10` by default, and thus be unable to sp
If `block` is `True` it prevents new joins to the old room. If `block` is `True` it prevents new joins to the old room.
This API will remove all trace of the old room from your database after removing This API will remove all trace of the old room from your database after removing
all local users. all local users. If `purge` is `true` (the default), all traces of the old room will
be removed from your database after removing all local users. If you do not want
this to happen, set `purge` to `false`.
Depending on the amount of history being purged a call to the API may take Depending on the amount of history being purged a call to the API may take
several minutes or longer. several minutes or longer.
@ -388,7 +390,8 @@ with a body of:
"new_room_user_id": "@someuser:example.com", "new_room_user_id": "@someuser:example.com",
"room_name": "Content Violation Notification", "room_name": "Content Violation Notification",
"message": "Bad Room has been shutdown due to content violations on this server. Please review our Terms of Service.", "message": "Bad Room has been shutdown due to content violations on this server. Please review our Terms of Service.",
"block": true "block": true,
"purge": true
} }
``` ```
@ -430,8 +433,10 @@ The following JSON body parameters are available:
`new_room_user_id` in the new room. Ideally this will clearly convey why the `new_room_user_id` in the new room. Ideally this will clearly convey why the
original room was shut down. Defaults to `Sharing illegal content on this server original room was shut down. Defaults to `Sharing illegal content on this server
is not permitted and rooms in violation will be blocked.` is not permitted and rooms in violation will be blocked.`
* `block` - Optional. If set to `true`, this room will be added to a blocking list, preventing future attempts to * `block` - Optional. If set to `true`, this room will be added to a blocking list, preventing
join the room. Defaults to `false`. future attempts to join the room. Defaults to `false`.
* `purge` - Optional. If set to `true`, it will remove all traces of the room from your database.
Defaults to `true`.
The JSON body must not be empty. The body must be at least `{}`. The JSON body must not be empty. The body must be at least `{}`.

View file

@ -33,7 +33,7 @@ You will need to authenticate with an access token for an admin user.
* `message` - Optional. A string containing the first message that will be sent as * `message` - Optional. A string containing the first message that will be sent as
`new_room_user_id` in the new room. Ideally this will clearly convey why the `new_room_user_id` in the new room. Ideally this will clearly convey why the
original room was shut down. original room was shut down.
If not specified, the default value of `room_name` is "Content Violation If not specified, the default value of `room_name` is "Content Violation
Notification". The default value of `message` is "Sharing illegal content on Notification". The default value of `message` is "Sharing illegal content on
othis server is not permitted and rooms in violation will be blocked." othis server is not permitted and rooms in violation will be blocked."
@ -72,3 +72,23 @@ Response:
"new_room_id": "!newroomid:example.com", "new_room_id": "!newroomid:example.com",
}, },
``` ```
## Undoing room shutdowns
*Note*: This guide may be outdated by the time you read it. By nature of room shutdowns being performed at the database level,
the structure can and does change without notice.
First, it's important to understand that a room shutdown is very destructive. Undoing a shutdown is not as simple as pretending it
never happened - work has to be done to move forward instead of resetting the past.
1. For safety reasons, it is recommended to shut down Synapse prior to continuing.
2. In the database, run `DELETE FROM blocked_rooms WHERE room_id = '!example:example.org';`
* For caution: it's recommended to run this in a transaction: `BEGIN; DELETE ...;`, verify you got 1 result, then `COMMIT;`.
* The room ID is the same one supplied to the shutdown room API, not the Content Violation room.
3. Restart Synapse (required).
You will have to manually handle, if you so choose, the following:
* Aliases that would have been redirected to the Content Violation room.
* Users that would have been booted from the room (and will have been force-joined to the Content Violation room).
* Removal of the Content Violation room if desired.

View file

@ -27,7 +27,7 @@
different thread to Synapse. This can make it more resilient to different thread to Synapse. This can make it more resilient to
heavy load meaning metrics cannot be retrieved, and can be exposed heavy load meaning metrics cannot be retrieved, and can be exposed
to just internal networks easier. The served metrics are available to just internal networks easier. The served metrics are available
over HTTP only, and will be available at `/`. over HTTP only, and will be available at `/_synapse/metrics`.
Add a new listener to homeserver.yaml: Add a new listener to homeserver.yaml:

View file

@ -188,6 +188,9 @@ to do step 2.
It is safe to at any time kill the port script and restart it. It is safe to at any time kill the port script and restart it.
Note that the database may take up significantly more (25% - 100% more)
space on disk after porting to Postgres.
### Using the port script ### Using the port script
Firstly, shut down the currently running synapse server and copy its Firstly, shut down the currently running synapse server and copy its

View file

@ -10,6 +10,17 @@
# homeserver.yaml. Instead, if you are starting from scratch, please generate # homeserver.yaml. Instead, if you are starting from scratch, please generate
# a fresh config using Synapse by following the instructions in INSTALL.md. # a fresh config using Synapse by following the instructions in INSTALL.md.
# Configuration options that take a time period can be set using a number
# followed by a letter. Letters have the following meanings:
# s = second
# m = minute
# h = hour
# d = day
# w = week
# y = year
# For example, setting redaction_retention_period: 5m would remove redacted
# messages from the database after 5 minutes, rather than 5 months.
################################################################################ ################################################################################
# Configuration file for Synapse. # Configuration file for Synapse.
@ -314,6 +325,10 @@ limit_remote_rooms:
# #
#complexity_error: "This room is too complex." #complexity_error: "This room is too complex."
# allow server admins to join complex rooms. Default is false.
#
#admins_can_join: true
# Whether to require a user to be in the room to add an alias to it. # Whether to require a user to be in the room to add an alias to it.
# Defaults to 'true'. # Defaults to 'true'.
# #
@ -1157,24 +1172,6 @@ account_validity:
# #
#default_identity_server: https://matrix.org #default_identity_server: https://matrix.org
# The list of identity servers trusted to verify third party
# identifiers by this server.
#
# Also defines the ID server which will be called when an account is
# deactivated (one will be picked arbitrarily).
#
# Note: This option is deprecated. Since v0.99.4, Synapse has tracked which identity
# server a 3PID has been bound to. For 3PIDs bound before then, Synapse runs a
# background migration script, informing itself that the identity server all of its
# 3PIDs have been bound to is likely one of the below.
#
# As of Synapse v1.4.0, all other functionality of this option has been deprecated, and
# it is now solely used for the purposes of the background migration script, and can be
# removed once it has run.
#trusted_third_party_id_servers:
# - matrix.org
# - vector.im
# Handle threepid (email/phone etc) registration and password resets through a set of # Handle threepid (email/phone etc) registration and password resets through a set of
# *trusted* identity servers. Note that this allows the configured identity server to # *trusted* identity servers. Note that this allows the configured identity server to
# reset passwords for accounts! # reset passwords for accounts!

View file

@ -1,10 +1,10 @@
# Scaling synapse via workers # Scaling synapse via workers
For small instances it recommended to run Synapse in monolith mode (the For small instances it recommended to run Synapse in the default monolith mode.
default). For larger instances where performance is a concern it can be helpful For larger instances where performance is a concern it can be helpful to split
to split out functionality into multiple separate python processes. These out functionality into multiple separate python processes. These processes are
processes are called 'workers', and are (eventually) intended to scale called 'workers', and are (eventually) intended to scale horizontally
horizontally independently. independently.
Synapse's worker support is under active development and subject to change as Synapse's worker support is under active development and subject to change as
we attempt to rapidly scale ever larger Synapse instances. However we are we attempt to rapidly scale ever larger Synapse instances. However we are
@ -23,29 +23,30 @@ The processes communicate with each other via a Synapse-specific protocol called
feeds streams of newly written data between processes so they can be kept in feeds streams of newly written data between processes so they can be kept in
sync with the database state. sync with the database state.
Additionally, processes may make HTTP requests to each other. Typically this is When configured to do so, Synapse uses a
used for operations which need to wait for a reply - such as sending an event. [Redis pub/sub channel](https://redis.io/topics/pubsub) to send the replication
stream between all configured Synapse processes. Additionally, processes may
make HTTP requests to each other, primarily for operations which need to wait
for a reply ─ such as sending an event.
As of Synapse v1.13.0, it is possible to configure Synapse to send replication Redis support was added in v1.13.0 with it becoming the recommended method in
via a [Redis pub/sub channel](https://redis.io/topics/pubsub), and is now the v1.18.0. It replaced the old direct TCP connections (which is deprecated as of
recommended way of configuring replication. This is an alternative to the old v1.18.0) to the main process. With Redis, rather than all the workers connecting
direct TCP connections to the main process: rather than all the workers to the main process, all the workers and the main process connect to Redis,
connecting to the main process, all the workers and the main process connect to which relays replication commands between processes. This can give a significant
Redis, which relays replication commands between processes. This can give a cpu saving on the main process and will be a prerequisite for upcoming
significant cpu saving on the main process and will be a prerequisite for performance improvements.
upcoming performance improvements.
(See the [Architectural diagram](#architectural-diagram) section at the end for See the [Architectural diagram](#architectural-diagram) section at the end for
a visualisation of what this looks like) a visualisation of what this looks like.
## Setting up workers ## Setting up workers
A Redis server is required to manage the communication between the processes. A Redis server is required to manage the communication between the processes.
(The older direct TCP connections are now deprecated.) The Redis server The Redis server should be installed following the normal procedure for your
should be installed following the normal procedure for your distribution (e.g. distribution (e.g. `apt install redis-server` on Debian). It is safe to use an
`apt install redis-server` on Debian). It is safe to use an existing Redis existing Redis deployment if you have one.
deployment if you have one.
Once installed, check that Redis is running and accessible from the host running Once installed, check that Redis is running and accessible from the host running
Synapse, for example by executing `echo PING | nc -q1 localhost 6379` and seeing Synapse, for example by executing `echo PING | nc -q1 localhost 6379` and seeing
@ -65,8 +66,9 @@ https://hub.docker.com/r/matrixdotorg/synapse/.
To make effective use of the workers, you will need to configure an HTTP To make effective use of the workers, you will need to configure an HTTP
reverse-proxy such as nginx or haproxy, which will direct incoming requests to reverse-proxy such as nginx or haproxy, which will direct incoming requests to
the correct worker, or to the main synapse instance. See [reverse_proxy.md](reverse_proxy.md) the correct worker, or to the main synapse instance. See
for information on setting up a reverse proxy. [reverse_proxy.md](reverse_proxy.md) for information on setting up a reverse
proxy.
To enable workers you should create a configuration file for each worker To enable workers you should create a configuration file for each worker
process. Each worker configuration file inherits the configuration of the shared process. Each worker configuration file inherits the configuration of the shared
@ -75,8 +77,12 @@ that worker, e.g. the HTTP listener that it provides (if any); logging
configuration; etc. You should minimise the number of overrides though to configuration; etc. You should minimise the number of overrides though to
maintain a usable config. maintain a usable config.
Next you need to add both a HTTP replication listener and redis config to the
shared Synapse configuration file (`homeserver.yaml`). For example: ### Shared Configuration
Next you need to add both a HTTP replication listener, used for HTTP requests
between processes, and redis config to the shared Synapse configuration file
(`homeserver.yaml`). For example:
```yaml ```yaml
# extend the existing `listeners` section. This defines the ports that the # extend the existing `listeners` section. This defines the ports that the
@ -98,6 +104,9 @@ See the sample config for the full documentation of each option.
Under **no circumstances** should the replication listener be exposed to the Under **no circumstances** should the replication listener be exposed to the
public internet; it has no authentication and is unencrypted. public internet; it has no authentication and is unencrypted.
### Worker Configuration
In the config file for each worker, you must specify the type of worker In the config file for each worker, you must specify the type of worker
application (`worker_app`), and you should specify a unqiue name for the worker application (`worker_app`), and you should specify a unqiue name for the worker
(`worker_name`). The currently available worker applications are listed below. (`worker_name`). The currently available worker applications are listed below.
@ -278,7 +287,7 @@ instance_map:
host: localhost host: localhost
port: 8034 port: 8034
streams_writers: stream_writers:
events: event_persister1 events: event_persister1
``` ```

View file

@ -0,0 +1,34 @@
#!/bin/bash
#
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This script checks that line terminators in all repository files (excluding
# those in the .git directory) feature unix line terminators.
#
# Usage:
#
# ./check_line_terminators.sh
#
# The script will emit exit code 1 if any files that do not use unix line
# terminators are found, 0 otherwise.
# cd to the root of the repository
cd `dirname $0`/..
# Find and print files with non-unix line terminators
if find . -path './.git/*' -prune -o -type f -print0 | xargs -0 grep -I -l $'\r$'; then
echo -e '\e[31mERROR: found files with CRLF line endings. See above.\e[39m'
exit 1
fi

View file

@ -69,7 +69,7 @@ logger = logging.getLogger("synapse_port_db")
BOOLEAN_COLUMNS = { BOOLEAN_COLUMNS = {
"events": ["processed", "outlier", "contains_url"], "events": ["processed", "outlier", "contains_url", "count_as_unread"],
"rooms": ["is_public"], "rooms": ["is_public"],
"event_edges": ["is_state"], "event_edges": ["is_state"],
"presence_list": ["accepted"], "presence_list": ["accepted"],

View file

@ -17,6 +17,7 @@
""" This is a reference implementation of a Matrix homeserver. """ This is a reference implementation of a Matrix homeserver.
""" """
import json
import os import os
import sys import sys
@ -25,6 +26,9 @@ if sys.version_info < (3, 5):
print("Synapse requires Python 3.5 or above.") print("Synapse requires Python 3.5 or above.")
sys.exit(1) sys.exit(1)
# Twisted and canonicaljson will fail to import when this file is executed to
# get the __version__ during a fresh install. That's OK and subsequent calls to
# actually start Synapse will import these libraries fine.
try: try:
from twisted.internet import protocol from twisted.internet import protocol
from twisted.internet.protocol import Factory from twisted.internet.protocol import Factory
@ -36,6 +40,14 @@ try:
except ImportError: except ImportError:
pass pass
# Use the standard library json implementation instead of simplejson.
try:
from canonicaljson import set_json_library
set_json_library(json)
except ImportError:
pass
__version__ = "1.18.0" __version__ = "1.18.0"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):

View file

@ -82,7 +82,7 @@ class Auth(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def check_from_context(self, room_version: str, event, context, do_sig_check=True): def check_from_context(self, room_version: str, event, context, do_sig_check=True):
prev_state_ids = yield context.get_prev_state_ids() prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
auth_events_ids = yield self.compute_auth_events( auth_events_ids = yield self.compute_auth_events(
event, prev_state_ids, for_verification=True event, prev_state_ids, for_verification=True
) )

View file

@ -628,7 +628,7 @@ class GenericWorkerServer(HomeServer):
self.get_tcp_replication().start_replication(self) self.get_tcp_replication().start_replication(self)
def remove_pusher(self, app_id, push_key, user_id): async def remove_pusher(self, app_id, push_key, user_id):
self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id) self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id)
def build_replication_data_handler(self): def build_replication_data_handler(self):

View file

@ -15,11 +15,9 @@
import logging import logging
import re import re
from twisted.internet import defer
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.types import GroupID, get_domain_from_id from synapse.types import GroupID, get_domain_from_id
from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -43,7 +41,7 @@ class AppServiceTransaction(object):
Args: Args:
as_api(ApplicationServiceApi): The API to use to send. as_api(ApplicationServiceApi): The API to use to send.
Returns: Returns:
A Deferred which resolves to True if the transaction was sent. An Awaitable which resolves to True if the transaction was sent.
""" """
return as_api.push_bulk( return as_api.push_bulk(
service=self.service, events=self.events, txn_id=self.id service=self.service, events=self.events, txn_id=self.id
@ -172,8 +170,7 @@ class ApplicationService(object):
return regex_obj["exclusive"] return regex_obj["exclusive"]
return False return False
@defer.inlineCallbacks async def _matches_user(self, event, store):
def _matches_user(self, event, store):
if not event: if not event:
return False return False
@ -188,12 +185,12 @@ class ApplicationService(object):
if not store: if not store:
return False return False
does_match = yield self._matches_user_in_member_list(event.room_id, store) does_match = await self._matches_user_in_member_list(event.room_id, store)
return does_match return does_match
@cachedInlineCallbacks(num_args=1, cache_context=True) @cached(num_args=1, cache_context=True)
def _matches_user_in_member_list(self, room_id, store, cache_context): async def _matches_user_in_member_list(self, room_id, store, cache_context):
member_list = yield store.get_users_in_room( member_list = await store.get_users_in_room(
room_id, on_invalidate=cache_context.invalidate room_id, on_invalidate=cache_context.invalidate
) )
@ -208,35 +205,33 @@ class ApplicationService(object):
return self.is_interested_in_room(event.room_id) return self.is_interested_in_room(event.room_id)
return False return False
@defer.inlineCallbacks async def _matches_aliases(self, event, store):
def _matches_aliases(self, event, store):
if not store or not event: if not store or not event:
return False return False
alias_list = yield store.get_aliases_for_room(event.room_id) alias_list = await store.get_aliases_for_room(event.room_id)
for alias in alias_list: for alias in alias_list:
if self.is_interested_in_alias(alias): if self.is_interested_in_alias(alias):
return True return True
return False return False
@defer.inlineCallbacks async def is_interested(self, event, store=None) -> bool:
def is_interested(self, event, store=None):
"""Check if this service is interested in this event. """Check if this service is interested in this event.
Args: Args:
event(Event): The event to check. event(Event): The event to check.
store(DataStore) store(DataStore)
Returns: Returns:
bool: True if this service would like to know about this event. True if this service would like to know about this event.
""" """
# Do cheap checks first # Do cheap checks first
if self._matches_room_id(event): if self._matches_room_id(event):
return True return True
if (yield self._matches_aliases(event, store)): if await self._matches_aliases(event, store):
return True return True
if (yield self._matches_user(event, store)): if await self._matches_user(event, store):
return True return True
return False return False

View file

@ -93,13 +93,12 @@ class ApplicationServiceApi(SimpleHttpClient):
hs, "as_protocol_meta", timeout_ms=HOUR_IN_MS hs, "as_protocol_meta", timeout_ms=HOUR_IN_MS
) )
@defer.inlineCallbacks async def query_user(self, service, user_id):
def query_user(self, service, user_id):
if service.url is None: if service.url is None:
return False return False
uri = service.url + ("/users/%s" % urllib.parse.quote(user_id)) uri = service.url + ("/users/%s" % urllib.parse.quote(user_id))
try: try:
response = yield self.get_json(uri, {"access_token": service.hs_token}) response = await self.get_json(uri, {"access_token": service.hs_token})
if response is not None: # just an empty json object if response is not None: # just an empty json object
return True return True
except CodeMessageException as e: except CodeMessageException as e:
@ -110,14 +109,12 @@ class ApplicationServiceApi(SimpleHttpClient):
logger.warning("query_user to %s threw exception %s", uri, ex) logger.warning("query_user to %s threw exception %s", uri, ex)
return False return False
@defer.inlineCallbacks async def query_alias(self, service, alias):
def query_alias(self, service, alias):
if service.url is None: if service.url is None:
return False return False
uri = service.url + ("/rooms/%s" % urllib.parse.quote(alias)) uri = service.url + ("/rooms/%s" % urllib.parse.quote(alias))
response = None
try: try:
response = yield self.get_json(uri, {"access_token": service.hs_token}) response = await self.get_json(uri, {"access_token": service.hs_token})
if response is not None: # just an empty json object if response is not None: # just an empty json object
return True return True
except CodeMessageException as e: except CodeMessageException as e:
@ -128,8 +125,7 @@ class ApplicationServiceApi(SimpleHttpClient):
logger.warning("query_alias to %s threw exception %s", uri, ex) logger.warning("query_alias to %s threw exception %s", uri, ex)
return False return False
@defer.inlineCallbacks async def query_3pe(self, service, kind, protocol, fields):
def query_3pe(self, service, kind, protocol, fields):
if kind == ThirdPartyEntityKind.USER: if kind == ThirdPartyEntityKind.USER:
required_field = "userid" required_field = "userid"
elif kind == ThirdPartyEntityKind.LOCATION: elif kind == ThirdPartyEntityKind.LOCATION:
@ -146,7 +142,7 @@ class ApplicationServiceApi(SimpleHttpClient):
urllib.parse.quote(protocol), urllib.parse.quote(protocol),
) )
try: try:
response = yield self.get_json(uri, fields) response = await self.get_json(uri, fields)
if not isinstance(response, list): if not isinstance(response, list):
logger.warning( logger.warning(
"query_3pe to %s returned an invalid response %r", uri, response "query_3pe to %s returned an invalid response %r", uri, response
@ -202,8 +198,7 @@ class ApplicationServiceApi(SimpleHttpClient):
key = (service.id, protocol) key = (service.id, protocol)
return self.protocol_meta_cache.wrap(key, _get) return self.protocol_meta_cache.wrap(key, _get)
@defer.inlineCallbacks async def push_bulk(self, service, events, txn_id=None):
def push_bulk(self, service, events, txn_id=None):
if service.url is None: if service.url is None:
return True return True
@ -218,7 +213,7 @@ class ApplicationServiceApi(SimpleHttpClient):
uri = service.url + ("/transactions/%s" % urllib.parse.quote(txn_id)) uri = service.url + ("/transactions/%s" % urllib.parse.quote(txn_id))
try: try:
yield self.put_json( await self.put_json(
uri=uri, uri=uri,
json_body={"events": events}, json_body={"events": events},
args={"access_token": service.hs_token}, args={"access_token": service.hs_token},

View file

@ -50,8 +50,6 @@ components.
""" """
import logging import logging
from twisted.internet import defer
from synapse.appservice import ApplicationServiceState from synapse.appservice import ApplicationServiceState
from synapse.logging.context import run_in_background from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
@ -73,12 +71,11 @@ class ApplicationServiceScheduler(object):
self.txn_ctrl = _TransactionController(self.clock, self.store, self.as_api) self.txn_ctrl = _TransactionController(self.clock, self.store, self.as_api)
self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock) self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock)
@defer.inlineCallbacks async def start(self):
def start(self):
logger.info("Starting appservice scheduler") logger.info("Starting appservice scheduler")
# check for any DOWN ASes and start recoverers for them. # check for any DOWN ASes and start recoverers for them.
services = yield self.store.get_appservices_by_state( services = await self.store.get_appservices_by_state(
ApplicationServiceState.DOWN ApplicationServiceState.DOWN
) )
@ -117,8 +114,7 @@ class _ServiceQueuer(object):
"as-sender-%s" % (service.id,), self._send_request, service "as-sender-%s" % (service.id,), self._send_request, service
) )
@defer.inlineCallbacks async def _send_request(self, service):
def _send_request(self, service):
# sanity-check: we shouldn't get here if this service already has a sender # sanity-check: we shouldn't get here if this service already has a sender
# running. # running.
assert service.id not in self.requests_in_flight assert service.id not in self.requests_in_flight
@ -130,7 +126,7 @@ class _ServiceQueuer(object):
if not events: if not events:
return return
try: try:
yield self.txn_ctrl.send(service, events) await self.txn_ctrl.send(service, events)
except Exception: except Exception:
logger.exception("AS request failed") logger.exception("AS request failed")
finally: finally:
@ -162,36 +158,33 @@ class _TransactionController(object):
# for UTs # for UTs
self.RECOVERER_CLASS = _Recoverer self.RECOVERER_CLASS = _Recoverer
@defer.inlineCallbacks async def send(self, service, events):
def send(self, service, events):
try: try:
txn = yield self.store.create_appservice_txn(service=service, events=events) txn = await self.store.create_appservice_txn(service=service, events=events)
service_is_up = yield self._is_service_up(service) service_is_up = await self._is_service_up(service)
if service_is_up: if service_is_up:
sent = yield txn.send(self.as_api) sent = await txn.send(self.as_api)
if sent: if sent:
yield txn.complete(self.store) await txn.complete(self.store)
else: else:
run_in_background(self._on_txn_fail, service) run_in_background(self._on_txn_fail, service)
except Exception: except Exception:
logger.exception("Error creating appservice transaction") logger.exception("Error creating appservice transaction")
run_in_background(self._on_txn_fail, service) run_in_background(self._on_txn_fail, service)
@defer.inlineCallbacks async def on_recovered(self, recoverer):
def on_recovered(self, recoverer):
logger.info( logger.info(
"Successfully recovered application service AS ID %s", recoverer.service.id "Successfully recovered application service AS ID %s", recoverer.service.id
) )
self.recoverers.pop(recoverer.service.id) self.recoverers.pop(recoverer.service.id)
logger.info("Remaining active recoverers: %s", len(self.recoverers)) logger.info("Remaining active recoverers: %s", len(self.recoverers))
yield self.store.set_appservice_state( await self.store.set_appservice_state(
recoverer.service, ApplicationServiceState.UP recoverer.service, ApplicationServiceState.UP
) )
@defer.inlineCallbacks async def _on_txn_fail(self, service):
def _on_txn_fail(self, service):
try: try:
yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN) await self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
self.start_recoverer(service) self.start_recoverer(service)
except Exception: except Exception:
logger.exception("Error starting AS recoverer") logger.exception("Error starting AS recoverer")
@ -211,9 +204,8 @@ class _TransactionController(object):
recoverer.recover() recoverer.recover()
logger.info("Now %i active recoverers", len(self.recoverers)) logger.info("Now %i active recoverers", len(self.recoverers))
@defer.inlineCallbacks async def _is_service_up(self, service):
def _is_service_up(self, service): state = await self.store.get_appservice_state(service)
state = yield self.store.get_appservice_state(service)
return state == ApplicationServiceState.UP or state is None return state == ApplicationServiceState.UP or state is None
@ -254,25 +246,24 @@ class _Recoverer(object):
self.backoff_counter += 1 self.backoff_counter += 1
self.recover() self.recover()
@defer.inlineCallbacks async def retry(self):
def retry(self):
logger.info("Starting retries on %s", self.service.id) logger.info("Starting retries on %s", self.service.id)
try: try:
while True: while True:
txn = yield self.store.get_oldest_unsent_txn(self.service) txn = await self.store.get_oldest_unsent_txn(self.service)
if not txn: if not txn:
# nothing left: we're done! # nothing left: we're done!
self.callback(self) await self.callback(self)
return return
logger.info( logger.info(
"Retrying transaction %s for AS ID %s", txn.id, txn.service.id "Retrying transaction %s for AS ID %s", txn.id, txn.service.id
) )
sent = yield txn.send(self.as_api) sent = await txn.send(self.as_api)
if not sent: if not sent:
break break
yield txn.complete(self.store) await txn.complete(self.store)
# reset the backoff counter and then process the next transaction # reset the backoff counter and then process the next transaction
self.backoff_counter = 1 self.backoff_counter = 1

View file

@ -333,24 +333,6 @@ class RegistrationConfig(Config):
# #
#default_identity_server: https://matrix.org #default_identity_server: https://matrix.org
# The list of identity servers trusted to verify third party
# identifiers by this server.
#
# Also defines the ID server which will be called when an account is
# deactivated (one will be picked arbitrarily).
#
# Note: This option is deprecated. Since v0.99.4, Synapse has tracked which identity
# server a 3PID has been bound to. For 3PIDs bound before then, Synapse runs a
# background migration script, informing itself that the identity server all of its
# 3PIDs have been bound to is likely one of the below.
#
# As of Synapse v1.4.0, all other functionality of this option has been deprecated, and
# it is now solely used for the purposes of the background migration script, and can be
# removed once it has run.
#trusted_third_party_id_servers:
# - matrix.org
# - vector.im
# Handle threepid (email/phone etc) registration and password resets through a set of # Handle threepid (email/phone etc) registration and password resets through a set of
# *trusted* identity servers. Note that this allows the configured identity server to # *trusted* identity servers. Note that this allows the configured identity server to
# reset passwords for accounts! # reset passwords for accounts!

View file

@ -439,6 +439,9 @@ class ServerConfig(Config):
validator=attr.validators.instance_of(str), validator=attr.validators.instance_of(str),
default=ROOM_COMPLEXITY_TOO_GREAT, default=ROOM_COMPLEXITY_TOO_GREAT,
) )
admins_can_join = attr.ib(
validator=attr.validators.instance_of(bool), default=False
)
self.limit_remote_rooms = LimitRemoteRoomsConfig( self.limit_remote_rooms = LimitRemoteRoomsConfig(
**(config.get("limit_remote_rooms") or {}) **(config.get("limit_remote_rooms") or {})
@ -893,6 +896,10 @@ class ServerConfig(Config):
# #
#complexity_error: "This room is too complex." #complexity_error: "This room is too complex."
# allow server admins to join complex rooms. Default is false.
#
#admins_can_join: true
# Whether to require a user to be in the room to add an alias to it. # Whether to require a user to be in the room to add an alias to it.
# Defaults to 'true'. # Defaults to 'true'.
# #

View file

@ -632,18 +632,20 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
) )
try: try:
query_response = yield self.client.post_json( query_response = yield defer.ensureDeferred(
destination=perspective_name, self.client.post_json(
path="/_matrix/key/v2/query", destination=perspective_name,
data={ path="/_matrix/key/v2/query",
"server_keys": { data={
server_name: { "server_keys": {
key_id: {"minimum_valid_until_ts": min_valid_ts} server_name: {
for key_id, min_valid_ts in server_keys.items() key_id: {"minimum_valid_until_ts": min_valid_ts}
for key_id, min_valid_ts in server_keys.items()
}
for server_name, server_keys in keys_to_fetch.items()
} }
for server_name, server_keys in keys_to_fetch.items() },
} )
},
) )
except (NotRetryingDestination, RequestSendFailed) as e: except (NotRetryingDestination, RequestSendFailed) as e:
# these both have str() representations which we can't really improve upon # these both have str() representations which we can't really improve upon
@ -792,23 +794,25 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
time_now_ms = self.clock.time_msec() time_now_ms = self.clock.time_msec()
try: try:
response = yield self.client.get_json( response = yield defer.ensureDeferred(
destination=server_name, self.client.get_json(
path="/_matrix/key/v2/server/" destination=server_name,
+ urllib.parse.quote(requested_key_id), path="/_matrix/key/v2/server/"
ignore_backoff=True, + urllib.parse.quote(requested_key_id),
# we only give the remote server 10s to respond. It should be an ignore_backoff=True,
# easy request to handle, so if it doesn't reply within 10s, it's # we only give the remote server 10s to respond. It should be an
# probably not going to. # easy request to handle, so if it doesn't reply within 10s, it's
# # probably not going to.
# Furthermore, when we are acting as a notary server, we cannot #
# wait all day for all of the origin servers, as the requesting # Furthermore, when we are acting as a notary server, we cannot
# server will otherwise time out before we can respond. # wait all day for all of the origin servers, as the requesting
# # server will otherwise time out before we can respond.
# (Note that get_json may make 4 attempts, so this can still take #
# almost 45 seconds to fetch the headers, plus up to another 60s to # (Note that get_json may make 4 attempts, so this can still take
# read the response). # almost 45 seconds to fetch the headers, plus up to another 60s to
timeout=10000, # read the response).
timeout=10000,
)
) )
except (NotRetryingDestination, RequestSendFailed) as e: except (NotRetryingDestination, RequestSendFailed) as e:
# these both have str() representations which we can't really improve # these both have str() representations which we can't really improve

View file

@ -17,8 +17,6 @@ from typing import Optional
import attr import attr
from nacl.signing import SigningKey from nacl.signing import SigningKey
from twisted.internet import defer
from synapse.api.constants import MAX_DEPTH from synapse.api.constants import MAX_DEPTH
from synapse.api.errors import UnsupportedRoomVersionError from synapse.api.errors import UnsupportedRoomVersionError
from synapse.api.room_versions import ( from synapse.api.room_versions import (
@ -95,31 +93,30 @@ class EventBuilder(object):
def is_state(self): def is_state(self):
return self._state_key is not None return self._state_key is not None
@defer.inlineCallbacks async def build(self, prev_event_ids):
def build(self, prev_event_ids):
"""Transform into a fully signed and hashed event """Transform into a fully signed and hashed event
Args: Args:
prev_event_ids (list[str]): The event IDs to use as the prev events prev_event_ids (list[str]): The event IDs to use as the prev events
Returns: Returns:
Deferred[FrozenEvent] FrozenEvent
""" """
state_ids = yield defer.ensureDeferred( state_ids = await self._state.get_current_state_ids(
self._state.get_current_state_ids(self.room_id, prev_event_ids) self.room_id, prev_event_ids
) )
auth_ids = yield self._auth.compute_auth_events(self, state_ids) auth_ids = await self._auth.compute_auth_events(self, state_ids)
format_version = self.room_version.event_format format_version = self.room_version.event_format
if format_version == EventFormatVersions.V1: if format_version == EventFormatVersions.V1:
auth_events = yield self._store.add_event_hashes(auth_ids) auth_events = await self._store.add_event_hashes(auth_ids)
prev_events = yield self._store.add_event_hashes(prev_event_ids) prev_events = await self._store.add_event_hashes(prev_event_ids)
else: else:
auth_events = auth_ids auth_events = auth_ids
prev_events = prev_event_ids prev_events = prev_event_ids
old_depth = yield self._store.get_max_depth_of(prev_event_ids) old_depth = await self._store.get_max_depth_of(prev_event_ids)
depth = old_depth + 1 depth = old_depth + 1
# we cap depth of generated events, to ensure that they are not # we cap depth of generated events, to ensure that they are not

View file

@ -12,17 +12,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional, Union from typing import TYPE_CHECKING, Optional, Union
import attr import attr
from frozendict import frozendict from frozendict import frozendict
from twisted.internet import defer
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.events import EventBase
from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.types import StateMap from synapse.types import StateMap
if TYPE_CHECKING:
from synapse.storage.data_stores.main import DataStore
@attr.s(slots=True) @attr.s(slots=True)
class EventContext: class EventContext:
@ -129,8 +131,7 @@ class EventContext:
delta_ids=delta_ids, delta_ids=delta_ids,
) )
@defer.inlineCallbacks async def serialize(self, event: EventBase, store: "DataStore") -> dict:
def serialize(self, event, store):
"""Converts self to a type that can be serialized as JSON, and then """Converts self to a type that can be serialized as JSON, and then
deserialized by `deserialize` deserialized by `deserialize`
@ -146,7 +147,7 @@ class EventContext:
# the prev_state_ids, so if we're a state event we include the event # the prev_state_ids, so if we're a state event we include the event
# id that we replaced in the state. # id that we replaced in the state.
if event.is_state(): if event.is_state():
prev_state_ids = yield self.get_prev_state_ids() prev_state_ids = await self.get_prev_state_ids()
prev_state_id = prev_state_ids.get((event.type, event.state_key)) prev_state_id = prev_state_ids.get((event.type, event.state_key))
else: else:
prev_state_id = None prev_state_id = None
@ -214,8 +215,7 @@ class EventContext:
return self._state_group return self._state_group
@defer.inlineCallbacks async def get_current_state_ids(self) -> Optional[StateMap[str]]:
def get_current_state_ids(self):
""" """
Gets the room state map, including this event - ie, the state in ``state_group`` Gets the room state map, including this event - ie, the state in ``state_group``
@ -224,32 +224,31 @@ class EventContext:
``rejected`` is set. ``rejected`` is set.
Returns: Returns:
Deferred[dict[(str, str), str]|None]: Returns None if state_group Returns None if state_group is None, which happens when the associated
is None, which happens when the associated event is an outlier. event is an outlier.
Maps a (type, state_key) to the event ID of the state event matching Maps a (type, state_key) to the event ID of the state event matching
this tuple. this tuple.
""" """
if self.rejected: if self.rejected:
raise RuntimeError("Attempt to access state_ids of rejected event") raise RuntimeError("Attempt to access state_ids of rejected event")
yield self._ensure_fetched() await self._ensure_fetched()
return self._current_state_ids return self._current_state_ids
@defer.inlineCallbacks async def get_prev_state_ids(self):
def get_prev_state_ids(self):
""" """
Gets the room state map, excluding this event. Gets the room state map, excluding this event.
For a non-state event, this will be the same as get_current_state_ids(). For a non-state event, this will be the same as get_current_state_ids().
Returns: Returns:
Deferred[dict[(str, str), str]|None]: Returns None if state_group dict[(str, str), str]|None: Returns None if state_group
is None, which happens when the associated event is an outlier. is None, which happens when the associated event is an outlier.
Maps a (type, state_key) to the event ID of the state event matching Maps a (type, state_key) to the event ID of the state event matching
this tuple. this tuple.
""" """
yield self._ensure_fetched() await self._ensure_fetched()
return self._prev_state_ids return self._prev_state_ids
def get_cached_current_state_ids(self): def get_cached_current_state_ids(self):
@ -269,8 +268,8 @@ class EventContext:
return self._current_state_ids return self._current_state_ids
def _ensure_fetched(self): async def _ensure_fetched(self):
return defer.succeed(None) return None
@attr.s(slots=True) @attr.s(slots=True)
@ -303,21 +302,20 @@ class _AsyncEventContextImpl(EventContext):
_event_state_key = attr.ib(default=None) _event_state_key = attr.ib(default=None)
_fetching_state_deferred = attr.ib(default=None) _fetching_state_deferred = attr.ib(default=None)
def _ensure_fetched(self): async def _ensure_fetched(self):
if not self._fetching_state_deferred: if not self._fetching_state_deferred:
self._fetching_state_deferred = run_in_background(self._fill_out_state) self._fetching_state_deferred = run_in_background(self._fill_out_state)
return make_deferred_yieldable(self._fetching_state_deferred) return await make_deferred_yieldable(self._fetching_state_deferred)
@defer.inlineCallbacks async def _fill_out_state(self):
def _fill_out_state(self):
"""Called to populate the _current_state_ids and _prev_state_ids """Called to populate the _current_state_ids and _prev_state_ids
attributes by loading from the database. attributes by loading from the database.
""" """
if self.state_group is None: if self.state_group is None:
return return
self._current_state_ids = yield self._storage.state.get_state_ids_for_group( self._current_state_ids = await self._storage.state.get_state_ids_for_group(
self.state_group self.state_group
) )
if self._event_state_key is not None: if self._event_state_key is not None:

View file

@ -13,7 +13,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.types import Requester
class ThirdPartyEventRules(object): class ThirdPartyEventRules(object):
@ -39,76 +41,79 @@ class ThirdPartyEventRules(object):
config=config, http_client=hs.get_simple_http_client() config=config, http_client=hs.get_simple_http_client()
) )
@defer.inlineCallbacks async def check_event_allowed(
def check_event_allowed(self, event, context): self, event: EventBase, context: EventContext
) -> bool:
"""Check if a provided event should be allowed in the given context. """Check if a provided event should be allowed in the given context.
Args: Args:
event (synapse.events.EventBase): The event to be checked. event: The event to be checked.
context (synapse.events.snapshot.EventContext): The context of the event. context: The context of the event.
Returns: Returns:
defer.Deferred[bool]: True if the event should be allowed, False if not. True if the event should be allowed, False if not.
""" """
if self.third_party_rules is None: if self.third_party_rules is None:
return True return True
prev_state_ids = yield context.get_prev_state_ids() prev_state_ids = await context.get_prev_state_ids()
# Retrieve the state events from the database. # Retrieve the state events from the database.
state_events = {} state_events = {}
for key, event_id in prev_state_ids.items(): for key, event_id in prev_state_ids.items():
state_events[key] = yield self.store.get_event(event_id, allow_none=True) state_events[key] = await self.store.get_event(event_id, allow_none=True)
ret = yield self.third_party_rules.check_event_allowed(event, state_events) ret = await self.third_party_rules.check_event_allowed(event, state_events)
return ret return ret
@defer.inlineCallbacks async def on_create_room(
def on_create_room(self, requester, config, is_requester_admin): self, requester: Requester, config: dict, is_requester_admin: bool
) -> bool:
"""Intercept requests to create room to allow, deny or update the """Intercept requests to create room to allow, deny or update the
request config. request config.
Args: Args:
requester (Requester) requester
config (dict): The creation config from the client. config: The creation config from the client.
is_requester_admin (bool): If the requester is an admin is_requester_admin: If the requester is an admin
Returns: Returns:
defer.Deferred[bool]: Whether room creation is allowed or denied. Whether room creation is allowed or denied.
""" """
if self.third_party_rules is None: if self.third_party_rules is None:
return True return True
ret = yield self.third_party_rules.on_create_room( ret = await self.third_party_rules.on_create_room(
requester, config, is_requester_admin requester, config, is_requester_admin
) )
return ret return ret
@defer.inlineCallbacks async def check_threepid_can_be_invited(
def check_threepid_can_be_invited(self, medium, address, room_id): self, medium: str, address: str, room_id: str
) -> bool:
"""Check if a provided 3PID can be invited in the given room. """Check if a provided 3PID can be invited in the given room.
Args: Args:
medium (str): The 3PID's medium. medium: The 3PID's medium.
address (str): The 3PID's address. address: The 3PID's address.
room_id (str): The room we want to invite the threepid to. room_id: The room we want to invite the threepid to.
Returns: Returns:
defer.Deferred[bool], True if the 3PID can be invited, False if not. True if the 3PID can be invited, False if not.
""" """
if self.third_party_rules is None: if self.third_party_rules is None:
return True return True
state_ids = yield self.store.get_filtered_current_state_ids(room_id) state_ids = await self.store.get_filtered_current_state_ids(room_id)
room_state_events = yield self.store.get_events(state_ids.values()) room_state_events = await self.store.get_events(state_ids.values())
state_events = {} state_events = {}
for key, event_id in state_ids.items(): for key, event_id in state_ids.items():
state_events[key] = room_state_events[event_id] state_events[key] = room_state_events[event_id]
ret = yield self.third_party_rules.check_threepid_can_be_invited( ret = await self.third_party_rules.check_threepid_can_be_invited(
medium, address, state_events medium, address, state_events
) )
return ret return ret

View file

@ -18,8 +18,6 @@ from typing import Any, Mapping, Union
from frozendict import frozendict from frozendict import frozendict
from twisted.internet import defer
from synapse.api.constants import EventTypes, RelationTypes from synapse.api.constants import EventTypes, RelationTypes
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import RoomVersion from synapse.api.room_versions import RoomVersion
@ -337,8 +335,9 @@ class EventClientSerializer(object):
hs.config.experimental_msc1849_support_enabled hs.config.experimental_msc1849_support_enabled
) )
@defer.inlineCallbacks async def serialize_event(
def serialize_event(self, event, time_now, bundle_aggregations=True, **kwargs): self, event, time_now, bundle_aggregations=True, **kwargs
):
"""Serializes a single event. """Serializes a single event.
Args: Args:
@ -348,7 +347,7 @@ class EventClientSerializer(object):
**kwargs: Arguments to pass to `serialize_event` **kwargs: Arguments to pass to `serialize_event`
Returns: Returns:
Deferred[dict]: The serialized event dict: The serialized event
""" """
# To handle the case of presence events and the like # To handle the case of presence events and the like
if not isinstance(event, EventBase): if not isinstance(event, EventBase):
@ -363,8 +362,8 @@ class EventClientSerializer(object):
if not event.internal_metadata.is_redacted() and ( if not event.internal_metadata.is_redacted() and (
self.experimental_msc1849_support_enabled and bundle_aggregations self.experimental_msc1849_support_enabled and bundle_aggregations
): ):
annotations = yield self.store.get_aggregation_groups_for_event(event_id) annotations = await self.store.get_aggregation_groups_for_event(event_id)
references = yield self.store.get_relations_for_event( references = await self.store.get_relations_for_event(
event_id, RelationTypes.REFERENCE, direction="f" event_id, RelationTypes.REFERENCE, direction="f"
) )
@ -378,7 +377,7 @@ class EventClientSerializer(object):
edit = None edit = None
if event.type == EventTypes.Message: if event.type == EventTypes.Message:
edit = yield self.store.get_applicable_edit(event_id) edit = await self.store.get_applicable_edit(event_id)
if edit: if edit:
# If there is an edit replace the content, preserving existing # If there is an edit replace the content, preserving existing

View file

@ -135,7 +135,7 @@ class FederationClient(FederationBase):
and try the request anyway. and try the request anyway.
Returns: Returns:
a Deferred which will eventually yield a JSON object from the a Awaitable which will eventually yield a JSON object from the
response response
""" """
sent_queries_counter.labels(query_type).inc() sent_queries_counter.labels(query_type).inc()
@ -157,7 +157,7 @@ class FederationClient(FederationBase):
content (dict): The query content. content (dict): The query content.
Returns: Returns:
a Deferred which will eventually yield a JSON object from the an Awaitable which will eventually yield a JSON object from the
response response
""" """
sent_queries_counter.labels("client_device_keys").inc() sent_queries_counter.labels("client_device_keys").inc()
@ -180,7 +180,7 @@ class FederationClient(FederationBase):
content (dict): The query content. content (dict): The query content.
Returns: Returns:
a Deferred which will eventually yield a JSON object from the an Awaitable which will eventually yield a JSON object from the
response response
""" """
sent_queries_counter.labels("client_one_time_keys").inc() sent_queries_counter.labels("client_one_time_keys").inc()
@ -900,7 +900,7 @@ class FederationClient(FederationBase):
party instance party instance
Returns: Returns:
Deferred[Dict[str, Any]]: The response from the remote server, or None if Awaitable[Dict[str, Any]]: The response from the remote server, or None if
`remote_server` is the same as the local server_name `remote_server` is the same as the local server_name
Raises: Raises:

View file

@ -288,8 +288,7 @@ class FederationSender(object):
for destination in destinations: for destination in destinations:
self._get_per_destination_queue(destination).send_pdu(pdu, order) self._get_per_destination_queue(destination).send_pdu(pdu, order)
@defer.inlineCallbacks async def send_read_receipt(self, receipt: ReadReceipt) -> None:
def send_read_receipt(self, receipt: ReadReceipt):
"""Send a RR to any other servers in the room """Send a RR to any other servers in the room
Args: Args:
@ -330,9 +329,7 @@ class FederationSender(object):
room_id = receipt.room_id room_id = receipt.room_id
# Work out which remote servers should be poked and poke them. # Work out which remote servers should be poked and poke them.
domains = yield defer.ensureDeferred( domains = await self.state.get_current_hosts_in_room(room_id)
self.state.get_current_hosts_in_room(room_id)
)
domains = [ domains = [
d d
for d in domains for d in domains
@ -387,8 +384,7 @@ class FederationSender(object):
queue.flush_read_receipts_for_room(room_id) queue.flush_read_receipts_for_room(room_id)
@preserve_fn # the caller should not yield on this @preserve_fn # the caller should not yield on this
@defer.inlineCallbacks async def send_presence(self, states: List[UserPresenceState]):
def send_presence(self, states: List[UserPresenceState]):
"""Send the new presence states to the appropriate destinations. """Send the new presence states to the appropriate destinations.
This actually queues up the presence states ready for sending and This actually queues up the presence states ready for sending and
@ -423,7 +419,7 @@ class FederationSender(object):
if not states_map: if not states_map:
break break
yield self._process_presence_inner(list(states_map.values())) await self._process_presence_inner(list(states_map.values()))
except Exception: except Exception:
logger.exception("Error sending presence states to servers") logger.exception("Error sending presence states to servers")
finally: finally:
@ -450,14 +446,11 @@ class FederationSender(object):
self._get_per_destination_queue(destination).send_presence(states) self._get_per_destination_queue(destination).send_presence(states)
@measure_func("txnqueue._process_presence") @measure_func("txnqueue._process_presence")
@defer.inlineCallbacks async def _process_presence_inner(self, states: List[UserPresenceState]):
def _process_presence_inner(self, states: List[UserPresenceState]):
"""Given a list of states populate self.pending_presence_by_dest and """Given a list of states populate self.pending_presence_by_dest and
poke to send a new transaction to each destination poke to send a new transaction to each destination
""" """
hosts_and_states = yield defer.ensureDeferred( hosts_and_states = await get_interested_remotes(self.store, states, self.state)
get_interested_remotes(self.store, states, self.state)
)
for destinations, states in hosts_and_states: for destinations, states in hosts_and_states:
for destination in destinations: for destination in destinations:

View file

@ -18,8 +18,6 @@ import logging
import urllib import urllib
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from twisted.internet import defer
from synapse.api.constants import Membership from synapse.api.constants import Membership
from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.api.urls import ( from synapse.api.urls import (
@ -51,7 +49,7 @@ class TransportLayerClient(object):
event_id (str): The event we want the context at. event_id (str): The event we want the context at.
Returns: Returns:
Deferred: Results in a dict received from the remote homeserver. Awaitable: Results in a dict received from the remote homeserver.
""" """
logger.debug("get_room_state_ids dest=%s, room=%s", destination, room_id) logger.debug("get_room_state_ids dest=%s, room=%s", destination, room_id)
@ -75,7 +73,7 @@ class TransportLayerClient(object):
giving up. None indicates no timeout. giving up. None indicates no timeout.
Returns: Returns:
Deferred: Results in a dict received from the remote homeserver. Awaitable: Results in a dict received from the remote homeserver.
""" """
logger.debug("get_pdu dest=%s, event_id=%s", destination, event_id) logger.debug("get_pdu dest=%s, event_id=%s", destination, event_id)
@ -96,7 +94,7 @@ class TransportLayerClient(object):
limit (int) limit (int)
Returns: Returns:
Deferred: Results in a dict received from the remote homeserver. Awaitable: Results in a dict received from the remote homeserver.
""" """
logger.debug( logger.debug(
"backfill dest=%s, room_id=%s, event_tuples=%r, limit=%s", "backfill dest=%s, room_id=%s, event_tuples=%r, limit=%s",
@ -118,16 +116,15 @@ class TransportLayerClient(object):
destination, path=path, args=args, try_trailing_slash_on_400=True destination, path=path, args=args, try_trailing_slash_on_400=True
) )
@defer.inlineCallbacks
@log_function @log_function
def send_transaction(self, transaction, json_data_callback=None): async def send_transaction(self, transaction, json_data_callback=None):
""" Sends the given Transaction to its destination """ Sends the given Transaction to its destination
Args: Args:
transaction (Transaction) transaction (Transaction)
Returns: Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result Succeeds when we get a 2xx HTTP response. The result
will be the decoded JSON body. will be the decoded JSON body.
Fails with ``HTTPRequestException`` if we get an HTTP response Fails with ``HTTPRequestException`` if we get an HTTP response
@ -154,7 +151,7 @@ class TransportLayerClient(object):
path = _create_v1_path("/send/%s", transaction.transaction_id) path = _create_v1_path("/send/%s", transaction.transaction_id)
response = yield self.client.put_json( response = await self.client.put_json(
transaction.destination, transaction.destination,
path=path, path=path,
data=json_data, data=json_data,
@ -166,14 +163,13 @@ class TransportLayerClient(object):
return response return response
@defer.inlineCallbacks
@log_function @log_function
def make_query( async def make_query(
self, destination, query_type, args, retry_on_dns_fail, ignore_backoff=False self, destination, query_type, args, retry_on_dns_fail, ignore_backoff=False
): ):
path = _create_v1_path("/query/%s", query_type) path = _create_v1_path("/query/%s", query_type)
content = yield self.client.get_json( content = await self.client.get_json(
destination=destination, destination=destination,
path=path, path=path,
args=args, args=args,
@ -184,9 +180,10 @@ class TransportLayerClient(object):
return content return content
@defer.inlineCallbacks
@log_function @log_function
def make_membership_event(self, destination, room_id, user_id, membership, params): async def make_membership_event(
self, destination, room_id, user_id, membership, params
):
"""Asks a remote server to build and sign us a membership event """Asks a remote server to build and sign us a membership event
Note that this does not append any events to any graphs. Note that this does not append any events to any graphs.
@ -200,7 +197,7 @@ class TransportLayerClient(object):
request. request.
Returns: Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result Succeeds when we get a 2xx HTTP response. The result
will be the decoded JSON body (ie, the new event). will be the decoded JSON body (ie, the new event).
Fails with ``HTTPRequestException`` if we get an HTTP response Fails with ``HTTPRequestException`` if we get an HTTP response
@ -231,7 +228,7 @@ class TransportLayerClient(object):
ignore_backoff = True ignore_backoff = True
retry_on_dns_fail = True retry_on_dns_fail = True
content = yield self.client.get_json( content = await self.client.get_json(
destination=destination, destination=destination,
path=path, path=path,
args=params, args=params,
@ -242,34 +239,31 @@ class TransportLayerClient(object):
return content return content
@defer.inlineCallbacks
@log_function @log_function
def send_join_v1(self, destination, room_id, event_id, content): async def send_join_v1(self, destination, room_id, event_id, content):
path = _create_v1_path("/send_join/%s/%s", room_id, event_id) path = _create_v1_path("/send_join/%s/%s", room_id, event_id)
response = yield self.client.put_json( response = await self.client.put_json(
destination=destination, path=path, data=content destination=destination, path=path, data=content
) )
return response return response
@defer.inlineCallbacks
@log_function @log_function
def send_join_v2(self, destination, room_id, event_id, content): async def send_join_v2(self, destination, room_id, event_id, content):
path = _create_v2_path("/send_join/%s/%s", room_id, event_id) path = _create_v2_path("/send_join/%s/%s", room_id, event_id)
response = yield self.client.put_json( response = await self.client.put_json(
destination=destination, path=path, data=content destination=destination, path=path, data=content
) )
return response return response
@defer.inlineCallbacks
@log_function @log_function
def send_leave_v1(self, destination, room_id, event_id, content): async def send_leave_v1(self, destination, room_id, event_id, content):
path = _create_v1_path("/send_leave/%s/%s", room_id, event_id) path = _create_v1_path("/send_leave/%s/%s", room_id, event_id)
response = yield self.client.put_json( response = await self.client.put_json(
destination=destination, destination=destination,
path=path, path=path,
data=content, data=content,
@ -282,12 +276,11 @@ class TransportLayerClient(object):
return response return response
@defer.inlineCallbacks
@log_function @log_function
def send_leave_v2(self, destination, room_id, event_id, content): async def send_leave_v2(self, destination, room_id, event_id, content):
path = _create_v2_path("/send_leave/%s/%s", room_id, event_id) path = _create_v2_path("/send_leave/%s/%s", room_id, event_id)
response = yield self.client.put_json( response = await self.client.put_json(
destination=destination, destination=destination,
path=path, path=path,
data=content, data=content,
@ -300,31 +293,28 @@ class TransportLayerClient(object):
return response return response
@defer.inlineCallbacks
@log_function @log_function
def send_invite_v1(self, destination, room_id, event_id, content): async def send_invite_v1(self, destination, room_id, event_id, content):
path = _create_v1_path("/invite/%s/%s", room_id, event_id) path = _create_v1_path("/invite/%s/%s", room_id, event_id)
response = yield self.client.put_json( response = await self.client.put_json(
destination=destination, path=path, data=content, ignore_backoff=True destination=destination, path=path, data=content, ignore_backoff=True
) )
return response return response
@defer.inlineCallbacks
@log_function @log_function
def send_invite_v2(self, destination, room_id, event_id, content): async def send_invite_v2(self, destination, room_id, event_id, content):
path = _create_v2_path("/invite/%s/%s", room_id, event_id) path = _create_v2_path("/invite/%s/%s", room_id, event_id)
response = yield self.client.put_json( response = await self.client.put_json(
destination=destination, path=path, data=content, ignore_backoff=True destination=destination, path=path, data=content, ignore_backoff=True
) )
return response return response
@defer.inlineCallbacks
@log_function @log_function
def get_public_rooms( async def get_public_rooms(
self, self,
remote_server: str, remote_server: str,
limit: Optional[int] = None, limit: Optional[int] = None,
@ -355,7 +345,7 @@ class TransportLayerClient(object):
data["filter"] = search_filter data["filter"] = search_filter
try: try:
response = yield self.client.post_json( response = await self.client.post_json(
destination=remote_server, path=path, data=data, ignore_backoff=True destination=remote_server, path=path, data=data, ignore_backoff=True
) )
except HttpResponseException as e: except HttpResponseException as e:
@ -381,7 +371,7 @@ class TransportLayerClient(object):
args["since"] = [since_token] args["since"] = [since_token]
try: try:
response = yield self.client.get_json( response = await self.client.get_json(
destination=remote_server, path=path, args=args, ignore_backoff=True destination=remote_server, path=path, args=args, ignore_backoff=True
) )
except HttpResponseException as e: except HttpResponseException as e:
@ -396,29 +386,26 @@ class TransportLayerClient(object):
return response return response
@defer.inlineCallbacks
@log_function @log_function
def exchange_third_party_invite(self, destination, room_id, event_dict): async def exchange_third_party_invite(self, destination, room_id, event_dict):
path = _create_v1_path("/exchange_third_party_invite/%s", room_id) path = _create_v1_path("/exchange_third_party_invite/%s", room_id)
response = yield self.client.put_json( response = await self.client.put_json(
destination=destination, path=path, data=event_dict destination=destination, path=path, data=event_dict
) )
return response return response
@defer.inlineCallbacks
@log_function @log_function
def get_event_auth(self, destination, room_id, event_id): async def get_event_auth(self, destination, room_id, event_id):
path = _create_v1_path("/event_auth/%s/%s", room_id, event_id) path = _create_v1_path("/event_auth/%s/%s", room_id, event_id)
content = yield self.client.get_json(destination=destination, path=path) content = await self.client.get_json(destination=destination, path=path)
return content return content
@defer.inlineCallbacks
@log_function @log_function
def query_client_keys(self, destination, query_content, timeout): async def query_client_keys(self, destination, query_content, timeout):
"""Query the device keys for a list of user ids hosted on a remote """Query the device keys for a list of user ids hosted on a remote
server. server.
@ -453,14 +440,13 @@ class TransportLayerClient(object):
""" """
path = _create_v1_path("/user/keys/query") path = _create_v1_path("/user/keys/query")
content = yield self.client.post_json( content = await self.client.post_json(
destination=destination, path=path, data=query_content, timeout=timeout destination=destination, path=path, data=query_content, timeout=timeout
) )
return content return content
@defer.inlineCallbacks
@log_function @log_function
def query_user_devices(self, destination, user_id, timeout): async def query_user_devices(self, destination, user_id, timeout):
"""Query the devices for a user id hosted on a remote server. """Query the devices for a user id hosted on a remote server.
Response: Response:
@ -493,14 +479,13 @@ class TransportLayerClient(object):
""" """
path = _create_v1_path("/user/devices/%s", user_id) path = _create_v1_path("/user/devices/%s", user_id)
content = yield self.client.get_json( content = await self.client.get_json(
destination=destination, path=path, timeout=timeout destination=destination, path=path, timeout=timeout
) )
return content return content
@defer.inlineCallbacks
@log_function @log_function
def claim_client_keys(self, destination, query_content, timeout): async def claim_client_keys(self, destination, query_content, timeout):
"""Claim one-time keys for a list of devices hosted on a remote server. """Claim one-time keys for a list of devices hosted on a remote server.
Request: Request:
@ -532,14 +517,13 @@ class TransportLayerClient(object):
path = _create_v1_path("/user/keys/claim") path = _create_v1_path("/user/keys/claim")
content = yield self.client.post_json( content = await self.client.post_json(
destination=destination, path=path, data=query_content, timeout=timeout destination=destination, path=path, data=query_content, timeout=timeout
) )
return content return content
@defer.inlineCallbacks
@log_function @log_function
def get_missing_events( async def get_missing_events(
self, self,
destination, destination,
room_id, room_id,
@ -551,7 +535,7 @@ class TransportLayerClient(object):
): ):
path = _create_v1_path("/get_missing_events/%s", room_id) path = _create_v1_path("/get_missing_events/%s", room_id)
content = yield self.client.post_json( content = await self.client.post_json(
destination=destination, destination=destination,
path=path, path=path,
data={ data={

View file

@ -41,8 +41,6 @@ from typing import Tuple
from signedjson.sign import sign_json from signedjson.sign import sign_json
from twisted.internet import defer
from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import get_domain_from_id from synapse.types import get_domain_from_id
@ -72,8 +70,9 @@ class GroupAttestationSigning(object):
self.server_name = hs.hostname self.server_name = hs.hostname
self.signing_key = hs.signing_key self.signing_key = hs.signing_key
@defer.inlineCallbacks async def verify_attestation(
def verify_attestation(self, attestation, group_id, user_id, server_name=None): self, attestation, group_id, user_id, server_name=None
):
"""Verifies that the given attestation matches the given parameters. """Verifies that the given attestation matches the given parameters.
An optional server_name can be supplied to explicitly set which server's An optional server_name can be supplied to explicitly set which server's
@ -102,7 +101,7 @@ class GroupAttestationSigning(object):
if valid_until_ms < now: if valid_until_ms < now:
raise SynapseError(400, "Attestation expired") raise SynapseError(400, "Attestation expired")
yield self.keyring.verify_json_for_server( await self.keyring.verify_json_for_server(
server_name, attestation, now, "Group attestation" server_name, attestation, now, "Group attestation"
) )
@ -142,8 +141,7 @@ class GroupAttestionRenewer(object):
self._start_renew_attestations, 30 * 60 * 1000 self._start_renew_attestations, 30 * 60 * 1000
) )
@defer.inlineCallbacks async def on_renew_attestation(self, group_id, user_id, content):
def on_renew_attestation(self, group_id, user_id, content):
"""When a remote updates an attestation """When a remote updates an attestation
""" """
attestation = content["attestation"] attestation = content["attestation"]
@ -151,11 +149,11 @@ class GroupAttestionRenewer(object):
if not self.is_mine_id(group_id) and not self.is_mine_id(user_id): if not self.is_mine_id(group_id) and not self.is_mine_id(user_id):
raise SynapseError(400, "Neither user not group are on this server") raise SynapseError(400, "Neither user not group are on this server")
yield self.attestations.verify_attestation( await self.attestations.verify_attestation(
attestation, user_id=user_id, group_id=group_id attestation, user_id=user_id, group_id=group_id
) )
yield self.store.update_remote_attestion(group_id, user_id, attestation) await self.store.update_remote_attestion(group_id, user_id, attestation)
return {} return {}
@ -172,8 +170,7 @@ class GroupAttestionRenewer(object):
now + UPDATE_ATTESTATION_TIME_MS now + UPDATE_ATTESTATION_TIME_MS
) )
@defer.inlineCallbacks async def _renew_attestation(group_user: Tuple[str, str]):
def _renew_attestation(group_user: Tuple[str, str]):
group_id, user_id = group_user group_id, user_id = group_user
try: try:
if not self.is_mine_id(group_id): if not self.is_mine_id(group_id):
@ -186,16 +183,16 @@ class GroupAttestionRenewer(object):
user_id, user_id,
group_id, group_id,
) )
yield self.store.remove_attestation_renewal(group_id, user_id) await self.store.remove_attestation_renewal(group_id, user_id)
return return
attestation = self.attestations.create_attestation(group_id, user_id) attestation = self.attestations.create_attestation(group_id, user_id)
yield self.transport_client.renew_group_attestation( await self.transport_client.renew_group_attestation(
destination, group_id, user_id, content={"attestation": attestation} destination, group_id, user_id, content={"attestation": attestation}
) )
yield self.store.update_attestation_renewal( await self.store.update_attestation_renewal(
group_id, user_id, attestation group_id, user_id, attestation
) )
except (RequestSendFailed, HttpResponseException) as e: except (RequestSendFailed, HttpResponseException) as e:

View file

@ -27,7 +27,6 @@ from synapse.metrics import (
event_processing_loop_room_count, event_processing_loop_room_count,
) )
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util import log_failure
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -100,10 +99,11 @@ class ApplicationServicesHandler(object):
if not self.started_scheduler: if not self.started_scheduler:
def start_scheduler(): async def start_scheduler():
return self.scheduler.start().addErrback( try:
log_failure, "Application Services Failure" return self.scheduler.start()
) except Exception:
logger.error("Application Services Failure")
run_as_background_process("as_scheduler", start_scheduler) run_as_background_process("as_scheduler", start_scheduler)
self.started_scheduler = True self.started_scheduler = True

View file

@ -2470,7 +2470,7 @@ class FederationHandler(BaseHandler):
} }
current_state_ids = await context.get_current_state_ids() current_state_ids = await context.get_current_state_ids()
current_state_ids = dict(current_state_ids) current_state_ids = dict(current_state_ids) # type: ignore
current_state_ids.update(state_updates) current_state_ids.update(state_updates)

View file

@ -23,39 +23,32 @@ logger = logging.getLogger(__name__)
def _create_rerouter(func_name): def _create_rerouter(func_name):
"""Returns a function that looks at the group id and calls the function """Returns an async function that looks at the group id and calls the function
on federation or the local group server if the group is local on federation or the local group server if the group is local
""" """
def f(self, group_id, *args, **kwargs): async def f(self, group_id, *args, **kwargs):
if self.is_mine_id(group_id): if self.is_mine_id(group_id):
return getattr(self.groups_server_handler, func_name)( return await getattr(self.groups_server_handler, func_name)(
group_id, *args, **kwargs group_id, *args, **kwargs
) )
else: else:
destination = get_domain_from_id(group_id) destination = get_domain_from_id(group_id)
d = getattr(self.transport_client, func_name)(
destination, group_id, *args, **kwargs
)
# Capture errors returned by the remote homeserver and try:
# re-throw specific errors as SynapseErrors. This is so return await getattr(self.transport_client, func_name)(
# when the remote end responds with things like 403 Not destination, group_id, *args, **kwargs
# In Group, we can communicate that to the client instead )
# of a 500. except HttpResponseException as e:
def http_response_errback(failure): # Capture errors returned by the remote homeserver and
failure.trap(HttpResponseException) # re-throw specific errors as SynapseErrors. This is so
e = failure.value # when the remote end responds with things like 403 Not
# In Group, we can communicate that to the client instead
# of a 500.
raise e.to_synapse_error() raise e.to_synapse_error()
except RequestSendFailed:
def request_failed_errback(failure):
failure.trap(RequestSendFailed)
raise SynapseError(502, "Failed to contact group server") raise SynapseError(502, "Failed to contact group server")
d.addErrback(http_response_errback)
d.addErrback(request_failed_errback)
return d
return f return f

View file

@ -502,26 +502,39 @@ class RoomMemberHandler(object):
user_id=target.to_string(), room_id=room_id user_id=target.to_string(), room_id=room_id
) # type: Optional[RoomsForUser] ) # type: Optional[RoomsForUser]
if not invite: if not invite:
logger.info(
"%s sent a leave request to %s, but that is not an active room "
"on this server, and there is no pending invite",
target,
room_id,
)
raise SynapseError(404, "Not a known room") raise SynapseError(404, "Not a known room")
logger.info( logger.info(
"%s rejects invite to %s from %s", target, room_id, invite.sender "%s rejects invite to %s from %s", target, room_id, invite.sender
) )
if self.hs.is_mine_id(invite.sender): if not self.hs.is_mine_id(invite.sender):
# the inviter was on our server, but has now left. Carry on
# with the normal rejection codepath.
#
# This is a bit of a hack, because the room might still be
# active on other servers.
pass
else:
# send the rejection to the inviter's HS (with fallback to # send the rejection to the inviter's HS (with fallback to
# local event) # local event)
return await self.remote_reject_invite( return await self.remote_reject_invite(
invite.event_id, txn_id, requester, content, invite.event_id, txn_id, requester, content,
) )
# the inviter was on our server, but has now left. Carry on
# with the normal rejection codepath, which will also send the
# rejection out to any other servers we believe are still in the room.
# thanks to overzealous cleaning up of event_forward_extremities in
# `delete_old_current_state_events`, it's possible to end up with no
# forward extremities here. If that happens, let's just hang the
# rejection off the invite event.
#
# see: https://github.com/matrix-org/synapse/issues/7139
if len(latest_event_ids) == 0:
latest_event_ids = [invite.event_id]
return await self._local_membership_update( return await self._local_membership_update(
requester=requester, requester=requester,
target=target, target=target,
@ -985,7 +998,11 @@ class RoomMemberMasterHandler(RoomMemberHandler):
if len(remote_room_hosts) == 0: if len(remote_room_hosts) == 0:
raise SynapseError(404, "No known servers") raise SynapseError(404, "No known servers")
if self.hs.config.limit_remote_rooms.enabled: check_complexity = self.hs.config.limit_remote_rooms.enabled
if check_complexity and self.hs.config.limit_remote_rooms.admins_can_join:
check_complexity = not await self.hs.auth.is_server_admin(user)
if check_complexity:
# Fetch the room complexity # Fetch the room complexity
too_complex = await self._is_remote_room_too_complex( too_complex = await self._is_remote_room_too_complex(
room_id, remote_room_hosts room_id, remote_room_hosts
@ -1008,7 +1025,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
# Check the room we just joined wasn't too large, if we didn't fetch the # Check the room we just joined wasn't too large, if we didn't fetch the
# complexity of it before. # complexity of it before.
if self.hs.config.limit_remote_rooms.enabled: if check_complexity:
if too_complex is False: if too_complex is False:
# We checked, and we're under the limit. # We checked, and we're under the limit.
return event_id, stream_id return event_id, stream_id

View file

@ -96,6 +96,9 @@ class SamlHandler:
relay_state=client_redirect_url relay_state=client_redirect_url
) )
# Since SAML sessions timeout it is useful to log when they were created.
logger.info("Initiating a new SAML session: %s" % (reqid,))
now = self._clock.time_msec() now = self._clock.time_msec()
self._outstanding_requests_dict[reqid] = Saml2SessionData( self._outstanding_requests_dict[reqid] = Saml2SessionData(
creation_time=now, ui_auth_session_id=ui_auth_session_id, creation_time=now, ui_auth_session_id=ui_auth_session_id,

View file

@ -103,6 +103,7 @@ class JoinedSyncResult:
account_data = attr.ib(type=List[JsonDict]) account_data = attr.ib(type=List[JsonDict])
unread_notifications = attr.ib(type=JsonDict) unread_notifications = attr.ib(type=JsonDict)
summary = attr.ib(type=Optional[JsonDict]) summary = attr.ib(type=Optional[JsonDict])
unread_count = attr.ib(type=int)
def __nonzero__(self) -> bool: def __nonzero__(self) -> bool:
"""Make the result appear empty if there are no updates. This is used """Make the result appear empty if there are no updates. This is used
@ -1886,6 +1887,10 @@ class SyncHandler(object):
if room_builder.rtype == "joined": if room_builder.rtype == "joined":
unread_notifications = {} # type: Dict[str, str] unread_notifications = {} # type: Dict[str, str]
unread_count = await self.store.get_unread_message_count_for_user(
room_id, sync_config.user.to_string(),
)
room_sync = JoinedSyncResult( room_sync = JoinedSyncResult(
room_id=room_id, room_id=room_id,
timeline=batch, timeline=batch,
@ -1894,6 +1899,7 @@ class SyncHandler(object):
account_data=account_data_events, account_data=account_data_events,
unread_notifications=unread_notifications, unread_notifications=unread_notifications,
summary=summary, summary=summary,
unread_count=unread_count,
) )
if room_sync or always_include: if room_sync or always_include:

View file

@ -395,7 +395,9 @@ class SimpleHttpClient(object):
if 200 <= response.code < 300: if 200 <= response.code < 300:
return json.loads(body.decode("utf-8")) return json.loads(body.decode("utf-8"))
else: else:
raise HttpResponseException(response.code, response.phrase, body) raise HttpResponseException(
response.code, response.phrase.decode("ascii", errors="replace"), body
)
@defer.inlineCallbacks @defer.inlineCallbacks
def post_json_get_json(self, uri, post_json, headers=None): def post_json_get_json(self, uri, post_json, headers=None):
@ -436,7 +438,9 @@ class SimpleHttpClient(object):
if 200 <= response.code < 300: if 200 <= response.code < 300:
return json.loads(body.decode("utf-8")) return json.loads(body.decode("utf-8"))
else: else:
raise HttpResponseException(response.code, response.phrase, body) raise HttpResponseException(
response.code, response.phrase.decode("ascii", errors="replace"), body
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_json(self, uri, args={}, headers=None): def get_json(self, uri, args={}, headers=None):
@ -509,7 +513,9 @@ class SimpleHttpClient(object):
if 200 <= response.code < 300: if 200 <= response.code < 300:
return json.loads(body.decode("utf-8")) return json.loads(body.decode("utf-8"))
else: else:
raise HttpResponseException(response.code, response.phrase, body) raise HttpResponseException(
response.code, response.phrase.decode("ascii", errors="replace"), body
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_raw(self, uri, args={}, headers=None): def get_raw(self, uri, args={}, headers=None):
@ -544,7 +550,9 @@ class SimpleHttpClient(object):
if 200 <= response.code < 300: if 200 <= response.code < 300:
return body return body
else: else:
raise HttpResponseException(response.code, response.phrase, body) raise HttpResponseException(
response.code, response.phrase.decode("ascii", errors="replace"), body
)
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient. # XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
# The two should be factored out. # The two should be factored out.

View file

@ -121,8 +121,7 @@ class MatrixFederationRequest(object):
return self.json return self.json
@defer.inlineCallbacks async def _handle_json_response(reactor, timeout_sec, request, response):
def _handle_json_response(reactor, timeout_sec, request, response):
""" """
Reads the JSON body of a response, with a timeout Reads the JSON body of a response, with a timeout
@ -141,7 +140,7 @@ def _handle_json_response(reactor, timeout_sec, request, response):
d = treq.json_content(response) d = treq.json_content(response)
d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor) d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor)
body = yield make_deferred_yieldable(d) body = await make_deferred_yieldable(d)
except TimeoutError as e: except TimeoutError as e:
logger.warning( logger.warning(
"{%s} [%s] Timed out reading response", request.txn_id, request.destination, "{%s} [%s] Timed out reading response", request.txn_id, request.destination,
@ -224,8 +223,7 @@ class MatrixFederationHttpClient(object):
self._cooperator = Cooperator(scheduler=schedule) self._cooperator = Cooperator(scheduler=schedule)
@defer.inlineCallbacks async def _send_request_with_optional_trailing_slash(
def _send_request_with_optional_trailing_slash(
self, request, try_trailing_slash_on_400=False, **send_request_args self, request, try_trailing_slash_on_400=False, **send_request_args
): ):
"""Wrapper for _send_request which can optionally retry the request """Wrapper for _send_request which can optionally retry the request
@ -246,10 +244,10 @@ class MatrixFederationHttpClient(object):
(except 429). (except 429).
Returns: Returns:
Deferred[Dict]: Parsed JSON response body. Dict: Parsed JSON response body.
""" """
try: try:
response = yield self._send_request(request, **send_request_args) response = await self._send_request(request, **send_request_args)
except HttpResponseException as e: except HttpResponseException as e:
# Received an HTTP error > 300. Check if it meets the requirements # Received an HTTP error > 300. Check if it meets the requirements
# to retry with a trailing slash # to retry with a trailing slash
@ -265,12 +263,11 @@ class MatrixFederationHttpClient(object):
logger.info("Retrying request with trailing slash") logger.info("Retrying request with trailing slash")
request.path += "/" request.path += "/"
response = yield self._send_request(request, **send_request_args) response = await self._send_request(request, **send_request_args)
return response return response
@defer.inlineCallbacks async def _send_request(
def _send_request(
self, self,
request, request,
retry_on_dns_fail=True, retry_on_dns_fail=True,
@ -311,7 +308,7 @@ class MatrixFederationHttpClient(object):
backoff_on_404 (bool): Back off if we get a 404 backoff_on_404 (bool): Back off if we get a 404
Returns: Returns:
Deferred[twisted.web.client.Response]: resolves with the HTTP twisted.web.client.Response: resolves with the HTTP
response object on success. response object on success.
Raises: Raises:
@ -335,7 +332,7 @@ class MatrixFederationHttpClient(object):
): ):
raise FederationDeniedError(request.destination) raise FederationDeniedError(request.destination)
limiter = yield synapse.util.retryutils.get_retry_limiter( limiter = await synapse.util.retryutils.get_retry_limiter(
request.destination, request.destination,
self.clock, self.clock,
self._store, self._store,
@ -433,7 +430,7 @@ class MatrixFederationHttpClient(object):
reactor=self.reactor, reactor=self.reactor,
) )
response = yield request_deferred response = await request_deferred
except TimeoutError as e: except TimeoutError as e:
raise RequestSendFailed(e, can_retry=True) from e raise RequestSendFailed(e, can_retry=True) from e
except DNSLookupError as e: except DNSLookupError as e:
@ -447,6 +444,7 @@ class MatrixFederationHttpClient(object):
).inc() ).inc()
set_tag(tags.HTTP_STATUS_CODE, response.code) set_tag(tags.HTTP_STATUS_CODE, response.code)
response_phrase = response.phrase.decode("ascii", errors="replace")
if 200 <= response.code < 300: if 200 <= response.code < 300:
logger.debug( logger.debug(
@ -454,7 +452,7 @@ class MatrixFederationHttpClient(object):
request.txn_id, request.txn_id,
request.destination, request.destination,
response.code, response.code,
response.phrase.decode("ascii", errors="replace"), response_phrase,
) )
pass pass
else: else:
@ -463,7 +461,7 @@ class MatrixFederationHttpClient(object):
request.txn_id, request.txn_id,
request.destination, request.destination,
response.code, response.code,
response.phrase.decode("ascii", errors="replace"), response_phrase,
) )
# :'( # :'(
# Update transactions table? # Update transactions table?
@ -473,7 +471,7 @@ class MatrixFederationHttpClient(object):
) )
try: try:
body = yield make_deferred_yieldable(d) body = await make_deferred_yieldable(d)
except Exception as e: except Exception as e:
# Eh, we're already going to raise an exception so lets # Eh, we're already going to raise an exception so lets
# ignore if this fails. # ignore if this fails.
@ -487,7 +485,7 @@ class MatrixFederationHttpClient(object):
) )
body = None body = None
e = HttpResponseException(response.code, response.phrase, body) e = HttpResponseException(response.code, response_phrase, body)
# Retry if the error is a 429 (Too Many Requests), # Retry if the error is a 429 (Too Many Requests),
# otherwise just raise a standard HttpResponseException # otherwise just raise a standard HttpResponseException
@ -527,7 +525,7 @@ class MatrixFederationHttpClient(object):
delay, delay,
) )
yield self.clock.sleep(delay) await self.clock.sleep(delay)
retries_left -= 1 retries_left -= 1
else: else:
raise raise
@ -590,8 +588,7 @@ class MatrixFederationHttpClient(object):
) )
return auth_headers return auth_headers
@defer.inlineCallbacks async def put_json(
def put_json(
self, self,
destination, destination,
path, path,
@ -635,7 +632,7 @@ class MatrixFederationHttpClient(object):
enabled. enabled.
Returns: Returns:
Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The dict|list: Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body. result will be the decoded JSON body.
Raises: Raises:
@ -657,7 +654,7 @@ class MatrixFederationHttpClient(object):
json=data, json=data,
) )
response = yield self._send_request_with_optional_trailing_slash( response = await self._send_request_with_optional_trailing_slash(
request, request,
try_trailing_slash_on_400, try_trailing_slash_on_400,
backoff_on_404=backoff_on_404, backoff_on_404=backoff_on_404,
@ -666,14 +663,13 @@ class MatrixFederationHttpClient(object):
timeout=timeout, timeout=timeout,
) )
body = yield _handle_json_response( body = await _handle_json_response(
self.reactor, self.default_timeout, request, response self.reactor, self.default_timeout, request, response
) )
return body return body
@defer.inlineCallbacks async def post_json(
def post_json(
self, self,
destination, destination,
path, path,
@ -706,7 +702,7 @@ class MatrixFederationHttpClient(object):
args (dict): query params args (dict): query params
Returns: Returns:
Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The dict|list: Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body. result will be the decoded JSON body.
Raises: Raises:
@ -724,7 +720,7 @@ class MatrixFederationHttpClient(object):
method="POST", destination=destination, path=path, query=args, json=data method="POST", destination=destination, path=path, query=args, json=data
) )
response = yield self._send_request( response = await self._send_request(
request, request,
long_retries=long_retries, long_retries=long_retries,
timeout=timeout, timeout=timeout,
@ -736,13 +732,12 @@ class MatrixFederationHttpClient(object):
else: else:
_sec_timeout = self.default_timeout _sec_timeout = self.default_timeout
body = yield _handle_json_response( body = await _handle_json_response(
self.reactor, _sec_timeout, request, response self.reactor, _sec_timeout, request, response
) )
return body return body
@defer.inlineCallbacks async def get_json(
def get_json(
self, self,
destination, destination,
path, path,
@ -774,7 +769,7 @@ class MatrixFederationHttpClient(object):
response we should try appending a trailing slash to the end of response we should try appending a trailing slash to the end of
the request. Workaround for #3622 in Synapse <= v0.99.3. the request. Workaround for #3622 in Synapse <= v0.99.3.
Returns: Returns:
Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The dict|list: Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body. result will be the decoded JSON body.
Raises: Raises:
@ -791,7 +786,7 @@ class MatrixFederationHttpClient(object):
method="GET", destination=destination, path=path, query=args method="GET", destination=destination, path=path, query=args
) )
response = yield self._send_request_with_optional_trailing_slash( response = await self._send_request_with_optional_trailing_slash(
request, request,
try_trailing_slash_on_400, try_trailing_slash_on_400,
backoff_on_404=False, backoff_on_404=False,
@ -800,14 +795,13 @@ class MatrixFederationHttpClient(object):
timeout=timeout, timeout=timeout,
) )
body = yield _handle_json_response( body = await _handle_json_response(
self.reactor, self.default_timeout, request, response self.reactor, self.default_timeout, request, response
) )
return body return body
@defer.inlineCallbacks async def delete_json(
def delete_json(
self, self,
destination, destination,
path, path,
@ -835,7 +829,7 @@ class MatrixFederationHttpClient(object):
args (dict): query params args (dict): query params
Returns: Returns:
Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The dict|list: Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body. result will be the decoded JSON body.
Raises: Raises:
@ -852,20 +846,19 @@ class MatrixFederationHttpClient(object):
method="DELETE", destination=destination, path=path, query=args method="DELETE", destination=destination, path=path, query=args
) )
response = yield self._send_request( response = await self._send_request(
request, request,
long_retries=long_retries, long_retries=long_retries,
timeout=timeout, timeout=timeout,
ignore_backoff=ignore_backoff, ignore_backoff=ignore_backoff,
) )
body = yield _handle_json_response( body = await _handle_json_response(
self.reactor, self.default_timeout, request, response self.reactor, self.default_timeout, request, response
) )
return body return body
@defer.inlineCallbacks async def get_file(
def get_file(
self, self,
destination, destination,
path, path,
@ -885,7 +878,7 @@ class MatrixFederationHttpClient(object):
and try the request anyway. and try the request anyway.
Returns: Returns:
Deferred[tuple[int, dict]]: Resolves with an (int,dict) tuple of tuple[int, dict]: Resolves with an (int,dict) tuple of
the file length and a dict of the response headers. the file length and a dict of the response headers.
Raises: Raises:
@ -902,7 +895,7 @@ class MatrixFederationHttpClient(object):
method="GET", destination=destination, path=path, query=args method="GET", destination=destination, path=path, query=args
) )
response = yield self._send_request( response = await self._send_request(
request, retry_on_dns_fail=retry_on_dns_fail, ignore_backoff=ignore_backoff request, retry_on_dns_fail=retry_on_dns_fail, ignore_backoff=ignore_backoff
) )
@ -911,7 +904,7 @@ class MatrixFederationHttpClient(object):
try: try:
d = _readBodyToFile(response, output_stream, max_size) d = _readBodyToFile(response, output_stream, max_size)
d.addTimeout(self.default_timeout, self.reactor) d.addTimeout(self.default_timeout, self.reactor)
length = yield make_deferred_yieldable(d) length = await make_deferred_yieldable(d)
except Exception as e: except Exception as e:
logger.warning( logger.warning(
"{%s} [%s] Error reading response: %s", "{%s} [%s] Error reading response: %s",

View file

@ -15,8 +15,6 @@
import logging import logging
from twisted.internet import defer
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from .bulk_push_rule_evaluator import BulkPushRuleEvaluator from .bulk_push_rule_evaluator import BulkPushRuleEvaluator
@ -37,7 +35,6 @@ class ActionGenerator(object):
# event stream, so we just run the rules for a client with no profile # event stream, so we just run the rules for a client with no profile
# tag (ie. we just need all the users). # tag (ie. we just need all the users).
@defer.inlineCallbacks async def handle_push_actions_for_event(self, event, context):
def handle_push_actions_for_event(self, event, context):
with Measure(self.clock, "action_for_event_by_user"): with Measure(self.clock, "action_for_event_by_user"):
yield self.bulk_evaluator.action_for_event_by_user(event, context) await self.bulk_evaluator.action_for_event_by_user(event, context)

View file

@ -19,8 +19,6 @@ from collections import namedtuple
from prometheus_client import Counter from prometheus_client import Counter
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.event_auth import get_user_power_level from synapse.event_auth import get_user_power_level
from synapse.state import POWER_KEY from synapse.state import POWER_KEY
@ -70,8 +68,7 @@ class BulkPushRuleEvaluator(object):
resizable=False, resizable=False,
) )
@defer.inlineCallbacks async def _get_rules_for_event(self, event, context):
def _get_rules_for_event(self, event, context):
"""This gets the rules for all users in the room at the time of the event, """This gets the rules for all users in the room at the time of the event,
as well as the push rules for the invitee if the event is an invite. as well as the push rules for the invitee if the event is an invite.
@ -79,19 +76,19 @@ class BulkPushRuleEvaluator(object):
dict of user_id -> push_rules dict of user_id -> push_rules
""" """
room_id = event.room_id room_id = event.room_id
rules_for_room = yield self._get_rules_for_room(room_id) rules_for_room = await self._get_rules_for_room(room_id)
rules_by_user = yield rules_for_room.get_rules(event, context) rules_by_user = await rules_for_room.get_rules(event, context)
# if this event is an invite event, we may need to run rules for the user # if this event is an invite event, we may need to run rules for the user
# who's been invited, otherwise they won't get told they've been invited # who's been invited, otherwise they won't get told they've been invited
if event.type == "m.room.member" and event.content["membership"] == "invite": if event.type == "m.room.member" and event.content["membership"] == "invite":
invited = event.state_key invited = event.state_key
if invited and self.hs.is_mine_id(invited): if invited and self.hs.is_mine_id(invited):
has_pusher = yield self.store.user_has_pusher(invited) has_pusher = await self.store.user_has_pusher(invited)
if has_pusher: if has_pusher:
rules_by_user = dict(rules_by_user) rules_by_user = dict(rules_by_user)
rules_by_user[invited] = yield self.store.get_push_rules_for_user( rules_by_user[invited] = await self.store.get_push_rules_for_user(
invited invited
) )
@ -114,20 +111,19 @@ class BulkPushRuleEvaluator(object):
self.room_push_rule_cache_metrics, self.room_push_rule_cache_metrics,
) )
@defer.inlineCallbacks async def _get_power_levels_and_sender_level(self, event, context):
def _get_power_levels_and_sender_level(self, event, context): prev_state_ids = await context.get_prev_state_ids()
prev_state_ids = yield context.get_prev_state_ids()
pl_event_id = prev_state_ids.get(POWER_KEY) pl_event_id = prev_state_ids.get(POWER_KEY)
if pl_event_id: if pl_event_id:
# fastpath: if there's a power level event, that's all we need, and # fastpath: if there's a power level event, that's all we need, and
# not having a power level event is an extreme edge case # not having a power level event is an extreme edge case
pl_event = yield self.store.get_event(pl_event_id) pl_event = await self.store.get_event(pl_event_id)
auth_events = {POWER_KEY: pl_event} auth_events = {POWER_KEY: pl_event}
else: else:
auth_events_ids = yield self.auth.compute_auth_events( auth_events_ids = await self.auth.compute_auth_events(
event, prev_state_ids, for_verification=False event, prev_state_ids, for_verification=False
) )
auth_events = yield self.store.get_events(auth_events_ids) auth_events = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events.values()} auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
sender_level = get_user_power_level(event.sender, auth_events) sender_level = get_user_power_level(event.sender, auth_events)
@ -136,23 +132,19 @@ class BulkPushRuleEvaluator(object):
return pl_event.content if pl_event else {}, sender_level return pl_event.content if pl_event else {}, sender_level
@defer.inlineCallbacks async def action_for_event_by_user(self, event, context) -> None:
def action_for_event_by_user(self, event, context):
"""Given an event and context, evaluate the push rules and insert the """Given an event and context, evaluate the push rules and insert the
results into the event_push_actions_staging table. results into the event_push_actions_staging table.
Returns:
Deferred
""" """
rules_by_user = yield self._get_rules_for_event(event, context) rules_by_user = await self._get_rules_for_event(event, context)
actions_by_user = {} actions_by_user = {}
room_members = yield self.store.get_joined_users_from_context(event, context) room_members = await self.store.get_joined_users_from_context(event, context)
( (
power_levels, power_levels,
sender_power_level, sender_power_level,
) = yield self._get_power_levels_and_sender_level(event, context) ) = await self._get_power_levels_and_sender_level(event, context)
evaluator = PushRuleEvaluatorForEvent( evaluator = PushRuleEvaluatorForEvent(
event, len(room_members), sender_power_level, power_levels event, len(room_members), sender_power_level, power_levels
@ -165,7 +157,7 @@ class BulkPushRuleEvaluator(object):
continue continue
if not event.is_state(): if not event.is_state():
is_ignored = yield self.store.is_ignored_by(event.sender, uid) is_ignored = await self.store.is_ignored_by(event.sender, uid)
if is_ignored: if is_ignored:
continue continue
@ -197,7 +189,7 @@ class BulkPushRuleEvaluator(object):
# Mark in the DB staging area the push actions for users who should be # Mark in the DB staging area the push actions for users who should be
# notified for this event. (This will then get handled when we persist # notified for this event. (This will then get handled when we persist
# the event) # the event)
yield self.store.add_push_actions_to_staging(event.event_id, actions_by_user) await self.store.add_push_actions_to_staging(event.event_id, actions_by_user)
def _condition_checker(evaluator, conditions, uid, display_name, cache): def _condition_checker(evaluator, conditions, uid, display_name, cache):
@ -274,8 +266,7 @@ class RulesForRoom(object):
# to self around in the callback. # to self around in the callback.
self.invalidate_all_cb = _Invalidation(rules_for_room_cache, room_id) self.invalidate_all_cb = _Invalidation(rules_for_room_cache, room_id)
@defer.inlineCallbacks async def get_rules(self, event, context):
def get_rules(self, event, context):
"""Given an event context return the rules for all users who are """Given an event context return the rules for all users who are
currently in the room. currently in the room.
""" """
@ -286,7 +277,7 @@ class RulesForRoom(object):
self.room_push_rule_cache_metrics.inc_hits() self.room_push_rule_cache_metrics.inc_hits()
return self.rules_by_user return self.rules_by_user
with (yield self.linearizer.queue(())): with (await self.linearizer.queue(())):
if state_group and self.state_group == state_group: if state_group and self.state_group == state_group:
logger.debug("Using cached rules for %r", self.room_id) logger.debug("Using cached rules for %r", self.room_id)
self.room_push_rule_cache_metrics.inc_hits() self.room_push_rule_cache_metrics.inc_hits()
@ -304,9 +295,7 @@ class RulesForRoom(object):
push_rules_delta_state_cache_metric.inc_hits() push_rules_delta_state_cache_metric.inc_hits()
else: else:
current_state_ids = yield defer.ensureDeferred( current_state_ids = await context.get_current_state_ids()
context.get_current_state_ids()
)
push_rules_delta_state_cache_metric.inc_misses() push_rules_delta_state_cache_metric.inc_misses()
push_rules_state_size_counter.inc(len(current_state_ids)) push_rules_state_size_counter.inc(len(current_state_ids))
@ -353,7 +342,7 @@ class RulesForRoom(object):
# If we have some memebr events we haven't seen, look them up # If we have some memebr events we haven't seen, look them up
# and fetch push rules for them if appropriate. # and fetch push rules for them if appropriate.
logger.debug("Found new member events %r", missing_member_event_ids) logger.debug("Found new member events %r", missing_member_event_ids)
yield self._update_rules_with_member_event_ids( await self._update_rules_with_member_event_ids(
ret_rules_by_user, missing_member_event_ids, state_group, event ret_rules_by_user, missing_member_event_ids, state_group, event
) )
else: else:
@ -371,8 +360,7 @@ class RulesForRoom(object):
) )
return ret_rules_by_user return ret_rules_by_user
@defer.inlineCallbacks async def _update_rules_with_member_event_ids(
def _update_rules_with_member_event_ids(
self, ret_rules_by_user, member_event_ids, state_group, event self, ret_rules_by_user, member_event_ids, state_group, event
): ):
"""Update the partially filled rules_by_user dict by fetching rules for """Update the partially filled rules_by_user dict by fetching rules for
@ -388,7 +376,7 @@ class RulesForRoom(object):
""" """
sequence = self.sequence sequence = self.sequence
rows = yield self.store.get_membership_from_event_ids(member_event_ids.values()) rows = await self.store.get_membership_from_event_ids(member_event_ids.values())
members = {row["event_id"]: (row["user_id"], row["membership"]) for row in rows} members = {row["event_id"]: (row["user_id"], row["membership"]) for row in rows}
@ -410,7 +398,7 @@ class RulesForRoom(object):
logger.debug("Joined: %r", interested_in_user_ids) logger.debug("Joined: %r", interested_in_user_ids)
if_users_with_pushers = yield self.store.get_if_users_have_pushers( if_users_with_pushers = await self.store.get_if_users_have_pushers(
interested_in_user_ids, on_invalidate=self.invalidate_all_cb interested_in_user_ids, on_invalidate=self.invalidate_all_cb
) )
@ -420,7 +408,7 @@ class RulesForRoom(object):
logger.debug("With pushers: %r", user_ids) logger.debug("With pushers: %r", user_ids)
users_with_receipts = yield self.store.get_users_with_read_receipts_in_room( users_with_receipts = await self.store.get_users_with_read_receipts_in_room(
self.room_id, on_invalidate=self.invalidate_all_cb self.room_id, on_invalidate=self.invalidate_all_cb
) )
@ -431,7 +419,7 @@ class RulesForRoom(object):
if uid in interested_in_user_ids: if uid in interested_in_user_ids:
user_ids.add(uid) user_ids.add(uid)
rules_by_user = yield self.store.bulk_get_push_rules( rules_by_user = await self.store.bulk_get_push_rules(
user_ids, on_invalidate=self.invalidate_all_cb user_ids, on_invalidate=self.invalidate_all_cb
) )

View file

@ -17,7 +17,6 @@ import logging
from prometheus_client import Counter from prometheus_client import Counter
from twisted.internet import defer
from twisted.internet.error import AlreadyCalled, AlreadyCancelled from twisted.internet.error import AlreadyCalled, AlreadyCancelled
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
@ -128,12 +127,11 @@ class HttpPusher(object):
# but currently that's the only type of receipt anyway... # but currently that's the only type of receipt anyway...
run_as_background_process("http_pusher.on_new_receipts", self._update_badge) run_as_background_process("http_pusher.on_new_receipts", self._update_badge)
@defer.inlineCallbacks async def _update_badge(self):
def _update_badge(self):
# XXX as per https://github.com/matrix-org/matrix-doc/issues/2627, this seems # XXX as per https://github.com/matrix-org/matrix-doc/issues/2627, this seems
# to be largely redundant. perhaps we can remove it. # to be largely redundant. perhaps we can remove it.
badge = yield push_tools.get_badge_count(self.hs.get_datastore(), self.user_id) badge = await push_tools.get_badge_count(self.hs.get_datastore(), self.user_id)
yield self._send_badge(badge) await self._send_badge(badge)
def on_timer(self): def on_timer(self):
self._start_processing() self._start_processing()
@ -152,8 +150,7 @@ class HttpPusher(object):
run_as_background_process("httppush.process", self._process) run_as_background_process("httppush.process", self._process)
@defer.inlineCallbacks async def _process(self):
def _process(self):
# we should never get here if we are already processing # we should never get here if we are already processing
assert not self._is_processing assert not self._is_processing
@ -164,7 +161,7 @@ class HttpPusher(object):
while True: while True:
starting_max_ordering = self.max_stream_ordering starting_max_ordering = self.max_stream_ordering
try: try:
yield self._unsafe_process() await self._unsafe_process()
except Exception: except Exception:
logger.exception("Exception processing notifs") logger.exception("Exception processing notifs")
if self.max_stream_ordering == starting_max_ordering: if self.max_stream_ordering == starting_max_ordering:
@ -172,8 +169,7 @@ class HttpPusher(object):
finally: finally:
self._is_processing = False self._is_processing = False
@defer.inlineCallbacks async def _unsafe_process(self):
def _unsafe_process(self):
""" """
Looks for unset notifications and dispatch them, in order Looks for unset notifications and dispatch them, in order
Never call this directly: use _process which will only allow this to Never call this directly: use _process which will only allow this to
@ -181,7 +177,7 @@ class HttpPusher(object):
""" """
fn = self.store.get_unread_push_actions_for_user_in_range_for_http fn = self.store.get_unread_push_actions_for_user_in_range_for_http
unprocessed = yield fn( unprocessed = await fn(
self.user_id, self.last_stream_ordering, self.max_stream_ordering self.user_id, self.last_stream_ordering, self.max_stream_ordering
) )
@ -203,13 +199,13 @@ class HttpPusher(object):
"app_display_name": self.app_display_name, "app_display_name": self.app_display_name,
}, },
): ):
processed = yield self._process_one(push_action) processed = await self._process_one(push_action)
if processed: if processed:
http_push_processed_counter.inc() http_push_processed_counter.inc()
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.last_stream_ordering = push_action["stream_ordering"] self.last_stream_ordering = push_action["stream_ordering"]
pusher_still_exists = yield self.store.update_pusher_last_stream_ordering_and_success( pusher_still_exists = await self.store.update_pusher_last_stream_ordering_and_success(
self.app_id, self.app_id,
self.pushkey, self.pushkey,
self.user_id, self.user_id,
@ -224,14 +220,14 @@ class HttpPusher(object):
if self.failing_since: if self.failing_since:
self.failing_since = None self.failing_since = None
yield self.store.update_pusher_failing_since( await self.store.update_pusher_failing_since(
self.app_id, self.pushkey, self.user_id, self.failing_since self.app_id, self.pushkey, self.user_id, self.failing_since
) )
else: else:
http_push_failed_counter.inc() http_push_failed_counter.inc()
if not self.failing_since: if not self.failing_since:
self.failing_since = self.clock.time_msec() self.failing_since = self.clock.time_msec()
yield self.store.update_pusher_failing_since( await self.store.update_pusher_failing_since(
self.app_id, self.pushkey, self.user_id, self.failing_since self.app_id, self.pushkey, self.user_id, self.failing_since
) )
@ -250,7 +246,7 @@ class HttpPusher(object):
) )
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.last_stream_ordering = push_action["stream_ordering"] self.last_stream_ordering = push_action["stream_ordering"]
pusher_still_exists = yield self.store.update_pusher_last_stream_ordering( pusher_still_exists = await self.store.update_pusher_last_stream_ordering(
self.app_id, self.app_id,
self.pushkey, self.pushkey,
self.user_id, self.user_id,
@ -263,7 +259,7 @@ class HttpPusher(object):
return return
self.failing_since = None self.failing_since = None
yield self.store.update_pusher_failing_since( await self.store.update_pusher_failing_since(
self.app_id, self.pushkey, self.user_id, self.failing_since self.app_id, self.pushkey, self.user_id, self.failing_since
) )
else: else:
@ -276,18 +272,17 @@ class HttpPusher(object):
) )
break break
@defer.inlineCallbacks async def _process_one(self, push_action):
def _process_one(self, push_action):
if "notify" not in push_action["actions"]: if "notify" not in push_action["actions"]:
return True return True
tweaks = push_rule_evaluator.tweaks_for_actions(push_action["actions"]) tweaks = push_rule_evaluator.tweaks_for_actions(push_action["actions"])
badge = yield push_tools.get_badge_count(self.hs.get_datastore(), self.user_id) badge = await push_tools.get_badge_count(self.hs.get_datastore(), self.user_id)
event = yield self.store.get_event(push_action["event_id"], allow_none=True) event = await self.store.get_event(push_action["event_id"], allow_none=True)
if event is None: if event is None:
return True # It's been redacted return True # It's been redacted
rejected = yield self.dispatch_push(event, tweaks, badge) rejected = await self.dispatch_push(event, tweaks, badge)
if rejected is False: if rejected is False:
return False return False
@ -301,11 +296,10 @@ class HttpPusher(object):
) )
else: else:
logger.info("Pushkey %s was rejected: removing", pk) logger.info("Pushkey %s was rejected: removing", pk)
yield self.hs.remove_pusher(self.app_id, pk, self.user_id) await self.hs.remove_pusher(self.app_id, pk, self.user_id)
return True return True
@defer.inlineCallbacks async def _build_notification_dict(self, event, tweaks, badge):
def _build_notification_dict(self, event, tweaks, badge):
priority = "low" priority = "low"
if ( if (
event.type == EventTypes.Encrypted event.type == EventTypes.Encrypted
@ -335,7 +329,7 @@ class HttpPusher(object):
} }
return d return d
ctx = yield push_tools.get_context_for_event( ctx = await push_tools.get_context_for_event(
self.storage, self.state_handler, event, self.user_id self.storage, self.state_handler, event, self.user_id
) )
@ -377,13 +371,12 @@ class HttpPusher(object):
return d return d
@defer.inlineCallbacks async def dispatch_push(self, event, tweaks, badge):
def dispatch_push(self, event, tweaks, badge): notification_dict = await self._build_notification_dict(event, tweaks, badge)
notification_dict = yield self._build_notification_dict(event, tweaks, badge)
if not notification_dict: if not notification_dict:
return [] return []
try: try:
resp = yield self.http_client.post_json_get_json( resp = await self.http_client.post_json_get_json(
self.url, notification_dict self.url, notification_dict
) )
except Exception as e: except Exception as e:
@ -400,8 +393,7 @@ class HttpPusher(object):
rejected = resp["rejected"] rejected = resp["rejected"]
return rejected return rejected
@defer.inlineCallbacks async def _send_badge(self, badge):
def _send_badge(self, badge):
""" """
Args: Args:
badge (int): number of unread messages badge (int): number of unread messages
@ -424,7 +416,7 @@ class HttpPusher(object):
} }
} }
try: try:
yield self.http_client.post_json_get_json(self.url, d) await self.http_client.post_json_get_json(self.url, d)
http_badges_processed_counter.inc() http_badges_processed_counter.inc()
except Exception as e: except Exception as e:
logger.warning( logger.warning(

View file

@ -16,8 +16,6 @@
import logging import logging
import re import re
from twisted.internet import defer
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -29,8 +27,7 @@ ALIAS_RE = re.compile(r"^#.*:.+$")
ALL_ALONE = "Empty Room" ALL_ALONE = "Empty Room"
@defer.inlineCallbacks async def calculate_room_name(
def calculate_room_name(
store, store,
room_state_ids, room_state_ids,
user_id, user_id,
@ -53,7 +50,7 @@ def calculate_room_name(
""" """
# does it have a name? # does it have a name?
if (EventTypes.Name, "") in room_state_ids: if (EventTypes.Name, "") in room_state_ids:
m_room_name = yield store.get_event( m_room_name = await store.get_event(
room_state_ids[(EventTypes.Name, "")], allow_none=True room_state_ids[(EventTypes.Name, "")], allow_none=True
) )
if m_room_name and m_room_name.content and m_room_name.content["name"]: if m_room_name and m_room_name.content and m_room_name.content["name"]:
@ -61,7 +58,7 @@ def calculate_room_name(
# does it have a canonical alias? # does it have a canonical alias?
if (EventTypes.CanonicalAlias, "") in room_state_ids: if (EventTypes.CanonicalAlias, "") in room_state_ids:
canon_alias = yield store.get_event( canon_alias = await store.get_event(
room_state_ids[(EventTypes.CanonicalAlias, "")], allow_none=True room_state_ids[(EventTypes.CanonicalAlias, "")], allow_none=True
) )
if ( if (
@ -81,7 +78,7 @@ def calculate_room_name(
my_member_event = None my_member_event = None
if (EventTypes.Member, user_id) in room_state_ids: if (EventTypes.Member, user_id) in room_state_ids:
my_member_event = yield store.get_event( my_member_event = await store.get_event(
room_state_ids[(EventTypes.Member, user_id)], allow_none=True room_state_ids[(EventTypes.Member, user_id)], allow_none=True
) )
@ -90,7 +87,7 @@ def calculate_room_name(
and my_member_event.content["membership"] == "invite" and my_member_event.content["membership"] == "invite"
): ):
if (EventTypes.Member, my_member_event.sender) in room_state_ids: if (EventTypes.Member, my_member_event.sender) in room_state_ids:
inviter_member_event = yield store.get_event( inviter_member_event = await store.get_event(
room_state_ids[(EventTypes.Member, my_member_event.sender)], room_state_ids[(EventTypes.Member, my_member_event.sender)],
allow_none=True, allow_none=True,
) )
@ -107,7 +104,7 @@ def calculate_room_name(
# we're going to have to generate a name based on who's in the room, # we're going to have to generate a name based on who's in the room,
# so find out who is in the room that isn't the user. # so find out who is in the room that isn't the user.
if EventTypes.Member in room_state_bytype_ids: if EventTypes.Member in room_state_bytype_ids:
member_events = yield store.get_events( member_events = await store.get_events(
list(room_state_bytype_ids[EventTypes.Member].values()) list(room_state_bytype_ids[EventTypes.Member].values())
) )
all_members = [ all_members = [

View file

@ -13,53 +13,40 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
from synapse.push.presentable_names import calculate_room_name, name_from_member_event from synapse.push.presentable_names import calculate_room_name, name_from_member_event
from synapse.storage import Storage from synapse.storage import Storage
@defer.inlineCallbacks async def get_badge_count(store, user_id):
def get_badge_count(store, user_id): invites = await store.get_invited_rooms_for_local_user(user_id)
invites = yield store.get_invited_rooms_for_local_user(user_id) joins = await store.get_rooms_for_user(user_id)
joins = yield store.get_rooms_for_user(user_id)
my_receipts_by_room = yield store.get_receipts_for_user(user_id, "m.read")
badge = len(invites) badge = len(invites)
for room_id in joins: for room_id in joins:
if room_id in my_receipts_by_room: unread_count = await store.get_unread_message_count_for_user(room_id, user_id)
last_unread_event_id = my_receipts_by_room[room_id] # return one badge count per conversation, as count per
# message is so noisy as to be almost useless
notifs = yield ( badge += 1 if unread_count else 0
store.get_unread_event_push_actions_by_room_for_user(
room_id, user_id, last_unread_event_id
)
)
# return one badge count per conversation, as count per
# message is so noisy as to be almost useless
badge += 1 if notifs["notify_count"] else 0
return badge return badge
@defer.inlineCallbacks async def get_context_for_event(storage: Storage, state_handler, ev, user_id):
def get_context_for_event(storage: Storage, state_handler, ev, user_id):
ctx = {} ctx = {}
room_state_ids = yield storage.state.get_state_ids_for_event(ev.event_id) room_state_ids = await storage.state.get_state_ids_for_event(ev.event_id)
# we no longer bother setting room_alias, and make room_name the # we no longer bother setting room_alias, and make room_name the
# human-readable name instead, be that m.room.name, an alias or # human-readable name instead, be that m.room.name, an alias or
# a list of people in the room # a list of people in the room
name = yield calculate_room_name( name = await calculate_room_name(
storage.main, room_state_ids, user_id, fallback_to_single_member=False storage.main, room_state_ids, user_id, fallback_to_single_member=False
) )
if name: if name:
ctx["name"] = name ctx["name"] = name
sender_state_event_id = room_state_ids[("m.room.member", ev.sender)] sender_state_event_id = room_state_ids[("m.room.member", ev.sender)]
sender_state_event = yield storage.main.get_event(sender_state_event_id) sender_state_event = await storage.main.get_event(sender_state_event_id)
ctx["sender_display_name"] = name_from_member_event(sender_state_event) ctx["sender_display_name"] = name_from_member_event(sender_state_event)
return ctx return ctx

View file

@ -19,8 +19,6 @@ from typing import TYPE_CHECKING, Dict, Union
from prometheus_client import Gauge from prometheus_client import Gauge
from twisted.internet import defer
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.push import PusherConfigException from synapse.push import PusherConfigException
from synapse.push.emailpusher import EmailPusher from synapse.push.emailpusher import EmailPusher
@ -52,7 +50,7 @@ class PusherPool:
Note that it is expected that each pusher will have its own 'processing' loop which Note that it is expected that each pusher will have its own 'processing' loop which
will send out the notifications in the background, rather than blocking until the will send out the notifications in the background, rather than blocking until the
notifications are sent; accordingly Pusher.on_started, Pusher.on_new_notifications and notifications are sent; accordingly Pusher.on_started, Pusher.on_new_notifications and
Pusher.on_new_receipts are not expected to return deferreds. Pusher.on_new_receipts are not expected to return awaitables.
""" """
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
@ -77,8 +75,7 @@ class PusherPool:
return return
run_as_background_process("start_pushers", self._start_pushers) run_as_background_process("start_pushers", self._start_pushers)
@defer.inlineCallbacks async def add_pusher(
def add_pusher(
self, self,
user_id, user_id,
access_token, access_token,
@ -94,7 +91,7 @@ class PusherPool:
"""Creates a new pusher and adds it to the pool """Creates a new pusher and adds it to the pool
Returns: Returns:
Deferred[EmailPusher|HttpPusher] EmailPusher|HttpPusher
""" """
time_now_msec = self.clock.time_msec() time_now_msec = self.clock.time_msec()
@ -124,9 +121,9 @@ class PusherPool:
# create the pusher setting last_stream_ordering to the current maximum # create the pusher setting last_stream_ordering to the current maximum
# stream ordering in event_push_actions, so it will process # stream ordering in event_push_actions, so it will process
# pushes from this point onwards. # pushes from this point onwards.
last_stream_ordering = yield self.store.get_latest_push_action_stream_ordering() last_stream_ordering = await self.store.get_latest_push_action_stream_ordering()
yield self.store.add_pusher( await self.store.add_pusher(
user_id=user_id, user_id=user_id,
access_token=access_token, access_token=access_token,
kind=kind, kind=kind,
@ -140,15 +137,14 @@ class PusherPool:
last_stream_ordering=last_stream_ordering, last_stream_ordering=last_stream_ordering,
profile_tag=profile_tag, profile_tag=profile_tag,
) )
pusher = yield self.start_pusher_by_id(app_id, pushkey, user_id) pusher = await self.start_pusher_by_id(app_id, pushkey, user_id)
return pusher return pusher
@defer.inlineCallbacks async def remove_pushers_by_app_id_and_pushkey_not_user(
def remove_pushers_by_app_id_and_pushkey_not_user(
self, app_id, pushkey, not_user_id self, app_id, pushkey, not_user_id
): ):
to_remove = yield self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) to_remove = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
for p in to_remove: for p in to_remove:
if p["user_name"] != not_user_id: if p["user_name"] != not_user_id:
logger.info( logger.info(
@ -157,10 +153,9 @@ class PusherPool:
pushkey, pushkey,
p["user_name"], p["user_name"],
) )
yield self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"]) await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"])
@defer.inlineCallbacks async def remove_pushers_by_access_token(self, user_id, access_tokens):
def remove_pushers_by_access_token(self, user_id, access_tokens):
"""Remove the pushers for a given user corresponding to a set of """Remove the pushers for a given user corresponding to a set of
access_tokens. access_tokens.
@ -173,7 +168,7 @@ class PusherPool:
return return
tokens = set(access_tokens) tokens = set(access_tokens)
for p in (yield self.store.get_pushers_by_user_id(user_id)): for p in await self.store.get_pushers_by_user_id(user_id):
if p["access_token"] in tokens: if p["access_token"] in tokens:
logger.info( logger.info(
"Removing pusher for app id %s, pushkey %s, user %s", "Removing pusher for app id %s, pushkey %s, user %s",
@ -181,16 +176,15 @@ class PusherPool:
p["pushkey"], p["pushkey"],
p["user_name"], p["user_name"],
) )
yield self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"]) await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"])
@defer.inlineCallbacks async def on_new_notifications(self, min_stream_id, max_stream_id):
def on_new_notifications(self, min_stream_id, max_stream_id):
if not self.pushers: if not self.pushers:
# nothing to do here. # nothing to do here.
return return
try: try:
users_affected = yield self.store.get_push_action_users_in_range( users_affected = await self.store.get_push_action_users_in_range(
min_stream_id, max_stream_id min_stream_id, max_stream_id
) )
@ -202,8 +196,7 @@ class PusherPool:
except Exception: except Exception:
logger.exception("Exception in pusher on_new_notifications") logger.exception("Exception in pusher on_new_notifications")
@defer.inlineCallbacks async def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids):
def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids):
if not self.pushers: if not self.pushers:
# nothing to do here. # nothing to do here.
return return
@ -211,7 +204,7 @@ class PusherPool:
try: try:
# Need to subtract 1 from the minimum because the lower bound here # Need to subtract 1 from the minimum because the lower bound here
# is not inclusive # is not inclusive
users_affected = yield self.store.get_users_sent_receipts_between( users_affected = await self.store.get_users_sent_receipts_between(
min_stream_id - 1, max_stream_id min_stream_id - 1, max_stream_id
) )
@ -223,12 +216,11 @@ class PusherPool:
except Exception: except Exception:
logger.exception("Exception in pusher on_new_receipts") logger.exception("Exception in pusher on_new_receipts")
@defer.inlineCallbacks async def start_pusher_by_id(self, app_id, pushkey, user_id):
def start_pusher_by_id(self, app_id, pushkey, user_id):
"""Look up the details for the given pusher, and start it """Look up the details for the given pusher, and start it
Returns: Returns:
Deferred[EmailPusher|HttpPusher|None]: The pusher started, if any EmailPusher|HttpPusher|None: The pusher started, if any
""" """
if not self._should_start_pushers: if not self._should_start_pushers:
return return
@ -236,7 +228,7 @@ class PusherPool:
if not self._pusher_shard_config.should_handle(self._instance_name, user_id): if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
return return
resultlist = yield self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
pusher_dict = None pusher_dict = None
for r in resultlist: for r in resultlist:
@ -245,34 +237,29 @@ class PusherPool:
pusher = None pusher = None
if pusher_dict: if pusher_dict:
pusher = yield self._start_pusher(pusher_dict) pusher = await self._start_pusher(pusher_dict)
return pusher return pusher
@defer.inlineCallbacks async def _start_pushers(self) -> None:
def _start_pushers(self):
"""Start all the pushers """Start all the pushers
Returns:
Deferred
""" """
pushers = yield self.store.get_all_pushers() pushers = await self.store.get_all_pushers()
# Stagger starting up the pushers so we don't completely drown the # Stagger starting up the pushers so we don't completely drown the
# process on start up. # process on start up.
yield concurrently_execute(self._start_pusher, pushers, 10) await concurrently_execute(self._start_pusher, pushers, 10)
logger.info("Started pushers") logger.info("Started pushers")
@defer.inlineCallbacks async def _start_pusher(self, pusherdict):
def _start_pusher(self, pusherdict):
"""Start the given pusher """Start the given pusher
Args: Args:
pusherdict (dict): dict with the values pulled from the db table pusherdict (dict): dict with the values pulled from the db table
Returns: Returns:
Deferred[EmailPusher|HttpPusher] EmailPusher|HttpPusher
""" """
if not self._pusher_shard_config.should_handle( if not self._pusher_shard_config.should_handle(
self._instance_name, pusherdict["user_name"] self._instance_name, pusherdict["user_name"]
@ -315,7 +302,7 @@ class PusherPool:
user_id = pusherdict["user_name"] user_id = pusherdict["user_name"]
last_stream_ordering = pusherdict["last_stream_ordering"] last_stream_ordering = pusherdict["last_stream_ordering"]
if last_stream_ordering: if last_stream_ordering:
have_notifs = yield self.store.get_if_maybe_push_in_range_for_user( have_notifs = await self.store.get_if_maybe_push_in_range_for_user(
user_id, last_stream_ordering user_id, last_stream_ordering
) )
else: else:
@ -327,8 +314,7 @@ class PusherPool:
return p return p
@defer.inlineCallbacks async def remove_pusher(self, app_id, pushkey, user_id):
def remove_pusher(self, app_id, pushkey, user_id):
appid_pushkey = "%s:%s" % (app_id, pushkey) appid_pushkey = "%s:%s" % (app_id, pushkey)
byuser = self.pushers.get(user_id, {}) byuser = self.pushers.get(user_id, {})
@ -340,6 +326,6 @@ class PusherPool:
synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec() synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec()
yield self.store.delete_pusher_by_app_id_pushkey_user_id( await self.store.delete_pusher_by_app_id_pushkey_user_id(
app_id, pushkey, user_id app_id, pushkey, user_id
) )

View file

@ -43,7 +43,7 @@ REQUIREMENTS = [
"jsonschema>=2.5.1", "jsonschema>=2.5.1",
"frozendict>=1", "frozendict>=1",
"unpaddedbase64>=1.1.0", "unpaddedbase64>=1.1.0",
"canonicaljson>=1.1.3", "canonicaljson>=1.2.0",
# we use the type definitions added in signedjson 1.1. # we use the type definitions added in signedjson 1.1.
"signedjson>=1.1.0", "signedjson>=1.1.0",
"pynacl>=1.2.1", "pynacl>=1.2.1",

View file

@ -78,7 +78,9 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
""" """
event_payloads = [] event_payloads = []
for event, context in event_and_contexts: for event, context in event_and_contexts:
serialized_context = yield context.serialize(event, store) serialized_context = yield defer.ensureDeferred(
context.serialize(event, store)
)
event_payloads.append( event_payloads.append(
{ {

View file

@ -77,7 +77,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
extra_users (list(UserID)): Any extra users to notify about event extra_users (list(UserID)): Any extra users to notify about event
""" """
serialized_context = yield context.serialize(event, store) serialized_context = yield defer.ensureDeferred(context.serialize(event, store))
payload = { payload = {
"event": event.get_pdu_json(), "event": event.get_pdu_json(),

View file

@ -103,6 +103,14 @@ class DeleteRoomRestServlet(RestServlet):
Codes.BAD_JSON, Codes.BAD_JSON,
) )
purge = content.get("purge", True)
if not isinstance(purge, bool):
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Param 'purge' must be a boolean, if given",
Codes.BAD_JSON,
)
ret = await self.room_shutdown_handler.shutdown_room( ret = await self.room_shutdown_handler.shutdown_room(
room_id=room_id, room_id=room_id,
new_room_user_id=content.get("new_room_user_id"), new_room_user_id=content.get("new_room_user_id"),
@ -113,7 +121,8 @@ class DeleteRoomRestServlet(RestServlet):
) )
# Purge room # Purge room
await self.pagination_handler.purge_room(room_id) if purge:
await self.pagination_handler.purge_room(room_id)
return (200, ret) return (200, ret)

View file

@ -426,6 +426,7 @@ class SyncRestServlet(RestServlet):
result["ephemeral"] = {"events": ephemeral_events} result["ephemeral"] = {"events": ephemeral_events}
result["unread_notifications"] = room.unread_notifications result["unread_notifications"] = room.unread_notifications
result["summary"] = room.summary result["summary"] = room.summary
result["org.matrix.msc2654.unread_count"] = room.unread_count
return result return result

View file

@ -17,7 +17,9 @@
import logging import logging
import os import os
import urllib import urllib
from typing import Awaitable
from twisted.internet.interfaces import IConsumer
from twisted.protocols.basic import FileSender from twisted.protocols.basic import FileSender
from synapse.api.errors import Codes, SynapseError, cs_error from synapse.api.errors import Codes, SynapseError, cs_error
@ -240,14 +242,14 @@ class Responder(object):
held can be cleaned up. held can be cleaned up.
""" """
def write_to_consumer(self, consumer): def write_to_consumer(self, consumer: IConsumer) -> Awaitable:
"""Stream response into consumer """Stream response into consumer
Args: Args:
consumer (IConsumer) consumer: The consumer to stream into.
Returns: Returns:
Deferred: Resolves once the response has finished being written Resolves once the response has finished being written
""" """
pass pass

View file

@ -18,10 +18,11 @@ import errno
import logging import logging
import os import os
import shutil import shutil
from typing import Dict, Tuple from typing import IO, Dict, Optional, Tuple
import twisted.internet.error import twisted.internet.error
import twisted.web.http import twisted.web.http
from twisted.web.http import Request
from twisted.web.resource import Resource from twisted.web.resource import Resource
from synapse.api.errors import ( from synapse.api.errors import (
@ -40,6 +41,7 @@ from synapse.util.stringutils import random_string
from ._base import ( from ._base import (
FileInfo, FileInfo,
Responder,
get_filename_from_headers, get_filename_from_headers,
respond_404, respond_404,
respond_with_responder, respond_with_responder,
@ -135,19 +137,24 @@ class MediaRepository(object):
self.recently_accessed_locals.add(media_id) self.recently_accessed_locals.add(media_id)
async def create_content( async def create_content(
self, media_type, upload_name, content, content_length, auth_user self,
): media_type: str,
upload_name: str,
content: IO,
content_length: int,
auth_user: str,
) -> str:
"""Store uploaded content for a local user and return the mxc URL """Store uploaded content for a local user and return the mxc URL
Args: Args:
media_type(str): The content type of the file media_type: The content type of the file
upload_name(str): The name of the file upload_name: The name of the file
content: A file like object that is the content to store content: A file like object that is the content to store
content_length(int): The length of the content content_length: The length of the content
auth_user(str): The user_id of the uploader auth_user: The user_id of the uploader
Returns: Returns:
Deferred[str]: The mxc url of the stored content The mxc url of the stored content
""" """
media_id = random_string(24) media_id = random_string(24)
@ -170,19 +177,20 @@ class MediaRepository(object):
return "mxc://%s/%s" % (self.server_name, media_id) return "mxc://%s/%s" % (self.server_name, media_id)
async def get_local_media(self, request, media_id, name): async def get_local_media(
self, request: Request, media_id: str, name: Optional[str]
) -> None:
"""Responds to reqests for local media, if exists, or returns 404. """Responds to reqests for local media, if exists, or returns 404.
Args: Args:
request(twisted.web.http.Request) request: The incoming request.
media_id (str): The media ID of the content. (This is the same as media_id: The media ID of the content. (This is the same as
the file_id for local content.) the file_id for local content.)
name (str|None): Optional name that, if specified, will be used as name: Optional name that, if specified, will be used as
the filename in the Content-Disposition header of the response. the filename in the Content-Disposition header of the response.
Returns: Returns:
Deferred: Resolves once a response has successfully been written Resolves once a response has successfully been written to request
to request
""" """
media_info = await self.store.get_local_media(media_id) media_info = await self.store.get_local_media(media_id)
if not media_info or media_info["quarantined_by"]: if not media_info or media_info["quarantined_by"]:
@ -203,20 +211,20 @@ class MediaRepository(object):
request, responder, media_type, media_length, upload_name request, responder, media_type, media_length, upload_name
) )
async def get_remote_media(self, request, server_name, media_id, name): async def get_remote_media(
self, request: Request, server_name: str, media_id: str, name: Optional[str]
) -> None:
"""Respond to requests for remote media. """Respond to requests for remote media.
Args: Args:
request(twisted.web.http.Request) request: The incoming request.
server_name (str): Remote server_name where the media originated. server_name: Remote server_name where the media originated.
media_id (str): The media ID of the content (as defined by the media_id: The media ID of the content (as defined by the remote server).
remote server). name: Optional name that, if specified, will be used as
name (str|None): Optional name that, if specified, will be used as
the filename in the Content-Disposition header of the response. the filename in the Content-Disposition header of the response.
Returns: Returns:
Deferred: Resolves once a response has successfully been written Resolves once a response has successfully been written to request
to request
""" """
if ( if (
self.federation_domain_whitelist is not None self.federation_domain_whitelist is not None
@ -245,17 +253,16 @@ class MediaRepository(object):
else: else:
respond_404(request) respond_404(request)
async def get_remote_media_info(self, server_name, media_id): async def get_remote_media_info(self, server_name: str, media_id: str) -> dict:
"""Gets the media info associated with the remote file, downloading """Gets the media info associated with the remote file, downloading
if necessary. if necessary.
Args: Args:
server_name (str): Remote server_name where the media originated. server_name: Remote server_name where the media originated.
media_id (str): The media ID of the content (as defined by the media_id: The media ID of the content (as defined by the remote server).
remote server).
Returns: Returns:
Deferred[dict]: The media_info of the file The media info of the file
""" """
if ( if (
self.federation_domain_whitelist is not None self.federation_domain_whitelist is not None
@ -278,7 +285,9 @@ class MediaRepository(object):
return media_info return media_info
async def _get_remote_media_impl(self, server_name, media_id): async def _get_remote_media_impl(
self, server_name: str, media_id: str
) -> Tuple[Optional[Responder], dict]:
"""Looks for media in local cache, if not there then attempt to """Looks for media in local cache, if not there then attempt to
download from remote server. download from remote server.
@ -288,7 +297,7 @@ class MediaRepository(object):
remote server). remote server).
Returns: Returns:
Deferred[(Responder, media_info)] A tuple of responder and the media info of the file.
""" """
media_info = await self.store.get_cached_remote_media(server_name, media_id) media_info = await self.store.get_cached_remote_media(server_name, media_id)
@ -319,19 +328,21 @@ class MediaRepository(object):
responder = await self.media_storage.fetch_media(file_info) responder = await self.media_storage.fetch_media(file_info)
return responder, media_info return responder, media_info
async def _download_remote_file(self, server_name, media_id, file_id): async def _download_remote_file(
self, server_name: str, media_id: str, file_id: str
) -> dict:
"""Attempt to download the remote file from the given server name, """Attempt to download the remote file from the given server name,
using the given file_id as the local id. using the given file_id as the local id.
Args: Args:
server_name (str): Originating server server_name: Originating server
media_id (str): The media ID of the content (as defined by the media_id: The media ID of the content (as defined by the
remote server). This is different than the file_id, which is remote server). This is different than the file_id, which is
locally generated. locally generated.
file_id (str): Local file ID file_id: Local file ID
Returns: Returns:
Deferred[MediaInfo] The media info of the file.
""" """
file_info = FileInfo(server_name=server_name, file_id=file_id) file_info = FileInfo(server_name=server_name, file_id=file_id)
@ -549,25 +560,31 @@ class MediaRepository(object):
return output_path return output_path
async def _generate_thumbnails( async def _generate_thumbnails(
self, server_name, media_id, file_id, media_type, url_cache=False self,
): server_name: Optional[str],
media_id: str,
file_id: str,
media_type: str,
url_cache: bool = False,
) -> Optional[dict]:
"""Generate and store thumbnails for an image. """Generate and store thumbnails for an image.
Args: Args:
server_name (str|None): The server name if remote media, else None if local server_name: The server name if remote media, else None if local
media_id (str): The media ID of the content. (This is the same as media_id: The media ID of the content. (This is the same as
the file_id for local content) the file_id for local content)
file_id (str): Local file ID file_id: Local file ID
media_type (str): The content type of the file media_type: The content type of the file
url_cache (bool): If we are thumbnailing images downloaded for the URL cache, url_cache: If we are thumbnailing images downloaded for the URL cache,
used exclusively by the url previewer used exclusively by the url previewer
Returns: Returns:
Deferred[dict]: Dict with "width" and "height" keys of original image Dict with "width" and "height" keys of original image or None if the
media cannot be thumbnailed.
""" """
requirements = self._get_thumbnail_requirements(media_type) requirements = self._get_thumbnail_requirements(media_type)
if not requirements: if not requirements:
return return None
input_path = await self.media_storage.ensure_media_is_in_local_cache( input_path = await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(server_name, file_id, url_cache=url_cache) FileInfo(server_name, file_id, url_cache=url_cache)
@ -584,7 +601,7 @@ class MediaRepository(object):
m_height, m_height,
self.max_image_pixels, self.max_image_pixels,
) )
return return None
if thumbnailer.transpose_method is not None: if thumbnailer.transpose_method is not None:
m_width, m_height = await defer_to_thread( m_width, m_height = await defer_to_thread(

View file

@ -12,13 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import contextlib import contextlib
import inspect import inspect
import logging import logging
import os import os
import shutil import shutil
from typing import Optional from typing import IO, TYPE_CHECKING, Any, Optional, Sequence
from twisted.protocols.basic import FileSender from twisted.protocols.basic import FileSender
@ -26,6 +25,12 @@ from synapse.logging.context import defer_to_thread, make_deferred_yieldable
from synapse.util.file_consumer import BackgroundFileConsumer from synapse.util.file_consumer import BackgroundFileConsumer
from ._base import FileInfo, Responder from ._base import FileInfo, Responder
from .filepath import MediaFilePaths
if TYPE_CHECKING:
from synapse.server import HomeServer
from .storage_provider import StorageProvider
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -34,20 +39,25 @@ class MediaStorage(object):
"""Responsible for storing/fetching files from local sources. """Responsible for storing/fetching files from local sources.
Args: Args:
hs (synapse.server.Homeserver) hs
local_media_directory (str): Base path where we store media on disk local_media_directory: Base path where we store media on disk
filepaths (MediaFilePaths) filepaths
storage_providers ([StorageProvider]): List of StorageProvider that are storage_providers: List of StorageProvider that are used to fetch and store files.
used to fetch and store files.
""" """
def __init__(self, hs, local_media_directory, filepaths, storage_providers): def __init__(
self,
hs: "HomeServer",
local_media_directory: str,
filepaths: MediaFilePaths,
storage_providers: Sequence["StorageProvider"],
):
self.hs = hs self.hs = hs
self.local_media_directory = local_media_directory self.local_media_directory = local_media_directory
self.filepaths = filepaths self.filepaths = filepaths
self.storage_providers = storage_providers self.storage_providers = storage_providers
async def store_file(self, source, file_info: FileInfo) -> str: async def store_file(self, source: IO, file_info: FileInfo) -> str:
"""Write `source` to the on disk media store, and also any other """Write `source` to the on disk media store, and also any other
configured storage providers configured storage providers
@ -69,7 +79,7 @@ class MediaStorage(object):
return fname return fname
@contextlib.contextmanager @contextlib.contextmanager
def store_into_file(self, file_info): def store_into_file(self, file_info: FileInfo):
"""Context manager used to get a file like object to write into, as """Context manager used to get a file like object to write into, as
described by file_info. described by file_info.
@ -85,7 +95,7 @@ class MediaStorage(object):
error. error.
Args: Args:
file_info (FileInfo): Info about the file to store file_info: Info about the file to store
Example: Example:
@ -143,9 +153,9 @@ class MediaStorage(object):
return FileResponder(open(local_path, "rb")) return FileResponder(open(local_path, "rb"))
for provider in self.storage_providers: for provider in self.storage_providers:
res = provider.fetch(path, file_info) res = provider.fetch(path, file_info) # type: Any
# Fetch is supposed to return an Awaitable, but guard against # Fetch is supposed to return an Awaitable[Responder], but guard
# improper implementations. # against improper implementations.
if inspect.isawaitable(res): if inspect.isawaitable(res):
res = await res res = await res
if res: if res:
@ -174,9 +184,9 @@ class MediaStorage(object):
os.makedirs(dirname) os.makedirs(dirname)
for provider in self.storage_providers: for provider in self.storage_providers:
res = provider.fetch(path, file_info) res = provider.fetch(path, file_info) # type: Any
# Fetch is supposed to return an Awaitable, but guard against # Fetch is supposed to return an Awaitable[Responder], but guard
# improper implementations. # against improper implementations.
if inspect.isawaitable(res): if inspect.isawaitable(res):
res = await res res = await res
if res: if res:
@ -190,17 +200,11 @@ class MediaStorage(object):
raise Exception("file could not be found") raise Exception("file could not be found")
def _file_info_to_path(self, file_info): def _file_info_to_path(self, file_info: FileInfo) -> str:
"""Converts file_info into a relative path. """Converts file_info into a relative path.
The path is suitable for storing files under a directory, e.g. used to The path is suitable for storing files under a directory, e.g. used to
store files on local FS under the base media repository directory. store files on local FS under the base media repository directory.
Args:
file_info (FileInfo)
Returns:
str
""" """
if file_info.url_cache: if file_info.url_cache:
if file_info.thumbnail: if file_info.thumbnail:

View file

@ -231,16 +231,16 @@ class PreviewUrlResource(DirectServeJsonResource):
og = await make_deferred_yieldable(defer.maybeDeferred(observable.observe)) og = await make_deferred_yieldable(defer.maybeDeferred(observable.observe))
respond_with_json_bytes(request, 200, og, send_cors=True) respond_with_json_bytes(request, 200, og, send_cors=True)
async def _do_preview(self, url, user, ts): async def _do_preview(self, url: str, user: str, ts: int) -> bytes:
"""Check the db, and download the URL and build a preview """Check the db, and download the URL and build a preview
Args: Args:
url (str): url: The URL to preview.
user (str): user: The user requesting the preview.
ts (int): ts: The timestamp requested for the preview.
Returns: Returns:
Deferred[bytes]: json-encoded og data json-encoded og data
""" """
# check the URL cache in the DB (which will also provide us with # check the URL cache in the DB (which will also provide us with
# historical previews, if we have any) # historical previews, if we have any)

View file

@ -16,62 +16,62 @@
import logging import logging
import os import os
import shutil import shutil
from typing import Optional
from twisted.internet import defer
from synapse.config._base import Config from synapse.config._base import Config
from synapse.logging.context import defer_to_thread, run_in_background from synapse.logging.context import defer_to_thread, run_in_background
from ._base import FileInfo, Responder
from .media_storage import FileResponder from .media_storage import FileResponder
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class StorageProvider(object): class StorageProvider:
"""A storage provider is a service that can store uploaded media and """A storage provider is a service that can store uploaded media and
retrieve them. retrieve them.
""" """
def store_file(self, path, file_info): async def store_file(self, path: str, file_info: FileInfo):
"""Store the file described by file_info. The actual contents can be """Store the file described by file_info. The actual contents can be
retrieved by reading the file in file_info.upload_path. retrieved by reading the file in file_info.upload_path.
Args: Args:
path (str): Relative path of file in local cache path: Relative path of file in local cache
file_info (FileInfo) file_info: The metadata of the file.
Returns:
Deferred
""" """
pass
def fetch(self, path, file_info): async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
"""Attempt to fetch the file described by file_info and stream it """Attempt to fetch the file described by file_info and stream it
into writer. into writer.
Args: Args:
path (str): Relative path of file in local cache path: Relative path of file in local cache
file_info (FileInfo) file_info: The metadata of the file.
Returns: Returns:
Deferred(Responder): Returns a Responder if the provider has the file, Returns a Responder if the provider has the file, otherwise returns None.
otherwise returns None.
""" """
pass
class StorageProviderWrapper(StorageProvider): class StorageProviderWrapper(StorageProvider):
"""Wraps a storage provider and provides various config options """Wraps a storage provider and provides various config options
Args: Args:
backend (StorageProvider) backend: The storage provider to wrap.
store_local (bool): Whether to store new local files or not. store_local: Whether to store new local files or not.
store_synchronous (bool): Whether to wait for file to be successfully store_synchronous: Whether to wait for file to be successfully
uploaded, or todo the upload in the background. uploaded, or todo the upload in the background.
store_remote (bool): Whether remote media should be uploaded store_remote: Whether remote media should be uploaded
""" """
def __init__(self, backend, store_local, store_synchronous, store_remote): def __init__(
self,
backend: StorageProvider,
store_local: bool,
store_synchronous: bool,
store_remote: bool,
):
self.backend = backend self.backend = backend
self.store_local = store_local self.store_local = store_local
self.store_synchronous = store_synchronous self.store_synchronous = store_synchronous
@ -80,15 +80,15 @@ class StorageProviderWrapper(StorageProvider):
def __str__(self): def __str__(self):
return "StorageProviderWrapper[%s]" % (self.backend,) return "StorageProviderWrapper[%s]" % (self.backend,)
def store_file(self, path, file_info): async def store_file(self, path, file_info):
if not file_info.server_name and not self.store_local: if not file_info.server_name and not self.store_local:
return defer.succeed(None) return None
if file_info.server_name and not self.store_remote: if file_info.server_name and not self.store_remote:
return defer.succeed(None) return None
if self.store_synchronous: if self.store_synchronous:
return self.backend.store_file(path, file_info) return await self.backend.store_file(path, file_info)
else: else:
# TODO: Handle errors. # TODO: Handle errors.
def store(): def store():
@ -98,10 +98,10 @@ class StorageProviderWrapper(StorageProvider):
logger.exception("Error storing file") logger.exception("Error storing file")
run_in_background(store) run_in_background(store)
return defer.succeed(None) return None
def fetch(self, path, file_info): async def fetch(self, path, file_info):
return self.backend.fetch(path, file_info) return await self.backend.fetch(path, file_info)
class FileStorageProviderBackend(StorageProvider): class FileStorageProviderBackend(StorageProvider):
@ -120,7 +120,7 @@ class FileStorageProviderBackend(StorageProvider):
def __str__(self): def __str__(self):
return "FileStorageProviderBackend[%s]" % (self.base_directory,) return "FileStorageProviderBackend[%s]" % (self.base_directory,)
def store_file(self, path, file_info): async def store_file(self, path, file_info):
"""See StorageProvider.store_file""" """See StorageProvider.store_file"""
primary_fname = os.path.join(self.cache_directory, path) primary_fname = os.path.join(self.cache_directory, path)
@ -130,11 +130,11 @@ class FileStorageProviderBackend(StorageProvider):
if not os.path.exists(dirname): if not os.path.exists(dirname):
os.makedirs(dirname) os.makedirs(dirname)
return defer_to_thread( return await defer_to_thread(
self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname
) )
def fetch(self, path, file_info): async def fetch(self, path, file_info):
"""See StorageProvider.fetch""" """See StorageProvider.fetch"""
backup_fname = os.path.join(self.base_directory, path) backup_fname = os.path.join(self.base_directory, path)

View file

@ -172,6 +172,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self.get_latest_event_ids_in_room.invalidate((room_id,)) self.get_latest_event_ids_in_room.invalidate((room_id,))
self.get_unread_message_count_for_user.invalidate_many((room_id,))
self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,)) self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,))
if not backfilled: if not backfilled:

View file

@ -15,11 +15,10 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import List
from canonicaljson import json from canonicaljson import json
from twisted.internet import defer
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
from synapse.storage.database import Database from synapse.storage.database import Database
@ -166,8 +165,9 @@ class EventPushActionsWorkerStore(SQLBaseStore):
return {"notify_count": notify_count, "highlight_count": highlight_count} return {"notify_count": notify_count, "highlight_count": highlight_count}
@defer.inlineCallbacks async def get_push_action_users_in_range(
def get_push_action_users_in_range(self, min_stream_ordering, max_stream_ordering): self, min_stream_ordering, max_stream_ordering
):
def f(txn): def f(txn):
sql = ( sql = (
"SELECT DISTINCT(user_id) FROM event_push_actions WHERE" "SELECT DISTINCT(user_id) FROM event_push_actions WHERE"
@ -176,26 +176,28 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, (min_stream_ordering, max_stream_ordering)) txn.execute(sql, (min_stream_ordering, max_stream_ordering))
return [r[0] for r in txn] return [r[0] for r in txn]
ret = yield self.db.runInteraction("get_push_action_users_in_range", f) ret = await self.db.runInteraction("get_push_action_users_in_range", f)
return ret return ret
@defer.inlineCallbacks async def get_unread_push_actions_for_user_in_range_for_http(
def get_unread_push_actions_for_user_in_range_for_http( self,
self, user_id, min_stream_ordering, max_stream_ordering, limit=20 user_id: str,
): min_stream_ordering: int,
max_stream_ordering: int,
limit: int = 20,
) -> List[dict]:
"""Get a list of the most recent unread push actions for a given user, """Get a list of the most recent unread push actions for a given user,
within the given stream ordering range. Called by the httppusher. within the given stream ordering range. Called by the httppusher.
Args: Args:
user_id (str): The user to fetch push actions for. user_id: The user to fetch push actions for.
min_stream_ordering(int): The exclusive lower bound on the min_stream_ordering: The exclusive lower bound on the
stream ordering of event push actions to fetch. stream ordering of event push actions to fetch.
max_stream_ordering(int): The inclusive upper bound on the max_stream_ordering: The inclusive upper bound on the
stream ordering of event push actions to fetch. stream ordering of event push actions to fetch.
limit (int): The maximum number of rows to return. limit: The maximum number of rows to return.
Returns: Returns:
A promise which resolves to a list of dicts with the keys "event_id", A list of dicts with the keys "event_id", "room_id", "stream_ordering", "actions".
"room_id", "stream_ordering", "actions".
The list will be ordered by ascending stream_ordering. The list will be ordered by ascending stream_ordering.
The list will have between 0~limit entries. The list will have between 0~limit entries.
""" """
@ -228,7 +230,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, args) txn.execute(sql, args)
return txn.fetchall() return txn.fetchall()
after_read_receipt = yield self.db.runInteraction( after_read_receipt = await self.db.runInteraction(
"get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt "get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt
) )
@ -256,7 +258,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, args) txn.execute(sql, args)
return txn.fetchall() return txn.fetchall()
no_read_receipt = yield self.db.runInteraction( no_read_receipt = await self.db.runInteraction(
"get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt "get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt
) )
@ -280,23 +282,25 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# one of the subqueries may have hit the limit. # one of the subqueries may have hit the limit.
return notifs[:limit] return notifs[:limit]
@defer.inlineCallbacks async def get_unread_push_actions_for_user_in_range_for_email(
def get_unread_push_actions_for_user_in_range_for_email( self,
self, user_id, min_stream_ordering, max_stream_ordering, limit=20 user_id: str,
): min_stream_ordering: int,
max_stream_ordering: int,
limit: int = 20,
) -> List[dict]:
"""Get a list of the most recent unread push actions for a given user, """Get a list of the most recent unread push actions for a given user,
within the given stream ordering range. Called by the emailpusher within the given stream ordering range. Called by the emailpusher
Args: Args:
user_id (str): The user to fetch push actions for. user_id: The user to fetch push actions for.
min_stream_ordering(int): The exclusive lower bound on the min_stream_ordering: The exclusive lower bound on the
stream ordering of event push actions to fetch. stream ordering of event push actions to fetch.
max_stream_ordering(int): The inclusive upper bound on the max_stream_ordering: The inclusive upper bound on the
stream ordering of event push actions to fetch. stream ordering of event push actions to fetch.
limit (int): The maximum number of rows to return. limit: The maximum number of rows to return.
Returns: Returns:
A promise which resolves to a list of dicts with the keys "event_id", A list of dicts with the keys "event_id", "room_id", "stream_ordering", "actions", "received_ts".
"room_id", "stream_ordering", "actions", "received_ts".
The list will be ordered by descending received_ts. The list will be ordered by descending received_ts.
The list will have between 0~limit entries. The list will have between 0~limit entries.
""" """
@ -328,7 +332,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, args) txn.execute(sql, args)
return txn.fetchall() return txn.fetchall()
after_read_receipt = yield self.db.runInteraction( after_read_receipt = await self.db.runInteraction(
"get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt "get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt
) )
@ -356,7 +360,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, args) txn.execute(sql, args)
return txn.fetchall() return txn.fetchall()
no_read_receipt = yield self.db.runInteraction( no_read_receipt = await self.db.runInteraction(
"get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt "get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt
) )
@ -411,7 +415,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
_get_if_maybe_push_in_range_for_user_txn, _get_if_maybe_push_in_range_for_user_txn,
) )
def add_push_actions_to_staging(self, event_id, user_id_actions): async def add_push_actions_to_staging(self, event_id, user_id_actions):
"""Add the push actions for the event to the push action staging area. """Add the push actions for the event to the push action staging area.
Args: Args:
@ -457,21 +461,17 @@ class EventPushActionsWorkerStore(SQLBaseStore):
), ),
) )
return self.db.runInteraction( return await self.db.runInteraction(
"add_push_actions_to_staging", _add_push_actions_to_staging_txn "add_push_actions_to_staging", _add_push_actions_to_staging_txn
) )
@defer.inlineCallbacks async def remove_push_actions_from_staging(self, event_id: str) -> None:
def remove_push_actions_from_staging(self, event_id):
"""Called if we failed to persist the event to ensure that stale push """Called if we failed to persist the event to ensure that stale push
actions don't build up in the DB actions don't build up in the DB
Args:
event_id (str)
""" """
try: try:
res = yield self.db.simple_delete( res = await self.db.simple_delete(
table="event_push_actions_staging", table="event_push_actions_staging",
keyvalues={"event_id": event_id}, keyvalues={"event_id": event_id},
desc="remove_push_actions_from_staging", desc="remove_push_actions_from_staging",
@ -606,8 +606,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
return range_end return range_end
@defer.inlineCallbacks async def get_time_of_last_push_action_before(self, stream_ordering):
def get_time_of_last_push_action_before(self, stream_ordering):
def f(txn): def f(txn):
sql = ( sql = (
"SELECT e.received_ts" "SELECT e.received_ts"
@ -620,7 +619,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, (stream_ordering,)) txn.execute(sql, (stream_ordering,))
return txn.fetchone() return txn.fetchone()
result = yield self.db.runInteraction("get_time_of_last_push_action_before", f) result = await self.db.runInteraction("get_time_of_last_push_action_before", f)
return result[0] if result else None return result[0] if result else None
@ -650,8 +649,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
self._start_rotate_notifs, 30 * 60 * 1000 self._start_rotate_notifs, 30 * 60 * 1000
) )
@defer.inlineCallbacks async def get_push_actions_for_user(
def get_push_actions_for_user(
self, user_id, before=None, limit=50, only_highlight=False self, user_id, before=None, limit=50, only_highlight=False
): ):
def f(txn): def f(txn):
@ -682,18 +680,17 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
txn.execute(sql, args) txn.execute(sql, args)
return self.db.cursor_to_dict(txn) return self.db.cursor_to_dict(txn)
push_actions = yield self.db.runInteraction("get_push_actions_for_user", f) push_actions = await self.db.runInteraction("get_push_actions_for_user", f)
for pa in push_actions: for pa in push_actions:
pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"]) pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
return push_actions return push_actions
@defer.inlineCallbacks async def get_latest_push_action_stream_ordering(self):
def get_latest_push_action_stream_ordering(self):
def f(txn): def f(txn):
txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions") txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions")
return txn.fetchone() return txn.fetchone()
result = yield self.db.runInteraction( result = await self.db.runInteraction(
"get_latest_push_action_stream_ordering", f "get_latest_push_action_stream_ordering", f
) )
return result[0] or 0 return result[0] or 0
@ -747,8 +744,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
def _start_rotate_notifs(self): def _start_rotate_notifs(self):
return run_as_background_process("rotate_notifs", self._rotate_notifs) return run_as_background_process("rotate_notifs", self._rotate_notifs)
@defer.inlineCallbacks async def _rotate_notifs(self):
def _rotate_notifs(self):
if self._doing_notif_rotation or self.stream_ordering_day_ago is None: if self._doing_notif_rotation or self.stream_ordering_day_ago is None:
return return
self._doing_notif_rotation = True self._doing_notif_rotation = True
@ -757,12 +753,12 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
while True: while True:
logger.info("Rotating notifications") logger.info("Rotating notifications")
caught_up = yield self.db.runInteraction( caught_up = await self.db.runInteraction(
"_rotate_notifs", self._rotate_notifs_txn "_rotate_notifs", self._rotate_notifs_txn
) )
if caught_up: if caught_up:
break break
yield self.hs.get_clock().sleep(self._rotate_delay) await self.hs.get_clock().sleep(self._rotate_delay)
finally: finally:
self._doing_notif_rotation = False self._doing_notif_rotation = False

View file

@ -53,6 +53,47 @@ event_counter = Counter(
["type", "origin_type", "origin_entity"], ["type", "origin_type", "origin_entity"],
) )
STATE_EVENT_TYPES_TO_MARK_UNREAD = {
EventTypes.Topic,
EventTypes.Name,
EventTypes.RoomAvatar,
EventTypes.Tombstone,
}
def should_count_as_unread(event: EventBase, context: EventContext) -> bool:
# Exclude rejected and soft-failed events.
if context.rejected or event.internal_metadata.is_soft_failed():
return False
# Exclude notices.
if (
not event.is_state()
and event.type == EventTypes.Message
and event.content.get("msgtype") == "m.notice"
):
return False
# Exclude edits.
relates_to = event.content.get("m.relates_to", {})
if relates_to.get("rel_type") == RelationTypes.REPLACE:
return False
# Mark events that have a non-empty string body as unread.
body = event.content.get("body")
if isinstance(body, str) and body:
return True
# Mark some state events as unread.
if event.is_state() and event.type in STATE_EVENT_TYPES_TO_MARK_UNREAD:
return True
# Mark encrypted events as unread.
if not event.is_state() and event.type == EventTypes.Encrypted:
return True
return False
def encode_json(json_object): def encode_json(json_object):
""" """
@ -196,6 +237,10 @@ class PersistEventsStore:
event_counter.labels(event.type, origin_type, origin_entity).inc() event_counter.labels(event.type, origin_type, origin_entity).inc()
self.store.get_unread_message_count_for_user.invalidate_many(
(event.room_id,),
)
for room_id, new_state in current_state_for_room.items(): for room_id, new_state in current_state_for_room.items():
self.store.get_current_state_ids.prefill((room_id,), new_state) self.store.get_current_state_ids.prefill((room_id,), new_state)
@ -817,8 +862,9 @@ class PersistEventsStore:
"contains_url": ( "contains_url": (
"url" in event.content and isinstance(event.content["url"], str) "url" in event.content and isinstance(event.content["url"], str)
), ),
"count_as_unread": should_count_as_unread(event, context),
} }
for event, _ in events_and_contexts for event, context in events_and_contexts
], ],
) )

View file

@ -41,9 +41,15 @@ from synapse.replication.tcp.streams import BackfillStream
from synapse.replication.tcp.streams.events import EventsStream from synapse.replication.tcp.streams.events import EventsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import Database from synapse.storage.database import Database
from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import StreamIdGenerator from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import get_domain_from_id from synapse.types import get_domain_from_id
from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks from synapse.util.caches.descriptors import (
Cache,
_CacheContext,
cached,
cachedInlineCallbacks,
)
from synapse.util.iterutils import batch_iter from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
@ -1358,6 +1364,84 @@ class EventsWorkerStore(SQLBaseStore):
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
) )
@cached(tree=True, cache_context=True)
async def get_unread_message_count_for_user(
self, room_id: str, user_id: str, cache_context: _CacheContext,
) -> int:
"""Retrieve the count of unread messages for the given room and user.
Args:
room_id: The ID of the room to count unread messages in.
user_id: The ID of the user to count unread messages for.
Returns:
The number of unread messages for the given user in the given room.
"""
with Measure(self._clock, "get_unread_message_count_for_user"):
last_read_event_id = await self.get_last_receipt_event_id_for_user(
user_id=user_id,
room_id=room_id,
receipt_type="m.read",
on_invalidate=cache_context.invalidate,
)
return await self.db.runInteraction(
"get_unread_message_count_for_user",
self._get_unread_message_count_for_user_txn,
user_id,
room_id,
last_read_event_id,
)
def _get_unread_message_count_for_user_txn(
self,
txn: Cursor,
user_id: str,
room_id: str,
last_read_event_id: Optional[str],
) -> int:
if last_read_event_id:
# Get the stream ordering for the last read event.
stream_ordering = self.db.simple_select_one_onecol_txn(
txn=txn,
table="events",
keyvalues={"room_id": room_id, "event_id": last_read_event_id},
retcol="stream_ordering",
)
else:
# If there's no read receipt for that room, it probably means the user hasn't
# opened it yet, in which case use the stream ID of their join event.
# We can't just set it to 0 otherwise messages from other local users from
# before this user joined will be counted as well.
txn.execute(
"""
SELECT stream_ordering FROM local_current_membership
LEFT JOIN events USING (event_id, room_id)
WHERE membership = 'join'
AND user_id = ?
AND room_id = ?
""",
(user_id, room_id),
)
row = txn.fetchone()
if row is None:
return 0
stream_ordering = row[0]
# Count the messages that qualify as unread after the stream ordering we've just
# retrieved.
sql = """
SELECT COUNT(*) FROM events
WHERE sender != ? AND room_id = ? AND stream_ordering > ? AND count_as_unread
"""
txn.execute(sql, (user_id, room_id, stream_ordering))
row = txn.fetchone()
return row[0] if row else 0
AllNewEventsResult = namedtuple( AllNewEventsResult = namedtuple(
"AllNewEventsResult", "AllNewEventsResult",

View file

@ -62,6 +62,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
# event_json # event_json
# event_push_actions # event_push_actions
# event_reference_hashes # event_reference_hashes
# event_relations
# event_search # event_search
# event_to_state_groups # event_to_state_groups
# events # events
@ -209,6 +210,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
"event_edges", "event_edges",
"event_forward_extremities", "event_forward_extremities",
"event_reference_hashes", "event_reference_hashes",
"event_relations",
"event_search", "event_search",
"rejections", "rejections",
): ):

View file

@ -23,8 +23,6 @@ from typing import Any, Dict, List, Optional, Tuple
from canonicaljson import json from canonicaljson import json
from twisted.internet import defer
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.api.room_versions import RoomVersion, RoomVersions from synapse.api.room_versions import RoomVersion, RoomVersions
@ -32,7 +30,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.data_stores.main.search import SearchStore from synapse.storage.data_stores.main.search import SearchStore
from synapse.storage.database import Database, LoggingTransaction from synapse.storage.database import Database, LoggingTransaction
from synapse.types import ThirdPartyInstanceID from synapse.types import ThirdPartyInstanceID
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -192,8 +190,7 @@ class RoomWorkerStore(SQLBaseStore):
return self.db.runInteraction("count_public_rooms", _count_public_rooms_txn) return self.db.runInteraction("count_public_rooms", _count_public_rooms_txn)
@defer.inlineCallbacks async def get_largest_public_rooms(
def get_largest_public_rooms(
self, self,
network_tuple: Optional[ThirdPartyInstanceID], network_tuple: Optional[ThirdPartyInstanceID],
search_filter: Optional[dict], search_filter: Optional[dict],
@ -330,10 +327,10 @@ class RoomWorkerStore(SQLBaseStore):
return results return results
ret_val = yield self.db.runInteraction( ret_val = await self.db.runInteraction(
"get_largest_public_rooms", _get_largest_public_rooms_txn "get_largest_public_rooms", _get_largest_public_rooms_txn
) )
defer.returnValue(ret_val) return ret_val
@cached(max_entries=10000) @cached(max_entries=10000)
def is_room_blocked(self, room_id): def is_room_blocked(self, room_id):
@ -509,8 +506,8 @@ class RoomWorkerStore(SQLBaseStore):
"get_rooms_paginate", _get_rooms_paginate_txn, "get_rooms_paginate", _get_rooms_paginate_txn,
) )
@cachedInlineCallbacks(max_entries=10000) @cached(max_entries=10000)
def get_ratelimit_for_user(self, user_id): async def get_ratelimit_for_user(self, user_id):
"""Check if there are any overrides for ratelimiting for the given """Check if there are any overrides for ratelimiting for the given
user user
@ -522,7 +519,7 @@ class RoomWorkerStore(SQLBaseStore):
of RatelimitOverride are None or 0 then ratelimitng has been of RatelimitOverride are None or 0 then ratelimitng has been
disabled for that user entirely. disabled for that user entirely.
""" """
row = yield self.db.simple_select_one( row = await self.db.simple_select_one(
table="ratelimit_override", table="ratelimit_override",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
retcols=("messages_per_second", "burst_count"), retcols=("messages_per_second", "burst_count"),
@ -538,8 +535,8 @@ class RoomWorkerStore(SQLBaseStore):
else: else:
return None return None
@cachedInlineCallbacks() @cached()
def get_retention_policy_for_room(self, room_id): async def get_retention_policy_for_room(self, room_id):
"""Get the retention policy for a given room. """Get the retention policy for a given room.
If no retention policy has been found for this room, returns a policy defined If no retention policy has been found for this room, returns a policy defined
@ -566,19 +563,17 @@ class RoomWorkerStore(SQLBaseStore):
return self.db.cursor_to_dict(txn) return self.db.cursor_to_dict(txn)
ret = yield self.db.runInteraction( ret = await self.db.runInteraction(
"get_retention_policy_for_room", get_retention_policy_for_room_txn, "get_retention_policy_for_room", get_retention_policy_for_room_txn,
) )
# If we don't know this room ID, ret will be None, in this case return the default # If we don't know this room ID, ret will be None, in this case return the default
# policy. # policy.
if not ret: if not ret:
defer.returnValue( return {
{ "min_lifetime": self.config.retention_default_min_lifetime,
"min_lifetime": self.config.retention_default_min_lifetime, "max_lifetime": self.config.retention_default_max_lifetime,
"max_lifetime": self.config.retention_default_max_lifetime, }
}
)
row = ret[0] row = ret[0]
@ -592,7 +587,7 @@ class RoomWorkerStore(SQLBaseStore):
if row["max_lifetime"] is None: if row["max_lifetime"] is None:
row["max_lifetime"] = self.config.retention_default_max_lifetime row["max_lifetime"] = self.config.retention_default_max_lifetime
defer.returnValue(row) return row
def get_media_mxcs_in_room(self, room_id): def get_media_mxcs_in_room(self, room_id):
"""Retrieves all the local and remote media MXC URIs in a given room """Retrieves all the local and remote media MXC URIs in a given room
@ -881,8 +876,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
self._background_add_rooms_room_version_column, self._background_add_rooms_room_version_column,
) )
@defer.inlineCallbacks async def _background_insert_retention(self, progress, batch_size):
def _background_insert_retention(self, progress, batch_size):
"""Retrieves a list of all rooms within a range and inserts an entry for each of """Retrieves a list of all rooms within a range and inserts an entry for each of
them into the room_retention table. them into the room_retention table.
NULLs the property's columns if missing from the retention event in the room's NULLs the property's columns if missing from the retention event in the room's
@ -940,14 +934,14 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
else: else:
return False return False
end = yield self.db.runInteraction( end = await self.db.runInteraction(
"insert_room_retention", _background_insert_retention_txn, "insert_room_retention", _background_insert_retention_txn,
) )
if end: if end:
yield self.db.updates._end_background_update("insert_room_retention") await self.db.updates._end_background_update("insert_room_retention")
defer.returnValue(batch_size) return batch_size
async def _background_add_rooms_room_version_column( async def _background_add_rooms_room_version_column(
self, progress: dict, batch_size: int self, progress: dict, batch_size: int
@ -1096,8 +1090,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
lock=False, lock=False,
) )
@defer.inlineCallbacks async def store_room(
def store_room(
self, self,
room_id: str, room_id: str,
room_creator_user_id: str, room_creator_user_id: str,
@ -1140,7 +1133,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
) )
with self._public_room_id_gen.get_next() as next_id: with self._public_room_id_gen.get_next() as next_id:
yield self.db.runInteraction("store_room_txn", store_room_txn, next_id) await self.db.runInteraction("store_room_txn", store_room_txn, next_id)
except Exception as e: except Exception as e:
logger.error("store_room with room_id=%s failed: %s", room_id, e) logger.error("store_room with room_id=%s failed: %s", room_id, e)
raise StoreError(500, "Problem creating room.") raise StoreError(500, "Problem creating room.")
@ -1165,8 +1158,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
lock=False, lock=False,
) )
@defer.inlineCallbacks async def set_room_is_public(self, room_id, is_public):
def set_room_is_public(self, room_id, is_public):
def set_room_is_public_txn(txn, next_id): def set_room_is_public_txn(txn, next_id):
self.db.simple_update_one_txn( self.db.simple_update_one_txn(
txn, txn,
@ -1206,13 +1198,12 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
) )
with self._public_room_id_gen.get_next() as next_id: with self._public_room_id_gen.get_next() as next_id:
yield self.db.runInteraction( await self.db.runInteraction(
"set_room_is_public", set_room_is_public_txn, next_id "set_room_is_public", set_room_is_public_txn, next_id
) )
self.hs.get_notifier().on_new_replication_data() self.hs.get_notifier().on_new_replication_data()
@defer.inlineCallbacks async def set_room_is_public_appservice(
def set_room_is_public_appservice(
self, room_id, appservice_id, network_id, is_public self, room_id, appservice_id, network_id, is_public
): ):
"""Edit the appservice/network specific public room list. """Edit the appservice/network specific public room list.
@ -1287,7 +1278,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
) )
with self._public_room_id_gen.get_next() as next_id: with self._public_room_id_gen.get_next() as next_id:
yield self.db.runInteraction( await self.db.runInteraction(
"set_room_is_public_appservice", "set_room_is_public_appservice",
set_room_is_public_appservice_txn, set_room_is_public_appservice_txn,
next_id, next_id,
@ -1327,52 +1318,47 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
def get_current_public_room_stream_id(self): def get_current_public_room_stream_id(self):
return self._public_room_id_gen.get_current_token() return self._public_room_id_gen.get_current_token()
@defer.inlineCallbacks async def block_room(self, room_id: str, user_id: str) -> None:
def block_room(self, room_id, user_id):
"""Marks the room as blocked. Can be called multiple times. """Marks the room as blocked. Can be called multiple times.
Args: Args:
room_id (str): Room to block room_id: Room to block
user_id (str): Who blocked it user_id: Who blocked it
Returns:
Deferred
""" """
yield self.db.simple_upsert( await self.db.simple_upsert(
table="blocked_rooms", table="blocked_rooms",
keyvalues={"room_id": room_id}, keyvalues={"room_id": room_id},
values={}, values={},
insertion_values={"user_id": user_id}, insertion_values={"user_id": user_id},
desc="block_room", desc="block_room",
) )
yield self.db.runInteraction( await self.db.runInteraction(
"block_room_invalidation", "block_room_invalidation",
self._invalidate_cache_and_stream, self._invalidate_cache_and_stream,
self.is_room_blocked, self.is_room_blocked,
(room_id,), (room_id,),
) )
@defer.inlineCallbacks async def get_rooms_for_retention_period_in_range(
def get_rooms_for_retention_period_in_range( self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False
self, min_ms, max_ms, include_null=False ) -> Dict[str, dict]:
):
"""Retrieves all of the rooms within the given retention range. """Retrieves all of the rooms within the given retention range.
Optionally includes the rooms which don't have a retention policy. Optionally includes the rooms which don't have a retention policy.
Args: Args:
min_ms (int|None): Duration in milliseconds that define the lower limit of min_ms: Duration in milliseconds that define the lower limit of
the range to handle (exclusive). If None, doesn't set a lower limit. the range to handle (exclusive). If None, doesn't set a lower limit.
max_ms (int|None): Duration in milliseconds that define the upper limit of max_ms: Duration in milliseconds that define the upper limit of
the range to handle (inclusive). If None, doesn't set an upper limit. the range to handle (inclusive). If None, doesn't set an upper limit.
include_null (bool): Whether to include rooms which retention policy is NULL include_null: Whether to include rooms which retention policy is NULL
in the returned set. in the returned set.
Returns: Returns:
dict[str, dict]: The rooms within this range, along with their retention The rooms within this range, along with their retention
policy. The key is "room_id", and maps to a dict describing the retention policy. The key is "room_id", and maps to a dict describing the retention
policy associated with this room ID. The keys for this nested dict are policy associated with this room ID. The keys for this nested dict are
"min_lifetime" (int|None), and "max_lifetime" (int|None). "min_lifetime" (int|None), and "max_lifetime" (int|None).
""" """
def get_rooms_for_retention_period_in_range_txn(txn): def get_rooms_for_retention_period_in_range_txn(txn):
@ -1431,9 +1417,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
return rooms_dict return rooms_dict
rooms = yield self.db.runInteraction( rooms = await self.db.runInteraction(
"get_rooms_for_retention_period_in_range", "get_rooms_for_retention_period_in_range",
get_rooms_for_retention_period_in_range_txn, get_rooms_for_retention_period_in_range_txn,
) )
defer.returnValue(rooms) return rooms

View file

@ -0,0 +1,18 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Store a boolean value in the events table for whether the event should be counted in
-- the unread_count property of sync responses.
ALTER TABLE events ADD COLUMN count_as_unread BOOLEAN;

View file

@ -16,12 +16,12 @@
import collections.abc import collections.abc
import logging import logging
from collections import namedtuple from collections import namedtuple
from typing import Iterable, Optional, Set
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore
@ -108,28 +108,27 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
create_event = await self.get_create_event_for_room(room_id) create_event = await self.get_create_event_for_room(room_id)
return create_event.content.get("room_version", "1") return create_event.content.get("room_version", "1")
@defer.inlineCallbacks async def get_room_predecessor(self, room_id: str) -> Optional[dict]:
def get_room_predecessor(self, room_id):
"""Get the predecessor of an upgraded room if it exists. """Get the predecessor of an upgraded room if it exists.
Otherwise return None. Otherwise return None.
Args: Args:
room_id (str) room_id: The room ID.
Returns: Returns:
Deferred[dict|None]: A dictionary containing the structure of the predecessor A dictionary containing the structure of the predecessor
field from the room's create event. The structure is subject to other servers, field from the room's create event. The structure is subject to other servers,
but it is expected to be: but it is expected to be:
* room_id (str): The room ID of the predecessor room * room_id (str): The room ID of the predecessor room
* event_id (str): The ID of the tombstone event in the predecessor room * event_id (str): The ID of the tombstone event in the predecessor room
None if a predecessor key is not found, or is not a dictionary. None if a predecessor key is not found, or is not a dictionary.
Raises: Raises:
NotFoundError if the given room is unknown NotFoundError if the given room is unknown
""" """
# Retrieve the room's create event # Retrieve the room's create event
create_event = yield self.get_create_event_for_room(room_id) create_event = await self.get_create_event_for_room(room_id)
# Retrieve the predecessor key of the create event # Retrieve the predecessor key of the create event
predecessor = create_event.content.get("predecessor", None) predecessor = create_event.content.get("predecessor", None)
@ -140,20 +139,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return predecessor return predecessor
@defer.inlineCallbacks async def get_create_event_for_room(self, room_id: str) -> EventBase:
def get_create_event_for_room(self, room_id):
"""Get the create state event for a room. """Get the create state event for a room.
Args: Args:
room_id (str) room_id: The room ID.
Returns: Returns:
Deferred[EventBase]: The room creation event. The room creation event.
Raises: Raises:
NotFoundError if the room is unknown NotFoundError if the room is unknown
""" """
state_ids = yield self.get_current_state_ids(room_id) state_ids = await self.get_current_state_ids(room_id)
create_id = state_ids.get((EventTypes.Create, "")) create_id = state_ids.get((EventTypes.Create, ""))
# If we can't find the create event, assume we've hit a dead end # If we can't find the create event, assume we've hit a dead end
@ -161,7 +159,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
raise NotFoundError("Unknown room %s" % (room_id,)) raise NotFoundError("Unknown room %s" % (room_id,))
# Retrieve the room's create event and return # Retrieve the room's create event and return
create_event = yield self.get_event(create_id) create_event = await self.get_event(create_id)
return create_event return create_event
@cached(max_entries=100000, iterable=True) @cached(max_entries=100000, iterable=True)
@ -237,18 +235,17 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"get_filtered_current_state_ids", _get_filtered_current_state_ids_txn "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
) )
@defer.inlineCallbacks async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]:
def get_canonical_alias_for_room(self, room_id):
"""Get canonical alias for room, if any """Get canonical alias for room, if any
Args: Args:
room_id (str) room_id: The room ID
Returns: Returns:
Deferred[str|None]: The canonical alias, if any The canonical alias, if any
""" """
state = yield self.get_filtered_current_state_ids( state = await self.get_filtered_current_state_ids(
room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")]) room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")])
) )
@ -256,7 +253,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
if not event_id: if not event_id:
return return
event = yield self.get_event(event_id, allow_none=True) event = await self.get_event(event_id, allow_none=True)
if not event: if not event:
return return
@ -292,19 +289,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return {row["event_id"]: row["state_group"] for row in rows} return {row["event_id"]: row["state_group"] for row in rows}
@defer.inlineCallbacks async def get_referenced_state_groups(
def get_referenced_state_groups(self, state_groups): self, state_groups: Iterable[int]
) -> Set[int]:
"""Check if the state groups are referenced by events. """Check if the state groups are referenced by events.
Args: Args:
state_groups (Iterable[int]) state_groups
Returns: Returns:
Deferred[set[int]]: The subset of state groups that are The subset of state groups that are referenced.
referenced.
""" """
rows = yield self.db.simple_select_many_batch( rows = await self.db.simple_select_many_batch(
table="event_to_state_groups", table="event_to_state_groups",
column="state_group", column="state_group",
iterable=state_groups, iterable=state_groups,

View file

@ -16,8 +16,8 @@
import logging import logging
from itertools import chain from itertools import chain
from typing import Tuple
from twisted.internet import defer
from twisted.internet.defer import DeferredLock from twisted.internet.defer import DeferredLock
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
@ -97,13 +97,12 @@ class StatsStore(StateDeltasStore):
""" """
return (ts // self.stats_bucket_size) * self.stats_bucket_size return (ts // self.stats_bucket_size) * self.stats_bucket_size
@defer.inlineCallbacks async def _populate_stats_process_users(self, progress, batch_size):
def _populate_stats_process_users(self, progress, batch_size):
""" """
This is a background update which regenerates statistics for users. This is a background update which regenerates statistics for users.
""" """
if not self.stats_enabled: if not self.stats_enabled:
yield self.db.updates._end_background_update("populate_stats_process_users") await self.db.updates._end_background_update("populate_stats_process_users")
return 1 return 1
last_user_id = progress.get("last_user_id", "") last_user_id = progress.get("last_user_id", "")
@ -118,20 +117,20 @@ class StatsStore(StateDeltasStore):
txn.execute(sql, (last_user_id, batch_size)) txn.execute(sql, (last_user_id, batch_size))
return [r for r, in txn] return [r for r, in txn]
users_to_work_on = yield self.db.runInteraction( users_to_work_on = await self.db.runInteraction(
"_populate_stats_process_users", _get_next_batch "_populate_stats_process_users", _get_next_batch
) )
# No more rooms -- complete the transaction. # No more rooms -- complete the transaction.
if not users_to_work_on: if not users_to_work_on:
yield self.db.updates._end_background_update("populate_stats_process_users") await self.db.updates._end_background_update("populate_stats_process_users")
return 1 return 1
for user_id in users_to_work_on: for user_id in users_to_work_on:
yield self._calculate_and_set_initial_state_for_user(user_id) await self._calculate_and_set_initial_state_for_user(user_id)
progress["last_user_id"] = user_id progress["last_user_id"] = user_id
yield self.db.runInteraction( await self.db.runInteraction(
"populate_stats_process_users", "populate_stats_process_users",
self.db.updates._background_update_progress_txn, self.db.updates._background_update_progress_txn,
"populate_stats_process_users", "populate_stats_process_users",
@ -140,13 +139,12 @@ class StatsStore(StateDeltasStore):
return len(users_to_work_on) return len(users_to_work_on)
@defer.inlineCallbacks async def _populate_stats_process_rooms(self, progress, batch_size):
def _populate_stats_process_rooms(self, progress, batch_size):
""" """
This is a background update which regenerates statistics for rooms. This is a background update which regenerates statistics for rooms.
""" """
if not self.stats_enabled: if not self.stats_enabled:
yield self.db.updates._end_background_update("populate_stats_process_rooms") await self.db.updates._end_background_update("populate_stats_process_rooms")
return 1 return 1
last_room_id = progress.get("last_room_id", "") last_room_id = progress.get("last_room_id", "")
@ -161,20 +159,20 @@ class StatsStore(StateDeltasStore):
txn.execute(sql, (last_room_id, batch_size)) txn.execute(sql, (last_room_id, batch_size))
return [r for r, in txn] return [r for r, in txn]
rooms_to_work_on = yield self.db.runInteraction( rooms_to_work_on = await self.db.runInteraction(
"populate_stats_rooms_get_batch", _get_next_batch "populate_stats_rooms_get_batch", _get_next_batch
) )
# No more rooms -- complete the transaction. # No more rooms -- complete the transaction.
if not rooms_to_work_on: if not rooms_to_work_on:
yield self.db.updates._end_background_update("populate_stats_process_rooms") await self.db.updates._end_background_update("populate_stats_process_rooms")
return 1 return 1
for room_id in rooms_to_work_on: for room_id in rooms_to_work_on:
yield self._calculate_and_set_initial_state_for_room(room_id) await self._calculate_and_set_initial_state_for_room(room_id)
progress["last_room_id"] = room_id progress["last_room_id"] = room_id
yield self.db.runInteraction( await self.db.runInteraction(
"_populate_stats_process_rooms", "_populate_stats_process_rooms",
self.db.updates._background_update_progress_txn, self.db.updates._background_update_progress_txn,
"populate_stats_process_rooms", "populate_stats_process_rooms",
@ -696,16 +694,16 @@ class StatsStore(StateDeltasStore):
return room_deltas, user_deltas return room_deltas, user_deltas
@defer.inlineCallbacks async def _calculate_and_set_initial_state_for_room(
def _calculate_and_set_initial_state_for_room(self, room_id): self, room_id: str
) -> Tuple[dict, dict, int]:
"""Calculate and insert an entry into room_stats_current. """Calculate and insert an entry into room_stats_current.
Args: Args:
room_id (str) room_id: The room ID under calculation.
Returns: Returns:
Deferred[tuple[dict, dict, int]]: A tuple of room state, membership A tuple of room state, membership counts and stream position.
counts and stream position.
""" """
def _fetch_current_state_stats(txn): def _fetch_current_state_stats(txn):
@ -767,11 +765,11 @@ class StatsStore(StateDeltasStore):
current_state_events_count, current_state_events_count,
users_in_room, users_in_room,
pos, pos,
) = yield self.db.runInteraction( ) = await self.db.runInteraction(
"get_initial_state_for_room", _fetch_current_state_stats "get_initial_state_for_room", _fetch_current_state_stats
) )
state_event_map = yield self.get_events(event_ids, get_prev_content=False) state_event_map = await self.get_events(event_ids, get_prev_content=False)
room_state = { room_state = {
"join_rules": None, "join_rules": None,
@ -806,11 +804,11 @@ class StatsStore(StateDeltasStore):
event.content.get("m.federate", True) is True event.content.get("m.federate", True) is True
) )
yield self.update_room_state(room_id, room_state) await self.update_room_state(room_id, room_state)
local_users_in_room = [u for u in users_in_room if self.hs.is_mine_id(u)] local_users_in_room = [u for u in users_in_room if self.hs.is_mine_id(u)]
yield self.update_stats_delta( await self.update_stats_delta(
ts=self.clock.time_msec(), ts=self.clock.time_msec(),
stats_type="room", stats_type="room",
stats_id=room_id, stats_id=room_id,
@ -826,8 +824,7 @@ class StatsStore(StateDeltasStore):
}, },
) )
@defer.inlineCallbacks async def _calculate_and_set_initial_state_for_user(self, user_id):
def _calculate_and_set_initial_state_for_user(self, user_id):
def _calculate_and_set_initial_state_for_user_txn(txn): def _calculate_and_set_initial_state_for_user_txn(txn):
pos = self._get_max_stream_id_in_current_state_deltas_txn(txn) pos = self._get_max_stream_id_in_current_state_deltas_txn(txn)
@ -842,12 +839,12 @@ class StatsStore(StateDeltasStore):
(count,) = txn.fetchone() (count,) = txn.fetchone()
return count, pos return count, pos
joined_rooms, pos = yield self.db.runInteraction( joined_rooms, pos = await self.db.runInteraction(
"calculate_and_set_initial_state_for_user", "calculate_and_set_initial_state_for_user",
_calculate_and_set_initial_state_for_user_txn, _calculate_and_set_initial_state_for_user_txn,
) )
yield self.update_stats_delta( await self.update_stats_delta(
ts=self.clock.time_msec(), ts=self.clock.time_msec(),
stats_type="user", stats_type="user",
stats_id=user_id, stats_id=user_id,

View file

@ -139,10 +139,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
"get_state_group_delta", _get_state_group_delta_txn "get_state_group_delta", _get_state_group_delta_txn
) )
@defer.inlineCallbacks async def _get_state_groups_from_groups(
def _get_state_groups_from_groups(
self, groups: List[int], state_filter: StateFilter self, groups: List[int], state_filter: StateFilter
): ) -> Dict[int, StateMap[str]]:
"""Returns the state groups for a given set of groups from the """Returns the state groups for a given set of groups from the
database, filtering on types of state events. database, filtering on types of state events.
@ -151,13 +150,13 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
state_filter: The state filter used to fetch state state_filter: The state filter used to fetch state
from the database. from the database.
Returns: Returns:
Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map. Dict of state group to state map.
""" """
results = {} results = {}
chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)] chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)]
for chunk in chunks: for chunk in chunks:
res = yield self.db.runInteraction( res = await self.db.runInteraction(
"_get_state_groups_from_groups", "_get_state_groups_from_groups",
self._get_state_groups_from_groups_txn, self._get_state_groups_from_groups_txn,
chunk, chunk,
@ -206,10 +205,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return state_filter.filter_state(state_dict_ids), not missing_types return state_filter.filter_state(state_dict_ids), not missing_types
@defer.inlineCallbacks async def _get_state_for_groups(
def _get_state_for_groups(
self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all() self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
): ) -> Dict[int, StateMap[str]]:
"""Gets the state at each of a list of state groups, optionally """Gets the state at each of a list of state groups, optionally
filtering by type/state_key filtering by type/state_key
@ -219,7 +217,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
state_filter: The state filter used to fetch state state_filter: The state filter used to fetch state
from the database. from the database.
Returns: Returns:
Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map. Dict of state group to state map.
""" """
member_filter, non_member_filter = state_filter.get_member_split() member_filter, non_member_filter = state_filter.get_member_split()
@ -228,14 +226,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
( (
non_member_state, non_member_state,
incomplete_groups_nm, incomplete_groups_nm,
) = yield self._get_state_for_groups_using_cache( ) = self._get_state_for_groups_using_cache(
groups, self._state_group_cache, state_filter=non_member_filter groups, self._state_group_cache, state_filter=non_member_filter
) )
( (member_state, incomplete_groups_m,) = self._get_state_for_groups_using_cache(
member_state,
incomplete_groups_m,
) = yield self._get_state_for_groups_using_cache(
groups, self._state_group_members_cache, state_filter=member_filter groups, self._state_group_members_cache, state_filter=member_filter
) )
@ -256,7 +251,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
# Help the cache hit ratio by expanding the filter a bit # Help the cache hit ratio by expanding the filter a bit
db_state_filter = state_filter.return_expanded() db_state_filter = state_filter.return_expanded()
group_to_state_dict = yield self._get_state_groups_from_groups( group_to_state_dict = await self._get_state_groups_from_groups(
list(incomplete_groups), state_filter=db_state_filter list(incomplete_groups), state_filter=db_state_filter
) )
@ -576,19 +571,19 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
((sg,) for sg in state_groups_to_delete), ((sg,) for sg in state_groups_to_delete),
) )
@defer.inlineCallbacks async def get_previous_state_groups(
def get_previous_state_groups(self, state_groups): self, state_groups: Iterable[int]
) -> Dict[int, int]:
"""Fetch the previous groups of the given state groups. """Fetch the previous groups of the given state groups.
Args: Args:
state_groups (Iterable[int]) state_groups
Returns: Returns:
Deferred[dict[int, int]]: mapping from state group to previous A mapping from state group to previous state group.
state group.
""" """
rows = yield self.db.simple_select_many_batch( rows = await self.db.simple_select_many_batch(
table="state_group_edges", table="state_group_edges",
column="prev_state_group", column="prev_state_group",
iterable=state_groups, iterable=state_groups,

View file

@ -49,11 +49,11 @@ from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3E
from synapse.storage.types import Connection, Cursor from synapse.storage.types import Connection, Cursor
from synapse.types import Collection from synapse.types import Collection
logger = logging.getLogger(__name__)
# python 3 does not have a maximum int value # python 3 does not have a maximum int value
MAX_TXN_ID = 2 ** 63 - 1 MAX_TXN_ID = 2 ** 63 - 1
logger = logging.getLogger(__name__)
sql_logger = logging.getLogger("synapse.storage.SQL") sql_logger = logging.getLogger("synapse.storage.SQL")
transaction_logger = logging.getLogger("synapse.storage.txn") transaction_logger = logging.getLogger("synapse.storage.txn")
perf_logger = logging.getLogger("synapse.storage.TIME") perf_logger = logging.getLogger("synapse.storage.TIME")
@ -233,7 +233,7 @@ class LoggingTransaction:
try: try:
return func(sql, *args) return func(sql, *args)
except Exception as e: except Exception as e:
logger.debug("[SQL FAIL] {%s} %s", self.name, e) sql_logger.debug("[SQL FAIL] {%s} %s", self.name, e)
raise raise
finally: finally:
secs = time.time() - start secs = time.time() - start
@ -419,7 +419,7 @@ class Database(object):
except self.engine.module.OperationalError as e: except self.engine.module.OperationalError as e:
# This can happen if the database disappears mid # This can happen if the database disappears mid
# transaction. # transaction.
logger.warning( transaction_logger.warning(
"[TXN OPERROR] {%s} %s %d/%d", name, e, i, N, "[TXN OPERROR] {%s} %s %d/%d", name, e, i, N,
) )
if i < N: if i < N:
@ -427,18 +427,20 @@ class Database(object):
try: try:
conn.rollback() conn.rollback()
except self.engine.module.Error as e1: except self.engine.module.Error as e1:
logger.warning("[TXN EROLL] {%s} %s", name, e1) transaction_logger.warning("[TXN EROLL] {%s} %s", name, e1)
continue continue
raise raise
except self.engine.module.DatabaseError as e: except self.engine.module.DatabaseError as e:
if self.engine.is_deadlock(e): if self.engine.is_deadlock(e):
logger.warning("[TXN DEADLOCK] {%s} %d/%d", name, i, N) transaction_logger.warning(
"[TXN DEADLOCK] {%s} %d/%d", name, i, N
)
if i < N: if i < N:
i += 1 i += 1
try: try:
conn.rollback() conn.rollback()
except self.engine.module.Error as e1: except self.engine.module.Error as e1:
logger.warning( transaction_logger.warning(
"[TXN EROLL] {%s} %s", name, e1, "[TXN EROLL] {%s} %s", name, e1,
) )
continue continue
@ -478,7 +480,7 @@ class Database(object):
# [2]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/cursor.c#L236 # [2]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/cursor.c#L236
cursor.close() cursor.close()
except Exception as e: except Exception as e:
logger.debug("[TXN FAIL] {%s} %s", name, e) transaction_logger.debug("[TXN FAIL] {%s} %s", name, e)
raise raise
finally: finally:
end = monotonic_time() end = monotonic_time()

View file

@ -25,7 +25,7 @@ from prometheus_client import Counter, Histogram
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.events import FrozenEvent from synapse.events import EventBase
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
@ -192,12 +192,11 @@ class EventsPersistenceStorage(object):
self._event_persist_queue = _EventPeristenceQueue() self._event_persist_queue = _EventPeristenceQueue()
self._state_resolution_handler = hs.get_state_resolution_handler() self._state_resolution_handler = hs.get_state_resolution_handler()
@defer.inlineCallbacks async def persist_events(
def persist_events(
self, self,
events_and_contexts: List[Tuple[FrozenEvent, EventContext]], events_and_contexts: List[Tuple[EventBase, EventContext]],
backfilled: bool = False, backfilled: bool = False,
): ) -> int:
""" """
Write events to the database Write events to the database
Args: Args:
@ -207,7 +206,7 @@ class EventsPersistenceStorage(object):
which might update the current state etc. which might update the current state etc.
Returns: Returns:
Deferred[int]: the stream ordering of the latest persisted event the stream ordering of the latest persisted event
""" """
partitioned = {} partitioned = {}
for event, ctx in events_and_contexts: for event, ctx in events_and_contexts:
@ -223,22 +222,19 @@ class EventsPersistenceStorage(object):
for room_id in partitioned: for room_id in partitioned:
self._maybe_start_persisting(room_id) self._maybe_start_persisting(room_id)
yield make_deferred_yieldable( await make_deferred_yieldable(
defer.gatherResults(deferreds, consumeErrors=True) defer.gatherResults(deferreds, consumeErrors=True)
) )
max_persisted_id = yield self.main_store.get_current_events_token() return self.main_store.get_current_events_token()
return max_persisted_id async def persist_event(
self, event: EventBase, context: EventContext, backfilled: bool = False
@defer.inlineCallbacks ) -> Tuple[int, int]:
def persist_event(
self, event: FrozenEvent, context: EventContext, backfilled: bool = False
):
""" """
Returns: Returns:
Deferred[Tuple[int, int]]: the stream ordering of ``event``, The stream ordering of `event`, and the stream ordering of the
and the stream ordering of the latest persisted event latest persisted event
""" """
deferred = self._event_persist_queue.add_to_queue( deferred = self._event_persist_queue.add_to_queue(
event.room_id, [(event, context)], backfilled=backfilled event.room_id, [(event, context)], backfilled=backfilled
@ -246,9 +242,9 @@ class EventsPersistenceStorage(object):
self._maybe_start_persisting(event.room_id) self._maybe_start_persisting(event.room_id)
yield make_deferred_yieldable(deferred) await make_deferred_yieldable(deferred)
max_persisted_id = yield self.main_store.get_current_events_token() max_persisted_id = self.main_store.get_current_events_token()
return (event.internal_metadata.stream_ordering, max_persisted_id) return (event.internal_metadata.stream_ordering, max_persisted_id)
def _maybe_start_persisting(self, room_id: str): def _maybe_start_persisting(self, room_id: str):
@ -262,7 +258,7 @@ class EventsPersistenceStorage(object):
async def _persist_events( async def _persist_events(
self, self,
events_and_contexts: List[Tuple[FrozenEvent, EventContext]], events_and_contexts: List[Tuple[EventBase, EventContext]],
backfilled: bool = False, backfilled: bool = False,
): ):
"""Calculates the change to current state and forward extremities, and """Calculates the change to current state and forward extremities, and
@ -439,7 +435,7 @@ class EventsPersistenceStorage(object):
async def _calculate_new_extremities( async def _calculate_new_extremities(
self, self,
room_id: str, room_id: str,
event_contexts: List[Tuple[FrozenEvent, EventContext]], event_contexts: List[Tuple[EventBase, EventContext]],
latest_event_ids: List[str], latest_event_ids: List[str],
): ):
"""Calculates the new forward extremities for a room given events to """Calculates the new forward extremities for a room given events to
@ -497,7 +493,7 @@ class EventsPersistenceStorage(object):
async def _get_new_state_after_events( async def _get_new_state_after_events(
self, self,
room_id: str, room_id: str,
events_context: List[Tuple[FrozenEvent, EventContext]], events_context: List[Tuple[EventBase, EventContext]],
old_latest_event_ids: Iterable[str], old_latest_event_ids: Iterable[str],
new_latest_event_ids: Iterable[str], new_latest_event_ids: Iterable[str],
) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]]]: ) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]]]:
@ -683,7 +679,7 @@ class EventsPersistenceStorage(object):
async def _is_server_still_joined( async def _is_server_still_joined(
self, self,
room_id: str, room_id: str,
ev_ctx_rm: List[Tuple[FrozenEvent, EventContext]], ev_ctx_rm: List[Tuple[EventBase, EventContext]],
delta: DeltaState, delta: DeltaState,
current_state: Optional[StateMap[str]], current_state: Optional[StateMap[str]],
potentially_left_users: Set[str], potentially_left_users: Set[str],

View file

@ -15,8 +15,7 @@
import itertools import itertools
import logging import logging
from typing import Set
from twisted.internet import defer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -28,49 +27,48 @@ class PurgeEventsStorage(object):
def __init__(self, hs, stores): def __init__(self, hs, stores):
self.stores = stores self.stores = stores
@defer.inlineCallbacks async def purge_room(self, room_id: str):
def purge_room(self, room_id: str):
"""Deletes all record of a room """Deletes all record of a room
""" """
state_groups_to_delete = yield self.stores.main.purge_room(room_id) state_groups_to_delete = await self.stores.main.purge_room(room_id)
yield self.stores.state.purge_room_state(room_id, state_groups_to_delete) await self.stores.state.purge_room_state(room_id, state_groups_to_delete)
@defer.inlineCallbacks async def purge_history(
def purge_history(self, room_id, token, delete_local_events): self, room_id: str, token: str, delete_local_events: bool
) -> None:
"""Deletes room history before a certain point """Deletes room history before a certain point
Args: Args:
room_id (str): room_id: The room ID
token (str): A topological token to delete events before token: A topological token to delete events before
delete_local_events (bool): delete_local_events:
if True, we will delete local events as well as remote ones if True, we will delete local events as well as remote ones
(instead of just marking them as outliers and deleting their (instead of just marking them as outliers and deleting their
state groups). state groups).
""" """
state_groups = yield self.stores.main.purge_history( state_groups = await self.stores.main.purge_history(
room_id, token, delete_local_events room_id, token, delete_local_events
) )
logger.info("[purge] finding state groups that can be deleted") logger.info("[purge] finding state groups that can be deleted")
sg_to_delete = yield self._find_unreferenced_groups(state_groups) sg_to_delete = await self._find_unreferenced_groups(state_groups)
yield self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete) await self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete)
@defer.inlineCallbacks async def _find_unreferenced_groups(self, state_groups: Set[int]) -> Set[int]:
def _find_unreferenced_groups(self, state_groups):
"""Used when purging history to figure out which state groups can be """Used when purging history to figure out which state groups can be
deleted. deleted.
Args: Args:
state_groups (set[int]): Set of state groups referenced by events state_groups: Set of state groups referenced by events
that are going to be deleted. that are going to be deleted.
Returns: Returns:
Deferred[set[int]] The set of state groups that can be deleted. The set of state groups that can be deleted.
""" """
# Graph of state group -> previous group # Graph of state group -> previous group
graph = {} graph = {}
@ -93,7 +91,7 @@ class PurgeEventsStorage(object):
current_search = set(itertools.islice(next_to_search, 100)) current_search = set(itertools.islice(next_to_search, 100))
next_to_search -= current_search next_to_search -= current_search
referenced = yield self.stores.main.get_referenced_state_groups( referenced = await self.stores.main.get_referenced_state_groups(
current_search current_search
) )
referenced_groups |= referenced referenced_groups |= referenced
@ -102,7 +100,7 @@ class PurgeEventsStorage(object):
# groups that are referenced. # groups that are referenced.
current_search -= referenced current_search -= referenced
edges = yield self.stores.state.get_previous_state_groups(current_search) edges = await self.stores.state.get_previous_state_groups(current_search)
prevs = set(edges.values()) prevs = set(edges.values())
# We don't bother re-handling groups we've already seen # We don't bother re-handling groups we've already seen

View file

@ -14,13 +14,12 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Iterable, List, TypeVar from typing import Awaitable, Dict, Iterable, List, Optional, Set, Tuple, TypeVar
import attr import attr
from twisted.internet import defer
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.types import StateMap from synapse.types import StateMap
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -34,16 +33,16 @@ class StateFilter(object):
"""A filter used when querying for state. """A filter used when querying for state.
Attributes: Attributes:
types (dict[str, set[str]|None]): Map from type to set of state keys (or types: Map from type to set of state keys (or None). This specifies
None). This specifies which state_keys for the given type to fetch which state_keys for the given type to fetch from the DB. If None
from the DB. If None then all events with that type are fetched. If then all events with that type are fetched. If the set is empty
the set is empty then no events with that type are fetched. then no events with that type are fetched.
include_others (bool): Whether to fetch events with types that do not include_others: Whether to fetch events with types that do not
appear in `types`. appear in `types`.
""" """
types = attr.ib() types = attr.ib(type=Dict[str, Optional[Set[str]]])
include_others = attr.ib(default=False) include_others = attr.ib(default=False, type=bool)
def __attrs_post_init__(self): def __attrs_post_init__(self):
# If `include_others` is set we canonicalise the filter by removing # If `include_others` is set we canonicalise the filter by removing
@ -52,36 +51,35 @@ class StateFilter(object):
self.types = {k: v for k, v in self.types.items() if v is not None} self.types = {k: v for k, v in self.types.items() if v is not None}
@staticmethod @staticmethod
def all(): def all() -> "StateFilter":
"""Creates a filter that fetches everything. """Creates a filter that fetches everything.
Returns: Returns:
StateFilter The new state filter.
""" """
return StateFilter(types={}, include_others=True) return StateFilter(types={}, include_others=True)
@staticmethod @staticmethod
def none(): def none() -> "StateFilter":
"""Creates a filter that fetches nothing. """Creates a filter that fetches nothing.
Returns: Returns:
StateFilter The new state filter.
""" """
return StateFilter(types={}, include_others=False) return StateFilter(types={}, include_others=False)
@staticmethod @staticmethod
def from_types(types): def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter":
"""Creates a filter that only fetches the given types """Creates a filter that only fetches the given types
Args: Args:
types (Iterable[tuple[str, str|None]]): A list of type and state types: A list of type and state keys to fetch. A state_key of None
keys to fetch. A state_key of None fetches everything for fetches everything for that type
that type
Returns: Returns:
StateFilter The new state filter.
""" """
type_dict = {} type_dict = {} # type: Dict[str, Optional[Set[str]]]
for typ, s in types: for typ, s in types:
if typ in type_dict: if typ in type_dict:
if type_dict[typ] is None: if type_dict[typ] is None:
@ -91,24 +89,24 @@ class StateFilter(object):
type_dict[typ] = None type_dict[typ] = None
continue continue
type_dict.setdefault(typ, set()).add(s) type_dict.setdefault(typ, set()).add(s) # type: ignore
return StateFilter(types=type_dict) return StateFilter(types=type_dict)
@staticmethod @staticmethod
def from_lazy_load_member_list(members): def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter":
"""Creates a filter that returns all non-member events, plus the member """Creates a filter that returns all non-member events, plus the member
events for the given users events for the given users
Args: Args:
members (iterable[str]): Set of user IDs members: Set of user IDs
Returns: Returns:
StateFilter The new state filter
""" """
return StateFilter(types={EventTypes.Member: set(members)}, include_others=True) return StateFilter(types={EventTypes.Member: set(members)}, include_others=True)
def return_expanded(self): def return_expanded(self) -> "StateFilter":
"""Creates a new StateFilter where type wild cards have been removed """Creates a new StateFilter where type wild cards have been removed
(except for memberships). The returned filter is a superset of the (except for memberships). The returned filter is a superset of the
current one, i.e. anything that passes the current filter will pass current one, i.e. anything that passes the current filter will pass
@ -130,7 +128,7 @@ class StateFilter(object):
return all non-member events return all non-member events
Returns: Returns:
StateFilter The new state filter.
""" """
if self.is_full(): if self.is_full():
@ -167,7 +165,7 @@ class StateFilter(object):
include_others=True, include_others=True,
) )
def make_sql_filter_clause(self): def make_sql_filter_clause(self) -> Tuple[str, List[str]]:
"""Converts the filter to an SQL clause. """Converts the filter to an SQL clause.
For example: For example:
@ -179,13 +177,12 @@ class StateFilter(object):
Returns: Returns:
tuple[str, list]: The SQL string (may be empty) and arguments. An The SQL string (may be empty) and arguments. An empty SQL string is
empty SQL string is returned when the filter matches everything returned when the filter matches everything (i.e. is "full").
(i.e. is "full").
""" """
where_clause = "" where_clause = ""
where_args = [] where_args = [] # type: List[str]
if self.is_full(): if self.is_full():
return where_clause, where_args return where_clause, where_args
@ -221,7 +218,7 @@ class StateFilter(object):
return where_clause, where_args return where_clause, where_args
def max_entries_returned(self): def max_entries_returned(self) -> Optional[int]:
"""Returns the maximum number of entries this filter will return if """Returns the maximum number of entries this filter will return if
known, otherwise returns None. known, otherwise returns None.
@ -260,33 +257,33 @@ class StateFilter(object):
return filtered_state return filtered_state
def is_full(self): def is_full(self) -> bool:
"""Whether this filter fetches everything or not """Whether this filter fetches everything or not
Returns: Returns:
bool True if the filter fetches everything.
""" """
return self.include_others and not self.types return self.include_others and not self.types
def has_wildcards(self): def has_wildcards(self) -> bool:
"""Whether the filter includes wildcards or is attempting to fetch """Whether the filter includes wildcards or is attempting to fetch
specific state. specific state.
Returns: Returns:
bool True if the filter includes wildcards.
""" """
return self.include_others or any( return self.include_others or any(
state_keys is None for state_keys in self.types.values() state_keys is None for state_keys in self.types.values()
) )
def concrete_types(self): def concrete_types(self) -> List[Tuple[str, str]]:
"""Returns a list of concrete type/state_keys (i.e. not None) that """Returns a list of concrete type/state_keys (i.e. not None) that
will be fetched. This will be a complete list if `has_wildcards` will be fetched. This will be a complete list if `has_wildcards`
returns False, but otherwise will be a subset (or even empty). returns False, but otherwise will be a subset (or even empty).
Returns: Returns:
list[tuple[str,str]] A list of type/state_keys tuples.
""" """
return [ return [
(t, s) (t, s)
@ -295,7 +292,7 @@ class StateFilter(object):
for s in state_keys for s in state_keys
] ]
def get_member_split(self): def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]:
"""Return the filter split into two: one which assumes it's exclusively """Return the filter split into two: one which assumes it's exclusively
matching against member state, and one which assumes it's matching matching against member state, and one which assumes it's matching
against non member state. against non member state.
@ -307,7 +304,7 @@ class StateFilter(object):
state caches). state caches).
Returns: Returns:
tuple[StateFilter, StateFilter]: The member and non member filters The member and non member filters
""" """
if EventTypes.Member in self.types: if EventTypes.Member in self.types:
@ -340,6 +337,9 @@ class StateGroupStorage(object):
"""Given a state group try to return a previous group and a delta between """Given a state group try to return a previous group and a delta between
the old and the new. the old and the new.
Args:
state_group: The state group used to retrieve state deltas.
Returns: Returns:
Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]: Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]:
(prev_group, delta_ids) (prev_group, delta_ids)
@ -347,55 +347,59 @@ class StateGroupStorage(object):
return self.stores.state.get_state_group_delta(state_group) return self.stores.state.get_state_group_delta(state_group)
@defer.inlineCallbacks async def get_state_groups_ids(
def get_state_groups_ids(self, _room_id, event_ids): self, _room_id: str, event_ids: Iterable[str]
) -> Dict[int, StateMap[str]]:
"""Get the event IDs of all the state for the state groups for the given events """Get the event IDs of all the state for the state groups for the given events
Args: Args:
_room_id (str): id of the room for these events _room_id: id of the room for these events
event_ids (iterable[str]): ids of the events event_ids: ids of the events
Returns: Returns:
Deferred[dict[int, StateMap[str]]]: dict of state_group_id -> (dict of (type, state_key) -> event id)
dict of state_group_id -> (dict of (type, state_key) -> event id)
""" """
if not event_ids: if not event_ids:
return {} return {}
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
groups = set(event_to_groups.values()) groups = set(event_to_groups.values())
group_to_state = yield self.stores.state._get_state_for_groups(groups) group_to_state = await self.stores.state._get_state_for_groups(groups)
return group_to_state return group_to_state
@defer.inlineCallbacks async def get_state_ids_for_group(self, state_group: int) -> StateMap[str]:
def get_state_ids_for_group(self, state_group):
"""Get the event IDs of all the state in the given state group """Get the event IDs of all the state in the given state group
Args: Args:
state_group (int) state_group: A state group for which we want to get the state IDs.
Returns: Returns:
Deferred[dict]: Resolves to a map of (type, state_key) -> event_id Resolves to a map of (type, state_key) -> event_id
""" """
group_to_state = yield self._get_state_for_groups((state_group,)) group_to_state = await self._get_state_for_groups((state_group,))
return group_to_state[state_group] return group_to_state[state_group]
@defer.inlineCallbacks async def get_state_groups(
def get_state_groups(self, room_id, event_ids): self, room_id: str, event_ids: Iterable[str]
) -> Dict[int, List[EventBase]]:
""" Get the state groups for the given list of event_ids """ Get the state groups for the given list of event_ids
Args:
room_id: ID of the room for these events.
event_ids: The event IDs to retrieve state for.
Returns: Returns:
Deferred[dict[int, list[EventBase]]]: dict of state_group_id -> list of state events.
dict of state_group_id -> list of state events.
""" """
if not event_ids: if not event_ids:
return {} return {}
group_to_ids = yield self.get_state_groups_ids(room_id, event_ids) group_to_ids = await self.get_state_groups_ids(room_id, event_ids)
state_event_map = yield self.stores.main.get_events( state_event_map = await self.stores.main.get_events(
[ [
ev_id ev_id
for group_ids in group_to_ids.values() for group_ids in group_to_ids.values()
@ -415,7 +419,7 @@ class StateGroupStorage(object):
def _get_state_groups_from_groups( def _get_state_groups_from_groups(
self, groups: List[int], state_filter: StateFilter self, groups: List[int], state_filter: StateFilter
): ) -> Awaitable[Dict[int, StateMap[str]]]:
"""Returns the state groups for a given set of groups, filtering on """Returns the state groups for a given set of groups, filtering on
types of state events. types of state events.
@ -423,31 +427,34 @@ class StateGroupStorage(object):
groups: list of state group IDs to query groups: list of state group IDs to query
state_filter: The state filter used to fetch state state_filter: The state filter used to fetch state
from the database. from the database.
Returns: Returns:
Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map. Dict of state group to state map.
""" """
return self.stores.state._get_state_groups_from_groups(groups, state_filter) return self.stores.state._get_state_groups_from_groups(groups, state_filter)
@defer.inlineCallbacks async def get_state_for_events(
def get_state_for_events(self, event_ids, state_filter=StateFilter.all()): self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
):
"""Given a list of event_ids and type tuples, return a list of state """Given a list of event_ids and type tuples, return a list of state
dicts for each event. dicts for each event.
Args: Args:
event_ids (list[string]) event_ids: The events to fetch the state of.
state_filter (StateFilter): The state filter used to fetch state state_filter: The state filter used to fetch state.
from the database.
Returns: Returns:
deferred: A dict of (event_id) -> (type, state_key) -> [state_events] A dict of (event_id) -> (type, state_key) -> [state_events]
""" """
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
groups = set(event_to_groups.values()) groups = set(event_to_groups.values())
group_to_state = yield self.stores.state._get_state_for_groups( group_to_state = await self.stores.state._get_state_for_groups(
groups, state_filter groups, state_filter
) )
state_event_map = yield self.stores.main.get_events( state_event_map = await self.stores.main.get_events(
[ev_id for sd in group_to_state.values() for ev_id in sd.values()], [ev_id for sd in group_to_state.values() for ev_id in sd.values()],
get_prev_content=False, get_prev_content=False,
) )
@ -463,24 +470,24 @@ class StateGroupStorage(object):
return {event: event_to_state[event] for event in event_ids} return {event: event_to_state[event] for event in event_ids}
@defer.inlineCallbacks async def get_state_ids_for_events(
def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()): self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
):
""" """
Get the state dicts corresponding to a list of events, containing the event_ids Get the state dicts corresponding to a list of events, containing the event_ids
of the state events (as opposed to the events themselves) of the state events (as opposed to the events themselves)
Args: Args:
event_ids(list(str)): events whose state should be returned event_ids: events whose state should be returned
state_filter (StateFilter): The state filter used to fetch state state_filter: The state filter used to fetch state from the database.
from the database.
Returns: Returns:
A deferred dict from event_id -> (type, state_key) -> event_id A dict from event_id -> (type, state_key) -> event_id
""" """
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
groups = set(event_to_groups.values()) groups = set(event_to_groups.values())
group_to_state = yield self.stores.state._get_state_for_groups( group_to_state = await self.stores.state._get_state_for_groups(
groups, state_filter groups, state_filter
) )
@ -491,67 +498,72 @@ class StateGroupStorage(object):
return {event: event_to_state[event] for event in event_ids} return {event: event_to_state[event] for event in event_ids}
@defer.inlineCallbacks async def get_state_for_event(
def get_state_for_event(self, event_id, state_filter=StateFilter.all()): self, event_id: str, state_filter: StateFilter = StateFilter.all()
):
""" """
Get the state dict corresponding to a particular event Get the state dict corresponding to a particular event
Args: Args:
event_id(str): event whose state should be returned event_id: event whose state should be returned
state_filter (StateFilter): The state filter used to fetch state state_filter: The state filter used to fetch state from the database.
from the database.
Returns: Returns:
A deferred dict from (type, state_key) -> state_event A dict from (type, state_key) -> state_event
""" """
state_map = yield self.get_state_for_events([event_id], state_filter) state_map = await self.get_state_for_events([event_id], state_filter)
return state_map[event_id] return state_map[event_id]
@defer.inlineCallbacks async def get_state_ids_for_event(
def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()): self, event_id: str, state_filter: StateFilter = StateFilter.all()
):
""" """
Get the state dict corresponding to a particular event Get the state dict corresponding to a particular event
Args: Args:
event_id(str): event whose state should be returned event_id: event whose state should be returned
state_filter (StateFilter): The state filter used to fetch state state_filter: The state filter used to fetch state from the database.
from the database.
Returns: Returns:
A deferred dict from (type, state_key) -> state_event A deferred dict from (type, state_key) -> state_event
""" """
state_map = yield self.get_state_ids_for_events([event_id], state_filter) state_map = await self.get_state_ids_for_events([event_id], state_filter)
return state_map[event_id] return state_map[event_id]
def _get_state_for_groups( def _get_state_for_groups(
self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all() self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
): ) -> Awaitable[Dict[int, StateMap[str]]]:
"""Gets the state at each of a list of state groups, optionally """Gets the state at each of a list of state groups, optionally
filtering by type/state_key filtering by type/state_key
Args: Args:
groups (iterable[int]): list of state groups for which we want groups: list of state groups for which we want to get the state.
to get the state. state_filter: The state filter used to fetch state.
state_filter (StateFilter): The state filter used to fetch state
from the database. from the database.
Returns: Returns:
Deferred[dict[int, StateMap[str]]]: Dict of state group to state map. Dict of state group to state map.
""" """
return self.stores.state._get_state_for_groups(groups, state_filter) return self.stores.state._get_state_for_groups(groups, state_filter)
def store_state_group( def store_state_group(
self, event_id, room_id, prev_group, delta_ids, current_state_ids self,
event_id: str,
room_id: str,
prev_group: Optional[int],
delta_ids: Optional[dict],
current_state_ids: dict,
): ):
"""Store a new set of state, returning a newly assigned state group. """Store a new set of state, returning a newly assigned state group.
Args: Args:
event_id (str): The event ID for which the state was calculated event_id: The event ID for which the state was calculated.
room_id (str) room_id: ID of the room for which the state was calculated.
prev_group (int|None): A previous state group for the room, optional. prev_group: A previous state group for the room, optional.
delta_ids (dict|None): The delta between state at `prev_group` and delta_ids: The delta between state at `prev_group` and
`current_state_ids`, if `prev_group` was given. Same format as `current_state_ids`, if `prev_group` was given. Same format as
`current_state_ids`. `current_state_ids`.
current_state_ids (dict): The state to store. Map of (type, state_key) current_state_ids: The state to store. Map of (type, state_key)
to event_id. to event_id.
Returns: Returns:

View file

@ -16,8 +16,6 @@
import logging import logging
import operator import operator
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.storage import Storage from synapse.storage import Storage
@ -39,8 +37,7 @@ MEMBERSHIP_PRIORITY = (
) )
@defer.inlineCallbacks async def filter_events_for_client(
def filter_events_for_client(
storage: Storage, storage: Storage,
user_id, user_id,
events, events,
@ -67,19 +64,19 @@ def filter_events_for_client(
also be called to check whether a user can see the state at a given point. also be called to check whether a user can see the state at a given point.
Returns: Returns:
Deferred[list[synapse.events.EventBase]] list[synapse.events.EventBase]
""" """
# Filter out events that have been soft failed so that we don't relay them # Filter out events that have been soft failed so that we don't relay them
# to clients. # to clients.
events = [e for e in events if not e.internal_metadata.is_soft_failed()] events = [e for e in events if not e.internal_metadata.is_soft_failed()]
types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id)) types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id))
event_id_to_state = yield storage.state.get_state_for_events( event_id_to_state = await storage.state.get_state_for_events(
frozenset(e.event_id for e in events), frozenset(e.event_id for e in events),
state_filter=StateFilter.from_types(types), state_filter=StateFilter.from_types(types),
) )
ignore_dict_content = yield storage.main.get_global_account_data_by_type_for_user( ignore_dict_content = await storage.main.get_global_account_data_by_type_for_user(
"m.ignored_user_list", user_id "m.ignored_user_list", user_id
) )
@ -90,7 +87,7 @@ def filter_events_for_client(
else [] else []
) )
erased_senders = yield storage.main.are_users_erased((e.sender for e in events)) erased_senders = await storage.main.are_users_erased((e.sender for e in events))
if filter_send_to_client: if filter_send_to_client:
room_ids = {e.room_id for e in events} room_ids = {e.room_id for e in events}
@ -99,7 +96,7 @@ def filter_events_for_client(
for room_id in room_ids: for room_id in room_ids:
retention_policies[ retention_policies[
room_id room_id
] = yield storage.main.get_retention_policy_for_room(room_id) ] = await storage.main.get_retention_policy_for_room(room_id)
def allowed(event): def allowed(event):
""" """
@ -254,8 +251,7 @@ def filter_events_for_client(
return list(filtered_events) return list(filtered_events)
@defer.inlineCallbacks async def filter_events_for_server(
def filter_events_for_server(
storage: Storage, storage: Storage,
server_name, server_name,
events, events,
@ -277,7 +273,7 @@ def filter_events_for_server(
backfill or not. backfill or not.
Returns Returns
Deferred[list[FrozenEvent]] list[FrozenEvent]
""" """
def is_sender_erased(event, erased_senders): def is_sender_erased(event, erased_senders):
@ -321,7 +317,7 @@ def filter_events_for_server(
# Lets check to see if all the events have a history visibility # Lets check to see if all the events have a history visibility
# of "shared" or "world_readable". If that's the case then we don't # of "shared" or "world_readable". If that's the case then we don't
# need to check membership (as we know the server is in the room). # need to check membership (as we know the server is in the room).
event_to_state_ids = yield storage.state.get_state_ids_for_events( event_to_state_ids = await storage.state.get_state_ids_for_events(
frozenset(e.event_id for e in events), frozenset(e.event_id for e in events),
state_filter=StateFilter.from_types( state_filter=StateFilter.from_types(
types=((EventTypes.RoomHistoryVisibility, ""),) types=((EventTypes.RoomHistoryVisibility, ""),)
@ -339,14 +335,14 @@ def filter_events_for_server(
if not visibility_ids: if not visibility_ids:
all_open = True all_open = True
else: else:
event_map = yield storage.main.get_events(visibility_ids) event_map = await storage.main.get_events(visibility_ids)
all_open = all( all_open = all(
e.content.get("history_visibility") in (None, "shared", "world_readable") e.content.get("history_visibility") in (None, "shared", "world_readable")
for e in event_map.values() for e in event_map.values()
) )
if not check_history_visibility_only: if not check_history_visibility_only:
erased_senders = yield storage.main.are_users_erased((e.sender for e in events)) erased_senders = await storage.main.are_users_erased((e.sender for e in events))
else: else:
# We don't want to check whether users are erased, which is equivalent # We don't want to check whether users are erased, which is equivalent
# to no users having been erased. # to no users having been erased.
@ -375,7 +371,7 @@ def filter_events_for_server(
# first, for each event we're wanting to return, get the event_ids # first, for each event we're wanting to return, get the event_ids
# of the history vis and membership state at those events. # of the history vis and membership state at those events.
event_to_state_ids = yield storage.state.get_state_ids_for_events( event_to_state_ids = await storage.state.get_state_ids_for_events(
frozenset(e.event_id for e in events), frozenset(e.event_id for e in events),
state_filter=StateFilter.from_types( state_filter=StateFilter.from_types(
types=((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, None)) types=((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, None))
@ -405,7 +401,7 @@ def filter_events_for_server(
return False return False
return state_key[idx + 1 :] == server_name return state_key[idx + 1 :] == server_name
event_map = yield storage.main.get_events( event_map = await storage.main.get_events(
[e_id for e_id, key in event_id_to_state_key.items() if include(key[0], key[1])] [e_id for e_id, key in event_id_to_state_key.items() if include(key[0], key[1])]
) )

View file

@ -50,13 +50,17 @@ class ApplicationServiceTestCase(unittest.TestCase):
def test_regex_user_id_prefix_match(self): def test_regex_user_id_prefix_match(self):
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
self.event.sender = "@irc_foobar:matrix.org" self.event.sender = "@irc_foobar:matrix.org"
self.assertTrue((yield self.service.is_interested(self.event))) self.assertTrue(
(yield defer.ensureDeferred(self.service.is_interested(self.event)))
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_regex_user_id_prefix_no_match(self): def test_regex_user_id_prefix_no_match(self):
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
self.event.sender = "@someone_else:matrix.org" self.event.sender = "@someone_else:matrix.org"
self.assertFalse((yield self.service.is_interested(self.event))) self.assertFalse(
(yield defer.ensureDeferred(self.service.is_interested(self.event)))
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_regex_room_member_is_checked(self): def test_regex_room_member_is_checked(self):
@ -64,7 +68,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.event.sender = "@someone_else:matrix.org" self.event.sender = "@someone_else:matrix.org"
self.event.type = "m.room.member" self.event.type = "m.room.member"
self.event.state_key = "@irc_foobar:matrix.org" self.event.state_key = "@irc_foobar:matrix.org"
self.assertTrue((yield self.service.is_interested(self.event))) self.assertTrue(
(yield defer.ensureDeferred(self.service.is_interested(self.event)))
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_regex_room_id_match(self): def test_regex_room_id_match(self):
@ -72,7 +78,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
_regex("!some_prefix.*some_suffix:matrix.org") _regex("!some_prefix.*some_suffix:matrix.org")
) )
self.event.room_id = "!some_prefixs0m3th1nGsome_suffix:matrix.org" self.event.room_id = "!some_prefixs0m3th1nGsome_suffix:matrix.org"
self.assertTrue((yield self.service.is_interested(self.event))) self.assertTrue(
(yield defer.ensureDeferred(self.service.is_interested(self.event)))
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_regex_room_id_no_match(self): def test_regex_room_id_no_match(self):
@ -80,19 +88,26 @@ class ApplicationServiceTestCase(unittest.TestCase):
_regex("!some_prefix.*some_suffix:matrix.org") _regex("!some_prefix.*some_suffix:matrix.org")
) )
self.event.room_id = "!XqBunHwQIXUiqCaoxq:matrix.org" self.event.room_id = "!XqBunHwQIXUiqCaoxq:matrix.org"
self.assertFalse((yield self.service.is_interested(self.event))) self.assertFalse(
(yield defer.ensureDeferred(self.service.is_interested(self.event)))
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_regex_alias_match(self): def test_regex_alias_match(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append( self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org") _regex("#irc_.*:matrix.org")
) )
self.store.get_aliases_for_room.return_value = [ self.store.get_aliases_for_room.return_value = defer.succeed(
"#irc_foobar:matrix.org", ["#irc_foobar:matrix.org", "#athing:matrix.org"]
"#athing:matrix.org", )
] self.store.get_users_in_room.return_value = defer.succeed([])
self.store.get_users_in_room.return_value = [] self.assertTrue(
self.assertTrue((yield self.service.is_interested(self.event, self.store))) (
yield defer.ensureDeferred(
self.service.is_interested(self.event, self.store)
)
)
)
def test_non_exclusive_alias(self): def test_non_exclusive_alias(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append( self.service.namespaces[ApplicationService.NS_ALIASES].append(
@ -135,12 +150,17 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.service.namespaces[ApplicationService.NS_ALIASES].append( self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org") _regex("#irc_.*:matrix.org")
) )
self.store.get_aliases_for_room.return_value = [ self.store.get_aliases_for_room.return_value = defer.succeed(
"#xmpp_foobar:matrix.org", ["#xmpp_foobar:matrix.org", "#athing:matrix.org"]
"#athing:matrix.org", )
] self.store.get_users_in_room.return_value = defer.succeed([])
self.store.get_users_in_room.return_value = [] self.assertFalse(
self.assertFalse((yield self.service.is_interested(self.event, self.store))) (
yield defer.ensureDeferred(
self.service.is_interested(self.event, self.store)
)
)
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_regex_multiple_matches(self): def test_regex_multiple_matches(self):
@ -149,9 +169,17 @@ class ApplicationServiceTestCase(unittest.TestCase):
) )
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
self.event.sender = "@irc_foobar:matrix.org" self.event.sender = "@irc_foobar:matrix.org"
self.store.get_aliases_for_room.return_value = ["#irc_barfoo:matrix.org"] self.store.get_aliases_for_room.return_value = defer.succeed(
self.store.get_users_in_room.return_value = [] ["#irc_barfoo:matrix.org"]
self.assertTrue((yield self.service.is_interested(self.event, self.store))) )
self.store.get_users_in_room.return_value = defer.succeed([])
self.assertTrue(
(
yield defer.ensureDeferred(
self.service.is_interested(self.event, self.store)
)
)
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_interested_in_self(self): def test_interested_in_self(self):
@ -161,19 +189,24 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.event.type = "m.room.member" self.event.type = "m.room.member"
self.event.content = {"membership": "invite"} self.event.content = {"membership": "invite"}
self.event.state_key = self.service.sender self.event.state_key = self.service.sender
self.assertTrue((yield self.service.is_interested(self.event))) self.assertTrue(
(yield defer.ensureDeferred(self.service.is_interested(self.event)))
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_member_list_match(self): def test_member_list_match(self):
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
self.store.get_users_in_room.return_value = [ # Note that @irc_fo:here is the AS user.
"@alice:here", self.store.get_users_in_room.return_value = defer.succeed(
"@irc_fo:here", # AS user ["@alice:here", "@irc_fo:here", "@bob:here"]
"@bob:here", )
] self.store.get_aliases_for_room.return_value = defer.succeed([])
self.store.get_aliases_for_room.return_value = []
self.event.sender = "@xmpp_foobar:matrix.org" self.event.sender = "@xmpp_foobar:matrix.org"
self.assertTrue( self.assertTrue(
(yield self.service.is_interested(event=self.event, store=self.store)) (
yield defer.ensureDeferred(
self.service.is_interested(event=self.event, store=self.store)
)
)
) )

View file

@ -25,6 +25,7 @@ from synapse.appservice.scheduler import (
from synapse.logging.context import make_deferred_yieldable from synapse.logging.context import make_deferred_yieldable
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable
from ..utils import MockClock from ..utils import MockClock
@ -52,11 +53,11 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
self.store.get_appservice_state = Mock( self.store.get_appservice_state = Mock(
return_value=defer.succeed(ApplicationServiceState.UP) return_value=defer.succeed(ApplicationServiceState.UP)
) )
txn.send = Mock(return_value=defer.succeed(True)) txn.send = Mock(return_value=make_awaitable(True))
self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn)) self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn))
# actual call # actual call
self.txnctrl.send(service, events) self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
self.store.create_appservice_txn.assert_called_once_with( self.store.create_appservice_txn.assert_called_once_with(
service=service, events=events # txn made and saved service=service, events=events # txn made and saved
@ -77,7 +78,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn)) self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn))
# actual call # actual call
self.txnctrl.send(service, events) self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
self.store.create_appservice_txn.assert_called_once_with( self.store.create_appservice_txn.assert_called_once_with(
service=service, events=events # txn made and saved service=service, events=events # txn made and saved
@ -98,11 +99,11 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
return_value=defer.succeed(ApplicationServiceState.UP) return_value=defer.succeed(ApplicationServiceState.UP)
) )
self.store.set_appservice_state = Mock(return_value=defer.succeed(True)) self.store.set_appservice_state = Mock(return_value=defer.succeed(True))
txn.send = Mock(return_value=defer.succeed(False)) # fails to send txn.send = Mock(return_value=make_awaitable(False)) # fails to send
self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn)) self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn))
# actual call # actual call
self.txnctrl.send(service, events) self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
self.store.create_appservice_txn.assert_called_once_with( self.store.create_appservice_txn.assert_called_once_with(
service=service, events=events service=service, events=events
@ -144,7 +145,8 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
self.recoverer.recover() self.recoverer.recover()
# shouldn't have called anything prior to waiting for exp backoff # shouldn't have called anything prior to waiting for exp backoff
self.assertEquals(0, self.store.get_oldest_unsent_txn.call_count) self.assertEquals(0, self.store.get_oldest_unsent_txn.call_count)
txn.send = Mock(return_value=True) txn.send = Mock(return_value=make_awaitable(True))
txn.complete.return_value = make_awaitable(None)
# wait for exp backoff # wait for exp backoff
self.clock.advance_time(2) self.clock.advance_time(2)
self.assertEquals(1, txn.send.call_count) self.assertEquals(1, txn.send.call_count)
@ -169,7 +171,8 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
self.recoverer.recover() self.recoverer.recover()
self.assertEquals(0, self.store.get_oldest_unsent_txn.call_count) self.assertEquals(0, self.store.get_oldest_unsent_txn.call_count)
txn.send = Mock(return_value=False) txn.send = Mock(return_value=make_awaitable(False))
txn.complete.return_value = make_awaitable(None)
self.clock.advance_time(2) self.clock.advance_time(2)
self.assertEquals(1, txn.send.call_count) self.assertEquals(1, txn.send.call_count)
self.assertEquals(0, txn.complete.call_count) self.assertEquals(0, txn.complete.call_count)
@ -182,7 +185,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
self.assertEquals(3, txn.send.call_count) self.assertEquals(3, txn.send.call_count)
self.assertEquals(0, txn.complete.call_count) self.assertEquals(0, txn.complete.call_count)
self.assertEquals(0, self.callback.call_count) self.assertEquals(0, self.callback.call_count)
txn.send = Mock(return_value=True) # successfully send the txn txn.send = Mock(return_value=make_awaitable(True)) # successfully send the txn
pop_txn = True # returns the txn the first time, then no more. pop_txn = True # returns the txn the first time, then no more.
self.clock.advance_time(16) self.clock.advance_time(16)
self.assertEquals(1, txn.send.call_count) # new mock reset call count self.assertEquals(1, txn.send.call_count) # new mock reset call count

View file

@ -102,11 +102,10 @@ class KeyringTestCase(unittest.HomeserverTestCase):
} }
persp_deferred = defer.Deferred() persp_deferred = defer.Deferred()
@defer.inlineCallbacks async def get_perspectives(**kwargs):
def get_perspectives(**kwargs):
self.assertEquals(current_context().request, "11") self.assertEquals(current_context().request, "11")
with PreserveLoggingContext(): with PreserveLoggingContext():
yield persp_deferred await persp_deferred
return persp_resp return persp_resp
self.http_client.post_json.side_effect = get_perspectives self.http_client.post_json.side_effect = get_perspectives
@ -355,7 +354,7 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
} }
signedjson.sign.sign_json(response, SERVER_NAME, testkey) signedjson.sign.sign_json(response, SERVER_NAME, testkey)
def get_json(destination, path, **kwargs): async def get_json(destination, path, **kwargs):
self.assertEqual(destination, SERVER_NAME) self.assertEqual(destination, SERVER_NAME)
self.assertEqual(path, "/_matrix/key/v2/server/key1") self.assertEqual(path, "/_matrix/key/v2/server/key1")
return response return response
@ -444,7 +443,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
Tell the mock http client to expect a perspectives-server key query Tell the mock http client to expect a perspectives-server key query
""" """
def post_json(destination, path, data, **kwargs): async def post_json(destination, path, data, **kwargs):
self.assertEqual(destination, self.mock_perspective_server.server_name) self.assertEqual(destination, self.mock_perspective_server.server_name)
self.assertEqual(path, "/_matrix/key/v2/query") self.assertEqual(path, "/_matrix/key/v2/query")
@ -580,14 +579,12 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
# remove the perspectives server's signature # remove the perspectives server's signature
response = build_response() response = build_response()
del response["signatures"][self.mock_perspective_server.server_name] del response["signatures"][self.mock_perspective_server.server_name]
self.http_client.post_json.return_value = {"server_keys": [response]}
keys = get_key_from_perspectives(response) keys = get_key_from_perspectives(response)
self.assertEqual(keys, {}, "Expected empty dict with missing persp server sig") self.assertEqual(keys, {}, "Expected empty dict with missing persp server sig")
# remove the origin server's signature # remove the origin server's signature
response = build_response() response = build_response()
del response["signatures"][SERVER_NAME] del response["signatures"][SERVER_NAME]
self.http_client.post_json.return_value = {"server_keys": [response]}
keys = get_key_from_perspectives(response) keys = get_key_from_perspectives(response)
self.assertEqual(keys, {}, "Expected empty dict with missing origin server sig") self.assertEqual(keys, {}, "Expected empty dict with missing origin server sig")

View file

@ -23,6 +23,7 @@ from synapse.rest.client.v1 import login, room
from synapse.types import UserID from synapse.types import UserID
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable
class RoomComplexityTests(unittest.FederatingHomeserverTestCase): class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
@ -78,9 +79,40 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client() fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join # Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999})) fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
handler.federation_handler.do_invite_join = Mock( handler.federation_handler.do_invite_join = Mock(
return_value=defer.succeed(("", 1)) return_value=make_awaitable(("", 1))
)
d = handler._remote_join(
None,
["other.example.com"],
"roomid",
UserID.from_string(u1),
{"membership": "join"},
)
self.pump()
# The request failed with a SynapseError saying the resource limit was
# exceeded.
f = self.get_failure(d, SynapseError)
self.assertEqual(f.value.code, 400, f.value)
self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
def test_join_too_large_admin(self):
# Check whether an admin can join if option "admins_can_join" is undefined,
# this option defaults to false, so the join should fail.
u1 = self.register_user("u1", "pass", admin=True)
handler = self.hs.get_room_member_handler()
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
handler.federation_handler.do_invite_join = Mock(
return_value=make_awaitable(("", 1))
) )
d = handler._remote_join( d = handler._remote_join(
@ -116,9 +148,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client() fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join # Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=defer.succeed(None)) fed_transport.client.get_json = Mock(return_value=make_awaitable(None))
handler.federation_handler.do_invite_join = Mock( handler.federation_handler.do_invite_join = Mock(
return_value=defer.succeed(("", 1)) return_value=make_awaitable(("", 1))
) )
# Artificially raise the complexity # Artificially raise the complexity
@ -141,3 +173,81 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
f = self.get_failure(d, SynapseError) f = self.get_failure(d, SynapseError)
self.assertEqual(f.value.code, 400) self.assertEqual(f.value.code, 400)
self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
# Test the behavior of joining rooms which exceed the complexity if option
# limit_remote_rooms.admins_can_join is True.
servlets = [
admin.register_servlets,
room.register_servlets,
login.register_servlets,
]
def default_config(self):
config = super().default_config()
config["limit_remote_rooms"] = {
"enabled": True,
"complexity": 0.05,
"admins_can_join": True,
}
return config
def test_join_too_large_no_admin(self):
# A user which is not an admin should not be able to join a remote room
# which is too complex.
u1 = self.register_user("u1", "pass")
handler = self.hs.get_room_member_handler()
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
handler.federation_handler.do_invite_join = Mock(
return_value=make_awaitable(("", 1))
)
d = handler._remote_join(
None,
["other.example.com"],
"roomid",
UserID.from_string(u1),
{"membership": "join"},
)
self.pump()
# The request failed with a SynapseError saying the resource limit was
# exceeded.
f = self.get_failure(d, SynapseError)
self.assertEqual(f.value.code, 400, f.value)
self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
def test_join_too_large_admin(self):
# An admin should be able to join rooms where a complexity check fails.
u1 = self.register_user("u1", "pass", admin=True)
handler = self.hs.get_room_member_handler()
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
handler.federation_handler.do_invite_join = Mock(
return_value=make_awaitable(("", 1))
)
d = handler._remote_join(
None,
["other.example.com"],
"roomid",
UserID.from_string(u1),
{"membership": "join"},
)
self.pump()
# The request success since the user is an admin
self.get_success(d)

View file

@ -47,13 +47,13 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
mock_send_transaction = ( mock_send_transaction = (
self.hs.get_federation_transport_client().send_transaction self.hs.get_federation_transport_client().send_transaction
) )
mock_send_transaction.return_value = defer.succeed({}) mock_send_transaction.return_value = make_awaitable({})
sender = self.hs.get_federation_sender() sender = self.hs.get_federation_sender()
receipt = ReadReceipt( receipt = ReadReceipt(
"room_id", "m.read", "user_id", ["event_id"], {"ts": 1234} "room_id", "m.read", "user_id", ["event_id"], {"ts": 1234}
) )
self.successResultOf(sender.send_read_receipt(receipt)) self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
self.pump() self.pump()
@ -87,13 +87,13 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
mock_send_transaction = ( mock_send_transaction = (
self.hs.get_federation_transport_client().send_transaction self.hs.get_federation_transport_client().send_transaction
) )
mock_send_transaction.return_value = defer.succeed({}) mock_send_transaction.return_value = make_awaitable({})
sender = self.hs.get_federation_sender() sender = self.hs.get_federation_sender()
receipt = ReadReceipt( receipt = ReadReceipt(
"room_id", "m.read", "user_id", ["event_id"], {"ts": 1234} "room_id", "m.read", "user_id", ["event_id"], {"ts": 1234}
) )
self.successResultOf(sender.send_read_receipt(receipt)) self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
self.pump() self.pump()
@ -125,7 +125,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
receipt = ReadReceipt( receipt = ReadReceipt(
"room_id", "m.read", "user_id", ["other_id"], {"ts": 1234} "room_id", "m.read", "user_id", ["other_id"], {"ts": 1234}
) )
self.successResultOf(sender.send_read_receipt(receipt)) self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
self.pump() self.pump()
mock_send_transaction.assert_not_called() mock_send_transaction.assert_not_called()

Some files were not shown because too many files have changed in this diff Show more