forked from MirrorHub/synapse
Merge branch 'develop' of github.com:matrix-org/synapse into erikj/add_rate_limiting_to_joins
This commit is contained in:
commit
faba873d4b
121 changed files with 1937 additions and 1256 deletions
109
INSTALL.md
109
INSTALL.md
|
@ -1,10 +1,12 @@
|
|||
- [Choosing your server name](#choosing-your-server-name)
|
||||
- [Picking a database engine](#picking-a-database-engine)
|
||||
- [Installing Synapse](#installing-synapse)
|
||||
- [Installing from source](#installing-from-source)
|
||||
- [Platform-Specific Instructions](#platform-specific-instructions)
|
||||
- [Prebuilt packages](#prebuilt-packages)
|
||||
- [Setting up Synapse](#setting-up-synapse)
|
||||
- [TLS certificates](#tls-certificates)
|
||||
- [Client Well-Known URI](#client-well-known-uri)
|
||||
- [Email](#email)
|
||||
- [Registering a user](#registering-a-user)
|
||||
- [Setting up a TURN server](#setting-up-a-turn-server)
|
||||
|
@ -27,6 +29,25 @@ that your email address is probably `user@example.com` rather than
|
|||
`user@email.example.com`) - but doing so may require more advanced setup: see
|
||||
[Setting up Federation](docs/federate.md).
|
||||
|
||||
# Picking a database engine
|
||||
|
||||
Synapse offers two database engines:
|
||||
* [PostgreSQL](https://www.postgresql.org)
|
||||
* [SQLite](https://sqlite.org/)
|
||||
|
||||
Almost all installations should opt to use PostgreSQL. Advantages include:
|
||||
|
||||
* significant performance improvements due to the superior threading and
|
||||
caching model, smarter query optimiser
|
||||
* allowing the DB to be run on separate hardware
|
||||
|
||||
For information on how to install and use PostgreSQL, please see
|
||||
[docs/postgres.md](docs/postgres.md)
|
||||
|
||||
By default Synapse uses SQLite and in doing so trades performance for convenience.
|
||||
SQLite is only recommended in Synapse for testing purposes or for servers with
|
||||
light workloads.
|
||||
|
||||
# Installing Synapse
|
||||
|
||||
## Installing from source
|
||||
|
@ -234,9 +255,9 @@ for a number of platforms.
|
|||
|
||||
There is an offical synapse image available at
|
||||
https://hub.docker.com/r/matrixdotorg/synapse which can be used with
|
||||
the docker-compose file available at [contrib/docker](contrib/docker). Further information on
|
||||
this including configuration options is available in the README on
|
||||
hub.docker.com.
|
||||
the docker-compose file available at [contrib/docker](contrib/docker). Further
|
||||
information on this including configuration options is available in the README
|
||||
on hub.docker.com.
|
||||
|
||||
Alternatively, Andreas Peters (previously Silvio Fricke) has contributed a
|
||||
Dockerfile to automate a synapse server in a single Docker image, at
|
||||
|
@ -244,7 +265,8 @@ https://hub.docker.com/r/avhost/docker-matrix/tags/
|
|||
|
||||
Slavi Pantaleev has created an Ansible playbook,
|
||||
which installs the offical Docker image of Matrix Synapse
|
||||
along with many other Matrix-related services (Postgres database, riot-web, coturn, mxisd, SSL support, etc.).
|
||||
along with many other Matrix-related services (Postgres database, Element, coturn,
|
||||
ma1sd, SSL support, etc.).
|
||||
For more details, see
|
||||
https://github.com/spantaleev/matrix-docker-ansible-deploy
|
||||
|
||||
|
@ -277,22 +299,27 @@ The fingerprint of the repository signing key (as shown by `gpg
|
|||
/usr/share/keyrings/matrix-org-archive-keyring.gpg`) is
|
||||
`AAF9AE843A7584B5A3E4CD2BCF45A512DE2DA058`.
|
||||
|
||||
#### Downstream Debian/Ubuntu packages
|
||||
#### Downstream Debian packages
|
||||
|
||||
For `buster` and `sid`, Synapse is available in the Debian repositories and
|
||||
it should be possible to install it with simply:
|
||||
We do not recommend using the packages from the default Debian `buster`
|
||||
repository at this time, as they are old and suffer from known security
|
||||
vulnerabilities. You can install the latest version of Synapse from
|
||||
[our repository](#matrixorg-packages) or from `buster-backports`. Please
|
||||
see the [Debian documentation](https://backports.debian.org/Instructions/)
|
||||
for information on how to use backports.
|
||||
|
||||
If you are using Debian `sid` or testing, Synapse is available in the default
|
||||
repositories and it should be possible to install it simply with:
|
||||
|
||||
```
|
||||
sudo apt install matrix-synapse
|
||||
```
|
||||
|
||||
There is also a version of `matrix-synapse` in `stretch-backports`. Please see
|
||||
the [Debian documentation on
|
||||
backports](https://backports.debian.org/Instructions/) for information on how
|
||||
to use them.
|
||||
#### Downstream Ubuntu packages
|
||||
|
||||
We do not recommend using the packages in downstream Ubuntu at this time, as
|
||||
they are old and suffer from known security vulnerabilities.
|
||||
We do not recommend using the packages in the default Ubuntu repository
|
||||
at this time, as they are old and suffer from known security vulnerabilities.
|
||||
The latest version of Synapse can be installed from [our repository](#matrixorg-packages).
|
||||
|
||||
### Fedora
|
||||
|
||||
|
@ -419,6 +446,60 @@ so, you will need to edit `homeserver.yaml`, as follows:
|
|||
For a more detailed guide to configuring your server for federation, see
|
||||
[federate.md](docs/federate.md).
|
||||
|
||||
## Client Well-Known URI
|
||||
|
||||
Setting up the client Well-Known URI is optional but if you set it up, it will
|
||||
allow users to enter their full username (e.g. `@user:<server_name>`) into clients
|
||||
which support well-known lookup to automatically configure the homeserver and
|
||||
identity server URLs. This is useful so that users don't have to memorize or think
|
||||
about the actual homeserver URL you are using.
|
||||
|
||||
The URL `https://<server_name>/.well-known/matrix/client` should return JSON in
|
||||
the following format.
|
||||
|
||||
```
|
||||
{
|
||||
"m.homeserver": {
|
||||
"base_url": "https://<matrix.example.com>"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
It can optionally contain identity server information as well.
|
||||
|
||||
```
|
||||
{
|
||||
"m.homeserver": {
|
||||
"base_url": "https://<matrix.example.com>"
|
||||
},
|
||||
"m.identity_server": {
|
||||
"base_url": "https://<identity.example.com>"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
To work in browser based clients, the file must be served with the appropriate
|
||||
Cross-Origin Resource Sharing (CORS) headers. A recommended value would be
|
||||
`Access-Control-Allow-Origin: *` which would allow all browser based clients to
|
||||
view it.
|
||||
|
||||
In nginx this would be something like:
|
||||
```
|
||||
location /.well-known/matrix/client {
|
||||
return 200 '{"m.homeserver": {"base_url": "https://<matrix.example.com>"}}';
|
||||
add_header Content-Type application/json;
|
||||
add_header Access-Control-Allow-Origin *;
|
||||
}
|
||||
```
|
||||
|
||||
You should also ensure the `public_baseurl` option in `homeserver.yaml` is set
|
||||
correctly. `public_baseurl` should be set to the URL that clients will use to
|
||||
connect to your server. This is the same URL you put for the `m.homeserver`
|
||||
`base_url` above.
|
||||
|
||||
```
|
||||
public_baseurl: "https://<matrix.example.com>"
|
||||
```
|
||||
|
||||
## Email
|
||||
|
||||
|
@ -437,7 +518,7 @@ email will be disabled.
|
|||
|
||||
## Registering a user
|
||||
|
||||
The easiest way to create a new user is to do so from a client like [Riot](https://riot.im).
|
||||
The easiest way to create a new user is to do so from a client like [Element](https://element.io/).
|
||||
|
||||
Alternatively you can do so from the command line if you have installed via pip.
|
||||
|
||||
|
|
43
README.rst
43
README.rst
|
@ -45,7 +45,7 @@ which handle:
|
|||
- Eventually-consistent cryptographically secure synchronisation of room
|
||||
state across a global open network of federated servers and services
|
||||
- Sending and receiving extensible messages in a room with (optional)
|
||||
end-to-end encryption[1]
|
||||
end-to-end encryption
|
||||
- Inviting, joining, leaving, kicking, banning room members
|
||||
- Managing user accounts (registration, login, logout)
|
||||
- Using 3rd Party IDs (3PIDs) such as email addresses, phone numbers,
|
||||
|
@ -82,9 +82,6 @@ at the `Matrix spec <https://matrix.org/docs/spec>`_, and experiment with the
|
|||
|
||||
Thanks for using Matrix!
|
||||
|
||||
[1] End-to-end encryption is currently in beta: `blog post <https://matrix.org/blog/2016/11/21/matrixs-olm-end-to-end-encryption-security-assessment-released-and-implemented-cross-platform-on-riot-at-last>`_.
|
||||
|
||||
|
||||
Support
|
||||
=======
|
||||
|
||||
|
@ -115,12 +112,11 @@ Unless you are running a test instance of Synapse on your local machine, in
|
|||
general, you will need to enable TLS support before you can successfully
|
||||
connect from a client: see `<INSTALL.md#tls-certificates>`_.
|
||||
|
||||
An easy way to get started is to login or register via Riot at
|
||||
https://riot.im/app/#/login or https://riot.im/app/#/register respectively.
|
||||
An easy way to get started is to login or register via Element at
|
||||
https://app.element.io/#/login or https://app.element.io/#/register respectively.
|
||||
You will need to change the server you are logging into from ``matrix.org``
|
||||
and instead specify a Homeserver URL of ``https://<server_name>:8448``
|
||||
(or just ``https://<server_name>`` if you are using a reverse proxy).
|
||||
(Leave the identity server as the default - see `Identity servers`_.)
|
||||
If you prefer to use another client, refer to our
|
||||
`client breakdown <https://matrix.org/docs/projects/clients-matrix>`_.
|
||||
|
||||
|
@ -137,7 +133,7 @@ it, specify ``enable_registration: true`` in ``homeserver.yaml``. (It is then
|
|||
recommended to also set up CAPTCHA - see `<docs/CAPTCHA_SETUP.md>`_.)
|
||||
|
||||
Once ``enable_registration`` is set to ``true``, it is possible to register a
|
||||
user via `riot.im <https://riot.im/app/#/register>`_ or other Matrix clients.
|
||||
user via a Matrix client.
|
||||
|
||||
Your new user name will be formed partly from the ``server_name``, and partly
|
||||
from a localpart you specify when you create the account. Your name will take
|
||||
|
@ -183,30 +179,6 @@ versions of synapse.
|
|||
|
||||
.. _UPGRADE.rst: UPGRADE.rst
|
||||
|
||||
|
||||
Using PostgreSQL
|
||||
================
|
||||
|
||||
Synapse offers two database engines:
|
||||
* `PostgreSQL <https://www.postgresql.org>`_
|
||||
* `SQLite <https://sqlite.org/>`_
|
||||
|
||||
Almost all installations should opt to use PostgreSQL. Advantages include:
|
||||
|
||||
* significant performance improvements due to the superior threading and
|
||||
caching model, smarter query optimiser
|
||||
* allowing the DB to be run on separate hardware
|
||||
* allowing basic active/backup high-availability with a "hot spare" synapse
|
||||
pointing at the same DB master, as well as enabling DB replication in
|
||||
synapse itself.
|
||||
|
||||
For information on how to install and use PostgreSQL, please see
|
||||
`docs/postgres.md <docs/postgres.md>`_.
|
||||
|
||||
By default Synapse uses SQLite and in doing so trades performance for convenience.
|
||||
SQLite is only recommended in Synapse for testing purposes or for servers with
|
||||
light workloads.
|
||||
|
||||
.. _reverse-proxy:
|
||||
|
||||
Using a reverse proxy with Synapse
|
||||
|
@ -255,10 +227,9 @@ email address.
|
|||
Password reset
|
||||
==============
|
||||
|
||||
If a user has registered an email address to their account using an identity
|
||||
server, they can request a password-reset token via clients such as Riot.
|
||||
|
||||
A manual password reset can be done via direct database access as follows.
|
||||
Users can reset their password through their client. Alternatively, a server admin
|
||||
can reset a users password using the `admin API <docs/admin_api/user_admin_api.rst#reset-password>`_
|
||||
or by directly editing the database as shown below.
|
||||
|
||||
First calculate the hash of the new password::
|
||||
|
||||
|
|
1
changelog.d/7736.feature
Normal file
1
changelog.d/7736.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Add unread messages count to sync responses, as specified in [MSC2654](https://github.com/matrix-org/matrix-doc/pull/2654).
|
1
changelog.d/7899.doc
Normal file
1
changelog.d/7899.doc
Normal file
|
@ -0,0 +1 @@
|
|||
Document how to set up a Client Well-Known file and fix several pieces of outdated documentation.
|
1
changelog.d/7902.feature
Normal file
1
changelog.d/7902.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Add option to allow server admins to join rooms which fail complexity checks. Contributed by @lugino-emeritus.
|
1
changelog.d/7936.misc
Normal file
1
changelog.d/7936.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Switch to the JSON implementation from the standard library and bump the minimum version of the canonicaljson library to 1.2.0.
|
1
changelog.d/7947.misc
Normal file
1
changelog.d/7947.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Convert various parts of the codebase to async/await.
|
1
changelog.d/7948.misc
Normal file
1
changelog.d/7948.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Convert various parts of the codebase to async/await.
|
1
changelog.d/7949.misc
Normal file
1
changelog.d/7949.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Convert various parts of the codebase to async/await.
|
1
changelog.d/7951.misc
Normal file
1
changelog.d/7951.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Convert various parts of the codebase to async/await.
|
1
changelog.d/7952.misc
Normal file
1
changelog.d/7952.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Move some database-related log lines from the default logger to the database/transaction loggers.
|
1
changelog.d/7963.misc
Normal file
1
changelog.d/7963.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Convert various parts of the codebase to async/await.
|
1
changelog.d/7964.feature
Normal file
1
changelog.d/7964.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Add an option to purge room or not with delete room admin endpoint (`POST /_synapse/admin/v1/rooms/<room_id>/delete`). Contributed by @dklimpel.
|
1
changelog.d/7965.misc
Normal file
1
changelog.d/7965.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Add a script to detect source code files using non-unix line terminators.
|
1
changelog.d/7970.misc
Normal file
1
changelog.d/7970.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Add a script to detect source code files using non-unix line terminators.
|
1
changelog.d/7971.misc
Normal file
1
changelog.d/7971.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Log the SAML session ID during creation.
|
1
changelog.d/7973.misc
Normal file
1
changelog.d/7973.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Convert various parts of the codebase to async/await.
|
1
changelog.d/7975.misc
Normal file
1
changelog.d/7975.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Convert various parts of the codebase to async/await.
|
1
changelog.d/7976.misc
Normal file
1
changelog.d/7976.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Convert various parts of the codebase to async/await.
|
1
changelog.d/7978.bugfix
Normal file
1
changelog.d/7978.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Fix a long standing bug: 'Duplicate key value violates unique constraint "event_relations_id"' when message retention is configured.
|
1
changelog.d/7979.misc
Normal file
1
changelog.d/7979.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Switch to the JSON implementation from the standard library and bump the minimum version of the canonicaljson library to 1.2.0.
|
1
changelog.d/7980.bugfix
Normal file
1
changelog.d/7980.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Fix "no create event in auth events" when trying to reject invitation after inviter leaves. Bug introduced in Synapse v1.10.0.
|
1
changelog.d/7981.misc
Normal file
1
changelog.d/7981.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Convert various parts of the codebase to async/await.
|
1
changelog.d/7990.doc
Normal file
1
changelog.d/7990.doc
Normal file
|
@ -0,0 +1 @@
|
|||
Improve workers docs.
|
1
changelog.d/7992.doc
Normal file
1
changelog.d/7992.doc
Normal file
|
@ -0,0 +1 @@
|
|||
Fix typo in `docs/workers.md`.
|
1
changelog.d/7998.doc
Normal file
1
changelog.d/7998.doc
Normal file
|
@ -0,0 +1 @@
|
|||
Add documentation for how to undo a room shutdown.
|
|
@ -609,7 +609,8 @@ class SynapseCmd(cmd.Cmd):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def _do_event_stream(self, timeout):
|
||||
res = yield self.http_client.get_json(
|
||||
res = yield defer.ensureDeferred(
|
||||
self.http_client.get_json(
|
||||
self._url() + "/events",
|
||||
{
|
||||
"access_token": self._tok(),
|
||||
|
@ -617,6 +618,7 @@ class SynapseCmd(cmd.Cmd):
|
|||
"from": self.event_stream_token,
|
||||
},
|
||||
)
|
||||
)
|
||||
print(json.dumps(res, indent=4))
|
||||
|
||||
if "chunk" in res:
|
||||
|
|
10
debian/changelog
vendored
10
debian/changelog
vendored
|
@ -1,3 +1,13 @@
|
|||
matrix-synapse-py3 (1.xx.0) stable; urgency=medium
|
||||
|
||||
[ Synapse Packaging team ]
|
||||
* New synapse release 1.xx.0.
|
||||
|
||||
[ Aaron Raimist ]
|
||||
* Fix outdated documentation for SYNAPSE_CACHE_FACTOR
|
||||
|
||||
-- Synapse Packaging team <packages@matrix.org> XXXXX
|
||||
|
||||
matrix-synapse-py3 (1.18.0) stable; urgency=medium
|
||||
|
||||
* New synapse release 1.18.0.
|
||||
|
|
2
debian/matrix-synapse.default
vendored
2
debian/matrix-synapse.default
vendored
|
@ -1,2 +1,2 @@
|
|||
# Specify environment variables used when running Synapse
|
||||
# SYNAPSE_CACHE_FACTOR=1 (default)
|
||||
# SYNAPSE_CACHE_FACTOR=0.5 (default)
|
||||
|
|
27
debian/synctl.ronn
vendored
27
debian/synctl.ronn
vendored
|
@ -46,19 +46,20 @@ Configuration file may be generated as follows:
|
|||
## ENVIRONMENT
|
||||
|
||||
* `SYNAPSE_CACHE_FACTOR`:
|
||||
Synapse's architecture is quite RAM hungry currently - a lot of
|
||||
recent room data and metadata is deliberately cached in RAM in
|
||||
order to speed up common requests. This will be improved in
|
||||
future, but for now the easiest way to either reduce the RAM usage
|
||||
(at the risk of slowing things down) is to set the
|
||||
SYNAPSE_CACHE_FACTOR environment variable. Roughly speaking, a
|
||||
SYNAPSE_CACHE_FACTOR of 1.0 will max out at around 3-4GB of
|
||||
resident memory - this is what we currently run the matrix.org
|
||||
on. The default setting is currently 0.1, which is probably around
|
||||
a ~700MB footprint. You can dial it down further to 0.02 if
|
||||
desired, which targets roughly ~512MB. Conversely you can dial it
|
||||
up if you need performance for lots of users and have a box with a
|
||||
lot of RAM.
|
||||
Synapse's architecture is quite RAM hungry currently - we deliberately
|
||||
cache a lot of recent room data and metadata in RAM in order to speed up
|
||||
common requests. We'll improve this in the future, but for now the easiest
|
||||
way to either reduce the RAM usage (at the risk of slowing things down)
|
||||
is to set the almost-undocumented ``SYNAPSE_CACHE_FACTOR`` environment
|
||||
variable. The default is 0.5, which can be decreased to reduce RAM usage
|
||||
in memory constrained enviroments, or increased if performance starts to
|
||||
degrade.
|
||||
|
||||
However, degraded performance due to a low cache factor, common on
|
||||
machines with slow disks, often leads to explosions in memory use due
|
||||
backlogged requests. In this case, reducing the cache factor will make
|
||||
things worse. Instead, try increasing it drastically. 2.0 is a good
|
||||
starting value.
|
||||
|
||||
## COPYRIGHT
|
||||
|
||||
|
|
|
@ -10,5 +10,16 @@
|
|||
# homeserver.yaml. Instead, if you are starting from scratch, please generate
|
||||
# a fresh config using Synapse by following the instructions in INSTALL.md.
|
||||
|
||||
# Configuration options that take a time period can be set using a number
|
||||
# followed by a letter. Letters have the following meanings:
|
||||
# s = second
|
||||
# m = minute
|
||||
# h = hour
|
||||
# d = day
|
||||
# w = week
|
||||
# y = year
|
||||
# For example, setting redaction_retention_period: 5m would remove redacted
|
||||
# messages from the database after 5 minutes, rather than 5 months.
|
||||
|
||||
################################################################################
|
||||
|
||||
|
|
|
@ -369,7 +369,9 @@ to the new room will have power level `-10` by default, and thus be unable to sp
|
|||
If `block` is `True` it prevents new joins to the old room.
|
||||
|
||||
This API will remove all trace of the old room from your database after removing
|
||||
all local users.
|
||||
all local users. If `purge` is `true` (the default), all traces of the old room will
|
||||
be removed from your database after removing all local users. If you do not want
|
||||
this to happen, set `purge` to `false`.
|
||||
Depending on the amount of history being purged a call to the API may take
|
||||
several minutes or longer.
|
||||
|
||||
|
@ -388,7 +390,8 @@ with a body of:
|
|||
"new_room_user_id": "@someuser:example.com",
|
||||
"room_name": "Content Violation Notification",
|
||||
"message": "Bad Room has been shutdown due to content violations on this server. Please review our Terms of Service.",
|
||||
"block": true
|
||||
"block": true,
|
||||
"purge": true
|
||||
}
|
||||
```
|
||||
|
||||
|
@ -430,8 +433,10 @@ The following JSON body parameters are available:
|
|||
`new_room_user_id` in the new room. Ideally this will clearly convey why the
|
||||
original room was shut down. Defaults to `Sharing illegal content on this server
|
||||
is not permitted and rooms in violation will be blocked.`
|
||||
* `block` - Optional. If set to `true`, this room will be added to a blocking list, preventing future attempts to
|
||||
join the room. Defaults to `false`.
|
||||
* `block` - Optional. If set to `true`, this room will be added to a blocking list, preventing
|
||||
future attempts to join the room. Defaults to `false`.
|
||||
* `purge` - Optional. If set to `true`, it will remove all traces of the room from your database.
|
||||
Defaults to `true`.
|
||||
|
||||
The JSON body must not be empty. The body must be at least `{}`.
|
||||
|
||||
|
|
|
@ -72,3 +72,23 @@ Response:
|
|||
"new_room_id": "!newroomid:example.com",
|
||||
},
|
||||
```
|
||||
|
||||
## Undoing room shutdowns
|
||||
|
||||
*Note*: This guide may be outdated by the time you read it. By nature of room shutdowns being performed at the database level,
|
||||
the structure can and does change without notice.
|
||||
|
||||
First, it's important to understand that a room shutdown is very destructive. Undoing a shutdown is not as simple as pretending it
|
||||
never happened - work has to be done to move forward instead of resetting the past.
|
||||
|
||||
1. For safety reasons, it is recommended to shut down Synapse prior to continuing.
|
||||
2. In the database, run `DELETE FROM blocked_rooms WHERE room_id = '!example:example.org';`
|
||||
* For caution: it's recommended to run this in a transaction: `BEGIN; DELETE ...;`, verify you got 1 result, then `COMMIT;`.
|
||||
* The room ID is the same one supplied to the shutdown room API, not the Content Violation room.
|
||||
3. Restart Synapse (required).
|
||||
|
||||
You will have to manually handle, if you so choose, the following:
|
||||
|
||||
* Aliases that would have been redirected to the Content Violation room.
|
||||
* Users that would have been booted from the room (and will have been force-joined to the Content Violation room).
|
||||
* Removal of the Content Violation room if desired.
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
different thread to Synapse. This can make it more resilient to
|
||||
heavy load meaning metrics cannot be retrieved, and can be exposed
|
||||
to just internal networks easier. The served metrics are available
|
||||
over HTTP only, and will be available at `/`.
|
||||
over HTTP only, and will be available at `/_synapse/metrics`.
|
||||
|
||||
Add a new listener to homeserver.yaml:
|
||||
|
||||
|
|
|
@ -188,6 +188,9 @@ to do step 2.
|
|||
|
||||
It is safe to at any time kill the port script and restart it.
|
||||
|
||||
Note that the database may take up significantly more (25% - 100% more)
|
||||
space on disk after porting to Postgres.
|
||||
|
||||
### Using the port script
|
||||
|
||||
Firstly, shut down the currently running synapse server and copy its
|
||||
|
|
|
@ -10,6 +10,17 @@
|
|||
# homeserver.yaml. Instead, if you are starting from scratch, please generate
|
||||
# a fresh config using Synapse by following the instructions in INSTALL.md.
|
||||
|
||||
# Configuration options that take a time period can be set using a number
|
||||
# followed by a letter. Letters have the following meanings:
|
||||
# s = second
|
||||
# m = minute
|
||||
# h = hour
|
||||
# d = day
|
||||
# w = week
|
||||
# y = year
|
||||
# For example, setting redaction_retention_period: 5m would remove redacted
|
||||
# messages from the database after 5 minutes, rather than 5 months.
|
||||
|
||||
################################################################################
|
||||
|
||||
# Configuration file for Synapse.
|
||||
|
@ -314,6 +325,10 @@ limit_remote_rooms:
|
|||
#
|
||||
#complexity_error: "This room is too complex."
|
||||
|
||||
# allow server admins to join complex rooms. Default is false.
|
||||
#
|
||||
#admins_can_join: true
|
||||
|
||||
# Whether to require a user to be in the room to add an alias to it.
|
||||
# Defaults to 'true'.
|
||||
#
|
||||
|
@ -1157,24 +1172,6 @@ account_validity:
|
|||
#
|
||||
#default_identity_server: https://matrix.org
|
||||
|
||||
# The list of identity servers trusted to verify third party
|
||||
# identifiers by this server.
|
||||
#
|
||||
# Also defines the ID server which will be called when an account is
|
||||
# deactivated (one will be picked arbitrarily).
|
||||
#
|
||||
# Note: This option is deprecated. Since v0.99.4, Synapse has tracked which identity
|
||||
# server a 3PID has been bound to. For 3PIDs bound before then, Synapse runs a
|
||||
# background migration script, informing itself that the identity server all of its
|
||||
# 3PIDs have been bound to is likely one of the below.
|
||||
#
|
||||
# As of Synapse v1.4.0, all other functionality of this option has been deprecated, and
|
||||
# it is now solely used for the purposes of the background migration script, and can be
|
||||
# removed once it has run.
|
||||
#trusted_third_party_id_servers:
|
||||
# - matrix.org
|
||||
# - vector.im
|
||||
|
||||
# Handle threepid (email/phone etc) registration and password resets through a set of
|
||||
# *trusted* identity servers. Note that this allows the configured identity server to
|
||||
# reset passwords for accounts!
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
# Scaling synapse via workers
|
||||
|
||||
For small instances it recommended to run Synapse in monolith mode (the
|
||||
default). For larger instances where performance is a concern it can be helpful
|
||||
to split out functionality into multiple separate python processes. These
|
||||
processes are called 'workers', and are (eventually) intended to scale
|
||||
horizontally independently.
|
||||
For small instances it recommended to run Synapse in the default monolith mode.
|
||||
For larger instances where performance is a concern it can be helpful to split
|
||||
out functionality into multiple separate python processes. These processes are
|
||||
called 'workers', and are (eventually) intended to scale horizontally
|
||||
independently.
|
||||
|
||||
Synapse's worker support is under active development and subject to change as
|
||||
we attempt to rapidly scale ever larger Synapse instances. However we are
|
||||
|
@ -23,29 +23,30 @@ The processes communicate with each other via a Synapse-specific protocol called
|
|||
feeds streams of newly written data between processes so they can be kept in
|
||||
sync with the database state.
|
||||
|
||||
Additionally, processes may make HTTP requests to each other. Typically this is
|
||||
used for operations which need to wait for a reply - such as sending an event.
|
||||
When configured to do so, Synapse uses a
|
||||
[Redis pub/sub channel](https://redis.io/topics/pubsub) to send the replication
|
||||
stream between all configured Synapse processes. Additionally, processes may
|
||||
make HTTP requests to each other, primarily for operations which need to wait
|
||||
for a reply ─ such as sending an event.
|
||||
|
||||
As of Synapse v1.13.0, it is possible to configure Synapse to send replication
|
||||
via a [Redis pub/sub channel](https://redis.io/topics/pubsub), and is now the
|
||||
recommended way of configuring replication. This is an alternative to the old
|
||||
direct TCP connections to the main process: rather than all the workers
|
||||
connecting to the main process, all the workers and the main process connect to
|
||||
Redis, which relays replication commands between processes. This can give a
|
||||
significant cpu saving on the main process and will be a prerequisite for
|
||||
upcoming performance improvements.
|
||||
Redis support was added in v1.13.0 with it becoming the recommended method in
|
||||
v1.18.0. It replaced the old direct TCP connections (which is deprecated as of
|
||||
v1.18.0) to the main process. With Redis, rather than all the workers connecting
|
||||
to the main process, all the workers and the main process connect to Redis,
|
||||
which relays replication commands between processes. This can give a significant
|
||||
cpu saving on the main process and will be a prerequisite for upcoming
|
||||
performance improvements.
|
||||
|
||||
(See the [Architectural diagram](#architectural-diagram) section at the end for
|
||||
a visualisation of what this looks like)
|
||||
See the [Architectural diagram](#architectural-diagram) section at the end for
|
||||
a visualisation of what this looks like.
|
||||
|
||||
|
||||
## Setting up workers
|
||||
|
||||
A Redis server is required to manage the communication between the processes.
|
||||
(The older direct TCP connections are now deprecated.) The Redis server
|
||||
should be installed following the normal procedure for your distribution (e.g.
|
||||
`apt install redis-server` on Debian). It is safe to use an existing Redis
|
||||
deployment if you have one.
|
||||
The Redis server should be installed following the normal procedure for your
|
||||
distribution (e.g. `apt install redis-server` on Debian). It is safe to use an
|
||||
existing Redis deployment if you have one.
|
||||
|
||||
Once installed, check that Redis is running and accessible from the host running
|
||||
Synapse, for example by executing `echo PING | nc -q1 localhost 6379` and seeing
|
||||
|
@ -65,8 +66,9 @@ https://hub.docker.com/r/matrixdotorg/synapse/.
|
|||
|
||||
To make effective use of the workers, you will need to configure an HTTP
|
||||
reverse-proxy such as nginx or haproxy, which will direct incoming requests to
|
||||
the correct worker, or to the main synapse instance. See [reverse_proxy.md](reverse_proxy.md)
|
||||
for information on setting up a reverse proxy.
|
||||
the correct worker, or to the main synapse instance. See
|
||||
[reverse_proxy.md](reverse_proxy.md) for information on setting up a reverse
|
||||
proxy.
|
||||
|
||||
To enable workers you should create a configuration file for each worker
|
||||
process. Each worker configuration file inherits the configuration of the shared
|
||||
|
@ -75,8 +77,12 @@ that worker, e.g. the HTTP listener that it provides (if any); logging
|
|||
configuration; etc. You should minimise the number of overrides though to
|
||||
maintain a usable config.
|
||||
|
||||
Next you need to add both a HTTP replication listener and redis config to the
|
||||
shared Synapse configuration file (`homeserver.yaml`). For example:
|
||||
|
||||
### Shared Configuration
|
||||
|
||||
Next you need to add both a HTTP replication listener, used for HTTP requests
|
||||
between processes, and redis config to the shared Synapse configuration file
|
||||
(`homeserver.yaml`). For example:
|
||||
|
||||
```yaml
|
||||
# extend the existing `listeners` section. This defines the ports that the
|
||||
|
@ -98,6 +104,9 @@ See the sample config for the full documentation of each option.
|
|||
Under **no circumstances** should the replication listener be exposed to the
|
||||
public internet; it has no authentication and is unencrypted.
|
||||
|
||||
|
||||
### Worker Configuration
|
||||
|
||||
In the config file for each worker, you must specify the type of worker
|
||||
application (`worker_app`), and you should specify a unqiue name for the worker
|
||||
(`worker_name`). The currently available worker applications are listed below.
|
||||
|
@ -278,7 +287,7 @@ instance_map:
|
|||
host: localhost
|
||||
port: 8034
|
||||
|
||||
streams_writers:
|
||||
stream_writers:
|
||||
events: event_persister1
|
||||
```
|
||||
|
||||
|
|
34
scripts-dev/check_line_terminators.sh
Executable file
34
scripts-dev/check_line_terminators.sh
Executable file
|
@ -0,0 +1,34 @@
|
|||
#!/bin/bash
|
||||
#
|
||||
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# This script checks that line terminators in all repository files (excluding
|
||||
# those in the .git directory) feature unix line terminators.
|
||||
#
|
||||
# Usage:
|
||||
#
|
||||
# ./check_line_terminators.sh
|
||||
#
|
||||
# The script will emit exit code 1 if any files that do not use unix line
|
||||
# terminators are found, 0 otherwise.
|
||||
|
||||
# cd to the root of the repository
|
||||
cd `dirname $0`/..
|
||||
|
||||
# Find and print files with non-unix line terminators
|
||||
if find . -path './.git/*' -prune -o -type f -print0 | xargs -0 grep -I -l $'\r$'; then
|
||||
echo -e '\e[31mERROR: found files with CRLF line endings. See above.\e[39m'
|
||||
exit 1
|
||||
fi
|
|
@ -69,7 +69,7 @@ logger = logging.getLogger("synapse_port_db")
|
|||
|
||||
|
||||
BOOLEAN_COLUMNS = {
|
||||
"events": ["processed", "outlier", "contains_url"],
|
||||
"events": ["processed", "outlier", "contains_url", "count_as_unread"],
|
||||
"rooms": ["is_public"],
|
||||
"event_edges": ["is_state"],
|
||||
"presence_list": ["accepted"],
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
""" This is a reference implementation of a Matrix homeserver.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
@ -25,6 +26,9 @@ if sys.version_info < (3, 5):
|
|||
print("Synapse requires Python 3.5 or above.")
|
||||
sys.exit(1)
|
||||
|
||||
# Twisted and canonicaljson will fail to import when this file is executed to
|
||||
# get the __version__ during a fresh install. That's OK and subsequent calls to
|
||||
# actually start Synapse will import these libraries fine.
|
||||
try:
|
||||
from twisted.internet import protocol
|
||||
from twisted.internet.protocol import Factory
|
||||
|
@ -36,6 +40,14 @@ try:
|
|||
except ImportError:
|
||||
pass
|
||||
|
||||
# Use the standard library json implementation instead of simplejson.
|
||||
try:
|
||||
from canonicaljson import set_json_library
|
||||
|
||||
set_json_library(json)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
__version__ = "1.18.0"
|
||||
|
||||
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
|
||||
|
|
|
@ -82,7 +82,7 @@ class Auth(object):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def check_from_context(self, room_version: str, event, context, do_sig_check=True):
|
||||
prev_state_ids = yield context.get_prev_state_ids()
|
||||
prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
|
||||
auth_events_ids = yield self.compute_auth_events(
|
||||
event, prev_state_ids, for_verification=True
|
||||
)
|
||||
|
|
|
@ -628,7 +628,7 @@ class GenericWorkerServer(HomeServer):
|
|||
|
||||
self.get_tcp_replication().start_replication(self)
|
||||
|
||||
def remove_pusher(self, app_id, push_key, user_id):
|
||||
async def remove_pusher(self, app_id, push_key, user_id):
|
||||
self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id)
|
||||
|
||||
def build_replication_data_handler(self):
|
||||
|
|
|
@ -15,11 +15,9 @@
|
|||
import logging
|
||||
import re
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.types import GroupID, get_domain_from_id
|
||||
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
||||
from synapse.util.caches.descriptors import cached
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -43,7 +41,7 @@ class AppServiceTransaction(object):
|
|||
Args:
|
||||
as_api(ApplicationServiceApi): The API to use to send.
|
||||
Returns:
|
||||
A Deferred which resolves to True if the transaction was sent.
|
||||
An Awaitable which resolves to True if the transaction was sent.
|
||||
"""
|
||||
return as_api.push_bulk(
|
||||
service=self.service, events=self.events, txn_id=self.id
|
||||
|
@ -172,8 +170,7 @@ class ApplicationService(object):
|
|||
return regex_obj["exclusive"]
|
||||
return False
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _matches_user(self, event, store):
|
||||
async def _matches_user(self, event, store):
|
||||
if not event:
|
||||
return False
|
||||
|
||||
|
@ -188,12 +185,12 @@ class ApplicationService(object):
|
|||
if not store:
|
||||
return False
|
||||
|
||||
does_match = yield self._matches_user_in_member_list(event.room_id, store)
|
||||
does_match = await self._matches_user_in_member_list(event.room_id, store)
|
||||
return does_match
|
||||
|
||||
@cachedInlineCallbacks(num_args=1, cache_context=True)
|
||||
def _matches_user_in_member_list(self, room_id, store, cache_context):
|
||||
member_list = yield store.get_users_in_room(
|
||||
@cached(num_args=1, cache_context=True)
|
||||
async def _matches_user_in_member_list(self, room_id, store, cache_context):
|
||||
member_list = await store.get_users_in_room(
|
||||
room_id, on_invalidate=cache_context.invalidate
|
||||
)
|
||||
|
||||
|
@ -208,35 +205,33 @@ class ApplicationService(object):
|
|||
return self.is_interested_in_room(event.room_id)
|
||||
return False
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _matches_aliases(self, event, store):
|
||||
async def _matches_aliases(self, event, store):
|
||||
if not store or not event:
|
||||
return False
|
||||
|
||||
alias_list = yield store.get_aliases_for_room(event.room_id)
|
||||
alias_list = await store.get_aliases_for_room(event.room_id)
|
||||
for alias in alias_list:
|
||||
if self.is_interested_in_alias(alias):
|
||||
return True
|
||||
return False
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def is_interested(self, event, store=None):
|
||||
async def is_interested(self, event, store=None) -> bool:
|
||||
"""Check if this service is interested in this event.
|
||||
|
||||
Args:
|
||||
event(Event): The event to check.
|
||||
store(DataStore)
|
||||
Returns:
|
||||
bool: True if this service would like to know about this event.
|
||||
True if this service would like to know about this event.
|
||||
"""
|
||||
# Do cheap checks first
|
||||
if self._matches_room_id(event):
|
||||
return True
|
||||
|
||||
if (yield self._matches_aliases(event, store)):
|
||||
if await self._matches_aliases(event, store):
|
||||
return True
|
||||
|
||||
if (yield self._matches_user(event, store)):
|
||||
if await self._matches_user(event, store):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
|
|
@ -93,13 +93,12 @@ class ApplicationServiceApi(SimpleHttpClient):
|
|||
hs, "as_protocol_meta", timeout_ms=HOUR_IN_MS
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def query_user(self, service, user_id):
|
||||
async def query_user(self, service, user_id):
|
||||
if service.url is None:
|
||||
return False
|
||||
uri = service.url + ("/users/%s" % urllib.parse.quote(user_id))
|
||||
try:
|
||||
response = yield self.get_json(uri, {"access_token": service.hs_token})
|
||||
response = await self.get_json(uri, {"access_token": service.hs_token})
|
||||
if response is not None: # just an empty json object
|
||||
return True
|
||||
except CodeMessageException as e:
|
||||
|
@ -110,14 +109,12 @@ class ApplicationServiceApi(SimpleHttpClient):
|
|||
logger.warning("query_user to %s threw exception %s", uri, ex)
|
||||
return False
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def query_alias(self, service, alias):
|
||||
async def query_alias(self, service, alias):
|
||||
if service.url is None:
|
||||
return False
|
||||
uri = service.url + ("/rooms/%s" % urllib.parse.quote(alias))
|
||||
response = None
|
||||
try:
|
||||
response = yield self.get_json(uri, {"access_token": service.hs_token})
|
||||
response = await self.get_json(uri, {"access_token": service.hs_token})
|
||||
if response is not None: # just an empty json object
|
||||
return True
|
||||
except CodeMessageException as e:
|
||||
|
@ -128,8 +125,7 @@ class ApplicationServiceApi(SimpleHttpClient):
|
|||
logger.warning("query_alias to %s threw exception %s", uri, ex)
|
||||
return False
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def query_3pe(self, service, kind, protocol, fields):
|
||||
async def query_3pe(self, service, kind, protocol, fields):
|
||||
if kind == ThirdPartyEntityKind.USER:
|
||||
required_field = "userid"
|
||||
elif kind == ThirdPartyEntityKind.LOCATION:
|
||||
|
@ -146,7 +142,7 @@ class ApplicationServiceApi(SimpleHttpClient):
|
|||
urllib.parse.quote(protocol),
|
||||
)
|
||||
try:
|
||||
response = yield self.get_json(uri, fields)
|
||||
response = await self.get_json(uri, fields)
|
||||
if not isinstance(response, list):
|
||||
logger.warning(
|
||||
"query_3pe to %s returned an invalid response %r", uri, response
|
||||
|
@ -202,8 +198,7 @@ class ApplicationServiceApi(SimpleHttpClient):
|
|||
key = (service.id, protocol)
|
||||
return self.protocol_meta_cache.wrap(key, _get)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def push_bulk(self, service, events, txn_id=None):
|
||||
async def push_bulk(self, service, events, txn_id=None):
|
||||
if service.url is None:
|
||||
return True
|
||||
|
||||
|
@ -218,7 +213,7 @@ class ApplicationServiceApi(SimpleHttpClient):
|
|||
|
||||
uri = service.url + ("/transactions/%s" % urllib.parse.quote(txn_id))
|
||||
try:
|
||||
yield self.put_json(
|
||||
await self.put_json(
|
||||
uri=uri,
|
||||
json_body={"events": events},
|
||||
args={"access_token": service.hs_token},
|
||||
|
|
|
@ -50,8 +50,6 @@ components.
|
|||
"""
|
||||
import logging
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.appservice import ApplicationServiceState
|
||||
from synapse.logging.context import run_in_background
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
|
@ -73,12 +71,11 @@ class ApplicationServiceScheduler(object):
|
|||
self.txn_ctrl = _TransactionController(self.clock, self.store, self.as_api)
|
||||
self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def start(self):
|
||||
async def start(self):
|
||||
logger.info("Starting appservice scheduler")
|
||||
|
||||
# check for any DOWN ASes and start recoverers for them.
|
||||
services = yield self.store.get_appservices_by_state(
|
||||
services = await self.store.get_appservices_by_state(
|
||||
ApplicationServiceState.DOWN
|
||||
)
|
||||
|
||||
|
@ -117,8 +114,7 @@ class _ServiceQueuer(object):
|
|||
"as-sender-%s" % (service.id,), self._send_request, service
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _send_request(self, service):
|
||||
async def _send_request(self, service):
|
||||
# sanity-check: we shouldn't get here if this service already has a sender
|
||||
# running.
|
||||
assert service.id not in self.requests_in_flight
|
||||
|
@ -130,7 +126,7 @@ class _ServiceQueuer(object):
|
|||
if not events:
|
||||
return
|
||||
try:
|
||||
yield self.txn_ctrl.send(service, events)
|
||||
await self.txn_ctrl.send(service, events)
|
||||
except Exception:
|
||||
logger.exception("AS request failed")
|
||||
finally:
|
||||
|
@ -162,36 +158,33 @@ class _TransactionController(object):
|
|||
# for UTs
|
||||
self.RECOVERER_CLASS = _Recoverer
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send(self, service, events):
|
||||
async def send(self, service, events):
|
||||
try:
|
||||
txn = yield self.store.create_appservice_txn(service=service, events=events)
|
||||
service_is_up = yield self._is_service_up(service)
|
||||
txn = await self.store.create_appservice_txn(service=service, events=events)
|
||||
service_is_up = await self._is_service_up(service)
|
||||
if service_is_up:
|
||||
sent = yield txn.send(self.as_api)
|
||||
sent = await txn.send(self.as_api)
|
||||
if sent:
|
||||
yield txn.complete(self.store)
|
||||
await txn.complete(self.store)
|
||||
else:
|
||||
run_in_background(self._on_txn_fail, service)
|
||||
except Exception:
|
||||
logger.exception("Error creating appservice transaction")
|
||||
run_in_background(self._on_txn_fail, service)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_recovered(self, recoverer):
|
||||
async def on_recovered(self, recoverer):
|
||||
logger.info(
|
||||
"Successfully recovered application service AS ID %s", recoverer.service.id
|
||||
)
|
||||
self.recoverers.pop(recoverer.service.id)
|
||||
logger.info("Remaining active recoverers: %s", len(self.recoverers))
|
||||
yield self.store.set_appservice_state(
|
||||
await self.store.set_appservice_state(
|
||||
recoverer.service, ApplicationServiceState.UP
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _on_txn_fail(self, service):
|
||||
async def _on_txn_fail(self, service):
|
||||
try:
|
||||
yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
|
||||
await self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
|
||||
self.start_recoverer(service)
|
||||
except Exception:
|
||||
logger.exception("Error starting AS recoverer")
|
||||
|
@ -211,9 +204,8 @@ class _TransactionController(object):
|
|||
recoverer.recover()
|
||||
logger.info("Now %i active recoverers", len(self.recoverers))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _is_service_up(self, service):
|
||||
state = yield self.store.get_appservice_state(service)
|
||||
async def _is_service_up(self, service):
|
||||
state = await self.store.get_appservice_state(service)
|
||||
return state == ApplicationServiceState.UP or state is None
|
||||
|
||||
|
||||
|
@ -254,25 +246,24 @@ class _Recoverer(object):
|
|||
self.backoff_counter += 1
|
||||
self.recover()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def retry(self):
|
||||
async def retry(self):
|
||||
logger.info("Starting retries on %s", self.service.id)
|
||||
try:
|
||||
while True:
|
||||
txn = yield self.store.get_oldest_unsent_txn(self.service)
|
||||
txn = await self.store.get_oldest_unsent_txn(self.service)
|
||||
if not txn:
|
||||
# nothing left: we're done!
|
||||
self.callback(self)
|
||||
await self.callback(self)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Retrying transaction %s for AS ID %s", txn.id, txn.service.id
|
||||
)
|
||||
sent = yield txn.send(self.as_api)
|
||||
sent = await txn.send(self.as_api)
|
||||
if not sent:
|
||||
break
|
||||
|
||||
yield txn.complete(self.store)
|
||||
await txn.complete(self.store)
|
||||
|
||||
# reset the backoff counter and then process the next transaction
|
||||
self.backoff_counter = 1
|
||||
|
|
|
@ -333,24 +333,6 @@ class RegistrationConfig(Config):
|
|||
#
|
||||
#default_identity_server: https://matrix.org
|
||||
|
||||
# The list of identity servers trusted to verify third party
|
||||
# identifiers by this server.
|
||||
#
|
||||
# Also defines the ID server which will be called when an account is
|
||||
# deactivated (one will be picked arbitrarily).
|
||||
#
|
||||
# Note: This option is deprecated. Since v0.99.4, Synapse has tracked which identity
|
||||
# server a 3PID has been bound to. For 3PIDs bound before then, Synapse runs a
|
||||
# background migration script, informing itself that the identity server all of its
|
||||
# 3PIDs have been bound to is likely one of the below.
|
||||
#
|
||||
# As of Synapse v1.4.0, all other functionality of this option has been deprecated, and
|
||||
# it is now solely used for the purposes of the background migration script, and can be
|
||||
# removed once it has run.
|
||||
#trusted_third_party_id_servers:
|
||||
# - matrix.org
|
||||
# - vector.im
|
||||
|
||||
# Handle threepid (email/phone etc) registration and password resets through a set of
|
||||
# *trusted* identity servers. Note that this allows the configured identity server to
|
||||
# reset passwords for accounts!
|
||||
|
|
|
@ -439,6 +439,9 @@ class ServerConfig(Config):
|
|||
validator=attr.validators.instance_of(str),
|
||||
default=ROOM_COMPLEXITY_TOO_GREAT,
|
||||
)
|
||||
admins_can_join = attr.ib(
|
||||
validator=attr.validators.instance_of(bool), default=False
|
||||
)
|
||||
|
||||
self.limit_remote_rooms = LimitRemoteRoomsConfig(
|
||||
**(config.get("limit_remote_rooms") or {})
|
||||
|
@ -893,6 +896,10 @@ class ServerConfig(Config):
|
|||
#
|
||||
#complexity_error: "This room is too complex."
|
||||
|
||||
# allow server admins to join complex rooms. Default is false.
|
||||
#
|
||||
#admins_can_join: true
|
||||
|
||||
# Whether to require a user to be in the room to add an alias to it.
|
||||
# Defaults to 'true'.
|
||||
#
|
||||
|
|
|
@ -632,7 +632,8 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
|||
)
|
||||
|
||||
try:
|
||||
query_response = yield self.client.post_json(
|
||||
query_response = yield defer.ensureDeferred(
|
||||
self.client.post_json(
|
||||
destination=perspective_name,
|
||||
path="/_matrix/key/v2/query",
|
||||
data={
|
||||
|
@ -645,6 +646,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
|||
}
|
||||
},
|
||||
)
|
||||
)
|
||||
except (NotRetryingDestination, RequestSendFailed) as e:
|
||||
# these both have str() representations which we can't really improve upon
|
||||
raise KeyLookupError(str(e))
|
||||
|
@ -792,7 +794,8 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
|
|||
|
||||
time_now_ms = self.clock.time_msec()
|
||||
try:
|
||||
response = yield self.client.get_json(
|
||||
response = yield defer.ensureDeferred(
|
||||
self.client.get_json(
|
||||
destination=server_name,
|
||||
path="/_matrix/key/v2/server/"
|
||||
+ urllib.parse.quote(requested_key_id),
|
||||
|
@ -810,6 +813,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
|
|||
# read the response).
|
||||
timeout=10000,
|
||||
)
|
||||
)
|
||||
except (NotRetryingDestination, RequestSendFailed) as e:
|
||||
# these both have str() representations which we can't really improve
|
||||
# upon
|
||||
|
|
|
@ -17,8 +17,6 @@ from typing import Optional
|
|||
import attr
|
||||
from nacl.signing import SigningKey
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import MAX_DEPTH
|
||||
from synapse.api.errors import UnsupportedRoomVersionError
|
||||
from synapse.api.room_versions import (
|
||||
|
@ -95,31 +93,30 @@ class EventBuilder(object):
|
|||
def is_state(self):
|
||||
return self._state_key is not None
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def build(self, prev_event_ids):
|
||||
async def build(self, prev_event_ids):
|
||||
"""Transform into a fully signed and hashed event
|
||||
|
||||
Args:
|
||||
prev_event_ids (list[str]): The event IDs to use as the prev events
|
||||
|
||||
Returns:
|
||||
Deferred[FrozenEvent]
|
||||
FrozenEvent
|
||||
"""
|
||||
|
||||
state_ids = yield defer.ensureDeferred(
|
||||
self._state.get_current_state_ids(self.room_id, prev_event_ids)
|
||||
state_ids = await self._state.get_current_state_ids(
|
||||
self.room_id, prev_event_ids
|
||||
)
|
||||
auth_ids = yield self._auth.compute_auth_events(self, state_ids)
|
||||
auth_ids = await self._auth.compute_auth_events(self, state_ids)
|
||||
|
||||
format_version = self.room_version.event_format
|
||||
if format_version == EventFormatVersions.V1:
|
||||
auth_events = yield self._store.add_event_hashes(auth_ids)
|
||||
prev_events = yield self._store.add_event_hashes(prev_event_ids)
|
||||
auth_events = await self._store.add_event_hashes(auth_ids)
|
||||
prev_events = await self._store.add_event_hashes(prev_event_ids)
|
||||
else:
|
||||
auth_events = auth_ids
|
||||
prev_events = prev_event_ids
|
||||
|
||||
old_depth = yield self._store.get_max_depth_of(prev_event_ids)
|
||||
old_depth = await self._store.get_max_depth_of(prev_event_ids)
|
||||
depth = old_depth + 1
|
||||
|
||||
# we cap depth of generated events, to ensure that they are not
|
||||
|
|
|
@ -12,17 +12,19 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional, Union
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import attr
|
||||
from frozendict import frozendict
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.appservice import ApplicationService
|
||||
from synapse.events import EventBase
|
||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||
from synapse.types import StateMap
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.storage.data_stores.main import DataStore
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
class EventContext:
|
||||
|
@ -129,8 +131,7 @@ class EventContext:
|
|||
delta_ids=delta_ids,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def serialize(self, event, store):
|
||||
async def serialize(self, event: EventBase, store: "DataStore") -> dict:
|
||||
"""Converts self to a type that can be serialized as JSON, and then
|
||||
deserialized by `deserialize`
|
||||
|
||||
|
@ -146,7 +147,7 @@ class EventContext:
|
|||
# the prev_state_ids, so if we're a state event we include the event
|
||||
# id that we replaced in the state.
|
||||
if event.is_state():
|
||||
prev_state_ids = yield self.get_prev_state_ids()
|
||||
prev_state_ids = await self.get_prev_state_ids()
|
||||
prev_state_id = prev_state_ids.get((event.type, event.state_key))
|
||||
else:
|
||||
prev_state_id = None
|
||||
|
@ -214,8 +215,7 @@ class EventContext:
|
|||
|
||||
return self._state_group
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_current_state_ids(self):
|
||||
async def get_current_state_ids(self) -> Optional[StateMap[str]]:
|
||||
"""
|
||||
Gets the room state map, including this event - ie, the state in ``state_group``
|
||||
|
||||
|
@ -224,8 +224,8 @@ class EventContext:
|
|||
``rejected`` is set.
|
||||
|
||||
Returns:
|
||||
Deferred[dict[(str, str), str]|None]: Returns None if state_group
|
||||
is None, which happens when the associated event is an outlier.
|
||||
Returns None if state_group is None, which happens when the associated
|
||||
event is an outlier.
|
||||
|
||||
Maps a (type, state_key) to the event ID of the state event matching
|
||||
this tuple.
|
||||
|
@ -233,23 +233,22 @@ class EventContext:
|
|||
if self.rejected:
|
||||
raise RuntimeError("Attempt to access state_ids of rejected event")
|
||||
|
||||
yield self._ensure_fetched()
|
||||
await self._ensure_fetched()
|
||||
return self._current_state_ids
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_prev_state_ids(self):
|
||||
async def get_prev_state_ids(self):
|
||||
"""
|
||||
Gets the room state map, excluding this event.
|
||||
|
||||
For a non-state event, this will be the same as get_current_state_ids().
|
||||
|
||||
Returns:
|
||||
Deferred[dict[(str, str), str]|None]: Returns None if state_group
|
||||
dict[(str, str), str]|None: Returns None if state_group
|
||||
is None, which happens when the associated event is an outlier.
|
||||
Maps a (type, state_key) to the event ID of the state event matching
|
||||
this tuple.
|
||||
"""
|
||||
yield self._ensure_fetched()
|
||||
await self._ensure_fetched()
|
||||
return self._prev_state_ids
|
||||
|
||||
def get_cached_current_state_ids(self):
|
||||
|
@ -269,8 +268,8 @@ class EventContext:
|
|||
|
||||
return self._current_state_ids
|
||||
|
||||
def _ensure_fetched(self):
|
||||
return defer.succeed(None)
|
||||
async def _ensure_fetched(self):
|
||||
return None
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
|
@ -303,21 +302,20 @@ class _AsyncEventContextImpl(EventContext):
|
|||
_event_state_key = attr.ib(default=None)
|
||||
_fetching_state_deferred = attr.ib(default=None)
|
||||
|
||||
def _ensure_fetched(self):
|
||||
async def _ensure_fetched(self):
|
||||
if not self._fetching_state_deferred:
|
||||
self._fetching_state_deferred = run_in_background(self._fill_out_state)
|
||||
|
||||
return make_deferred_yieldable(self._fetching_state_deferred)
|
||||
return await make_deferred_yieldable(self._fetching_state_deferred)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _fill_out_state(self):
|
||||
async def _fill_out_state(self):
|
||||
"""Called to populate the _current_state_ids and _prev_state_ids
|
||||
attributes by loading from the database.
|
||||
"""
|
||||
if self.state_group is None:
|
||||
return
|
||||
|
||||
self._current_state_ids = yield self._storage.state.get_state_ids_for_group(
|
||||
self._current_state_ids = await self._storage.state.get_state_ids_for_group(
|
||||
self.state_group
|
||||
)
|
||||
if self._event_state_key is not None:
|
||||
|
|
|
@ -13,7 +13,9 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.types import Requester
|
||||
|
||||
|
||||
class ThirdPartyEventRules(object):
|
||||
|
@ -39,76 +41,79 @@ class ThirdPartyEventRules(object):
|
|||
config=config, http_client=hs.get_simple_http_client()
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_event_allowed(self, event, context):
|
||||
async def check_event_allowed(
|
||||
self, event: EventBase, context: EventContext
|
||||
) -> bool:
|
||||
"""Check if a provided event should be allowed in the given context.
|
||||
|
||||
Args:
|
||||
event (synapse.events.EventBase): The event to be checked.
|
||||
context (synapse.events.snapshot.EventContext): The context of the event.
|
||||
event: The event to be checked.
|
||||
context: The context of the event.
|
||||
|
||||
Returns:
|
||||
defer.Deferred[bool]: True if the event should be allowed, False if not.
|
||||
True if the event should be allowed, False if not.
|
||||
"""
|
||||
if self.third_party_rules is None:
|
||||
return True
|
||||
|
||||
prev_state_ids = yield context.get_prev_state_ids()
|
||||
prev_state_ids = await context.get_prev_state_ids()
|
||||
|
||||
# Retrieve the state events from the database.
|
||||
state_events = {}
|
||||
for key, event_id in prev_state_ids.items():
|
||||
state_events[key] = yield self.store.get_event(event_id, allow_none=True)
|
||||
state_events[key] = await self.store.get_event(event_id, allow_none=True)
|
||||
|
||||
ret = yield self.third_party_rules.check_event_allowed(event, state_events)
|
||||
ret = await self.third_party_rules.check_event_allowed(event, state_events)
|
||||
return ret
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_create_room(self, requester, config, is_requester_admin):
|
||||
async def on_create_room(
|
||||
self, requester: Requester, config: dict, is_requester_admin: bool
|
||||
) -> bool:
|
||||
"""Intercept requests to create room to allow, deny or update the
|
||||
request config.
|
||||
|
||||
Args:
|
||||
requester (Requester)
|
||||
config (dict): The creation config from the client.
|
||||
is_requester_admin (bool): If the requester is an admin
|
||||
requester
|
||||
config: The creation config from the client.
|
||||
is_requester_admin: If the requester is an admin
|
||||
|
||||
Returns:
|
||||
defer.Deferred[bool]: Whether room creation is allowed or denied.
|
||||
Whether room creation is allowed or denied.
|
||||
"""
|
||||
|
||||
if self.third_party_rules is None:
|
||||
return True
|
||||
|
||||
ret = yield self.third_party_rules.on_create_room(
|
||||
ret = await self.third_party_rules.on_create_room(
|
||||
requester, config, is_requester_admin
|
||||
)
|
||||
return ret
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_threepid_can_be_invited(self, medium, address, room_id):
|
||||
async def check_threepid_can_be_invited(
|
||||
self, medium: str, address: str, room_id: str
|
||||
) -> bool:
|
||||
"""Check if a provided 3PID can be invited in the given room.
|
||||
|
||||
Args:
|
||||
medium (str): The 3PID's medium.
|
||||
address (str): The 3PID's address.
|
||||
room_id (str): The room we want to invite the threepid to.
|
||||
medium: The 3PID's medium.
|
||||
address: The 3PID's address.
|
||||
room_id: The room we want to invite the threepid to.
|
||||
|
||||
Returns:
|
||||
defer.Deferred[bool], True if the 3PID can be invited, False if not.
|
||||
True if the 3PID can be invited, False if not.
|
||||
"""
|
||||
|
||||
if self.third_party_rules is None:
|
||||
return True
|
||||
|
||||
state_ids = yield self.store.get_filtered_current_state_ids(room_id)
|
||||
room_state_events = yield self.store.get_events(state_ids.values())
|
||||
state_ids = await self.store.get_filtered_current_state_ids(room_id)
|
||||
room_state_events = await self.store.get_events(state_ids.values())
|
||||
|
||||
state_events = {}
|
||||
for key, event_id in state_ids.items():
|
||||
state_events[key] = room_state_events[event_id]
|
||||
|
||||
ret = yield self.third_party_rules.check_threepid_can_be_invited(
|
||||
ret = await self.third_party_rules.check_threepid_can_be_invited(
|
||||
medium, address, state_events
|
||||
)
|
||||
return ret
|
||||
|
|
|
@ -18,8 +18,6 @@ from typing import Any, Mapping, Union
|
|||
|
||||
from frozendict import frozendict
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes, RelationTypes
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.api.room_versions import RoomVersion
|
||||
|
@ -337,8 +335,9 @@ class EventClientSerializer(object):
|
|||
hs.config.experimental_msc1849_support_enabled
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def serialize_event(self, event, time_now, bundle_aggregations=True, **kwargs):
|
||||
async def serialize_event(
|
||||
self, event, time_now, bundle_aggregations=True, **kwargs
|
||||
):
|
||||
"""Serializes a single event.
|
||||
|
||||
Args:
|
||||
|
@ -348,7 +347,7 @@ class EventClientSerializer(object):
|
|||
**kwargs: Arguments to pass to `serialize_event`
|
||||
|
||||
Returns:
|
||||
Deferred[dict]: The serialized event
|
||||
dict: The serialized event
|
||||
"""
|
||||
# To handle the case of presence events and the like
|
||||
if not isinstance(event, EventBase):
|
||||
|
@ -363,8 +362,8 @@ class EventClientSerializer(object):
|
|||
if not event.internal_metadata.is_redacted() and (
|
||||
self.experimental_msc1849_support_enabled and bundle_aggregations
|
||||
):
|
||||
annotations = yield self.store.get_aggregation_groups_for_event(event_id)
|
||||
references = yield self.store.get_relations_for_event(
|
||||
annotations = await self.store.get_aggregation_groups_for_event(event_id)
|
||||
references = await self.store.get_relations_for_event(
|
||||
event_id, RelationTypes.REFERENCE, direction="f"
|
||||
)
|
||||
|
||||
|
@ -378,7 +377,7 @@ class EventClientSerializer(object):
|
|||
|
||||
edit = None
|
||||
if event.type == EventTypes.Message:
|
||||
edit = yield self.store.get_applicable_edit(event_id)
|
||||
edit = await self.store.get_applicable_edit(event_id)
|
||||
|
||||
if edit:
|
||||
# If there is an edit replace the content, preserving existing
|
||||
|
|
|
@ -135,7 +135,7 @@ class FederationClient(FederationBase):
|
|||
and try the request anyway.
|
||||
|
||||
Returns:
|
||||
a Deferred which will eventually yield a JSON object from the
|
||||
a Awaitable which will eventually yield a JSON object from the
|
||||
response
|
||||
"""
|
||||
sent_queries_counter.labels(query_type).inc()
|
||||
|
@ -157,7 +157,7 @@ class FederationClient(FederationBase):
|
|||
content (dict): The query content.
|
||||
|
||||
Returns:
|
||||
a Deferred which will eventually yield a JSON object from the
|
||||
an Awaitable which will eventually yield a JSON object from the
|
||||
response
|
||||
"""
|
||||
sent_queries_counter.labels("client_device_keys").inc()
|
||||
|
@ -180,7 +180,7 @@ class FederationClient(FederationBase):
|
|||
content (dict): The query content.
|
||||
|
||||
Returns:
|
||||
a Deferred which will eventually yield a JSON object from the
|
||||
an Awaitable which will eventually yield a JSON object from the
|
||||
response
|
||||
"""
|
||||
sent_queries_counter.labels("client_one_time_keys").inc()
|
||||
|
@ -900,7 +900,7 @@ class FederationClient(FederationBase):
|
|||
party instance
|
||||
|
||||
Returns:
|
||||
Deferred[Dict[str, Any]]: The response from the remote server, or None if
|
||||
Awaitable[Dict[str, Any]]: The response from the remote server, or None if
|
||||
`remote_server` is the same as the local server_name
|
||||
|
||||
Raises:
|
||||
|
|
|
@ -288,8 +288,7 @@ class FederationSender(object):
|
|||
for destination in destinations:
|
||||
self._get_per_destination_queue(destination).send_pdu(pdu, order)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send_read_receipt(self, receipt: ReadReceipt):
|
||||
async def send_read_receipt(self, receipt: ReadReceipt) -> None:
|
||||
"""Send a RR to any other servers in the room
|
||||
|
||||
Args:
|
||||
|
@ -330,9 +329,7 @@ class FederationSender(object):
|
|||
room_id = receipt.room_id
|
||||
|
||||
# Work out which remote servers should be poked and poke them.
|
||||
domains = yield defer.ensureDeferred(
|
||||
self.state.get_current_hosts_in_room(room_id)
|
||||
)
|
||||
domains = await self.state.get_current_hosts_in_room(room_id)
|
||||
domains = [
|
||||
d
|
||||
for d in domains
|
||||
|
@ -387,8 +384,7 @@ class FederationSender(object):
|
|||
queue.flush_read_receipts_for_room(room_id)
|
||||
|
||||
@preserve_fn # the caller should not yield on this
|
||||
@defer.inlineCallbacks
|
||||
def send_presence(self, states: List[UserPresenceState]):
|
||||
async def send_presence(self, states: List[UserPresenceState]):
|
||||
"""Send the new presence states to the appropriate destinations.
|
||||
|
||||
This actually queues up the presence states ready for sending and
|
||||
|
@ -423,7 +419,7 @@ class FederationSender(object):
|
|||
if not states_map:
|
||||
break
|
||||
|
||||
yield self._process_presence_inner(list(states_map.values()))
|
||||
await self._process_presence_inner(list(states_map.values()))
|
||||
except Exception:
|
||||
logger.exception("Error sending presence states to servers")
|
||||
finally:
|
||||
|
@ -450,14 +446,11 @@ class FederationSender(object):
|
|||
self._get_per_destination_queue(destination).send_presence(states)
|
||||
|
||||
@measure_func("txnqueue._process_presence")
|
||||
@defer.inlineCallbacks
|
||||
def _process_presence_inner(self, states: List[UserPresenceState]):
|
||||
async def _process_presence_inner(self, states: List[UserPresenceState]):
|
||||
"""Given a list of states populate self.pending_presence_by_dest and
|
||||
poke to send a new transaction to each destination
|
||||
"""
|
||||
hosts_and_states = yield defer.ensureDeferred(
|
||||
get_interested_remotes(self.store, states, self.state)
|
||||
)
|
||||
hosts_and_states = await get_interested_remotes(self.store, states, self.state)
|
||||
|
||||
for destinations, states in hosts_and_states:
|
||||
for destination in destinations:
|
||||
|
|
|
@ -18,8 +18,6 @@ import logging
|
|||
import urllib
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import Membership
|
||||
from synapse.api.errors import Codes, HttpResponseException, SynapseError
|
||||
from synapse.api.urls import (
|
||||
|
@ -51,7 +49,7 @@ class TransportLayerClient(object):
|
|||
event_id (str): The event we want the context at.
|
||||
|
||||
Returns:
|
||||
Deferred: Results in a dict received from the remote homeserver.
|
||||
Awaitable: Results in a dict received from the remote homeserver.
|
||||
"""
|
||||
logger.debug("get_room_state_ids dest=%s, room=%s", destination, room_id)
|
||||
|
||||
|
@ -75,7 +73,7 @@ class TransportLayerClient(object):
|
|||
giving up. None indicates no timeout.
|
||||
|
||||
Returns:
|
||||
Deferred: Results in a dict received from the remote homeserver.
|
||||
Awaitable: Results in a dict received from the remote homeserver.
|
||||
"""
|
||||
logger.debug("get_pdu dest=%s, event_id=%s", destination, event_id)
|
||||
|
||||
|
@ -96,7 +94,7 @@ class TransportLayerClient(object):
|
|||
limit (int)
|
||||
|
||||
Returns:
|
||||
Deferred: Results in a dict received from the remote homeserver.
|
||||
Awaitable: Results in a dict received from the remote homeserver.
|
||||
"""
|
||||
logger.debug(
|
||||
"backfill dest=%s, room_id=%s, event_tuples=%r, limit=%s",
|
||||
|
@ -118,16 +116,15 @@ class TransportLayerClient(object):
|
|||
destination, path=path, args=args, try_trailing_slash_on_400=True
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def send_transaction(self, transaction, json_data_callback=None):
|
||||
async def send_transaction(self, transaction, json_data_callback=None):
|
||||
""" Sends the given Transaction to its destination
|
||||
|
||||
Args:
|
||||
transaction (Transaction)
|
||||
|
||||
Returns:
|
||||
Deferred: Succeeds when we get a 2xx HTTP response. The result
|
||||
Succeeds when we get a 2xx HTTP response. The result
|
||||
will be the decoded JSON body.
|
||||
|
||||
Fails with ``HTTPRequestException`` if we get an HTTP response
|
||||
|
@ -154,7 +151,7 @@ class TransportLayerClient(object):
|
|||
|
||||
path = _create_v1_path("/send/%s", transaction.transaction_id)
|
||||
|
||||
response = yield self.client.put_json(
|
||||
response = await self.client.put_json(
|
||||
transaction.destination,
|
||||
path=path,
|
||||
data=json_data,
|
||||
|
@ -166,14 +163,13 @@ class TransportLayerClient(object):
|
|||
|
||||
return response
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def make_query(
|
||||
async def make_query(
|
||||
self, destination, query_type, args, retry_on_dns_fail, ignore_backoff=False
|
||||
):
|
||||
path = _create_v1_path("/query/%s", query_type)
|
||||
|
||||
content = yield self.client.get_json(
|
||||
content = await self.client.get_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
args=args,
|
||||
|
@ -184,9 +180,10 @@ class TransportLayerClient(object):
|
|||
|
||||
return content
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def make_membership_event(self, destination, room_id, user_id, membership, params):
|
||||
async def make_membership_event(
|
||||
self, destination, room_id, user_id, membership, params
|
||||
):
|
||||
"""Asks a remote server to build and sign us a membership event
|
||||
|
||||
Note that this does not append any events to any graphs.
|
||||
|
@ -200,7 +197,7 @@ class TransportLayerClient(object):
|
|||
request.
|
||||
|
||||
Returns:
|
||||
Deferred: Succeeds when we get a 2xx HTTP response. The result
|
||||
Succeeds when we get a 2xx HTTP response. The result
|
||||
will be the decoded JSON body (ie, the new event).
|
||||
|
||||
Fails with ``HTTPRequestException`` if we get an HTTP response
|
||||
|
@ -231,7 +228,7 @@ class TransportLayerClient(object):
|
|||
ignore_backoff = True
|
||||
retry_on_dns_fail = True
|
||||
|
||||
content = yield self.client.get_json(
|
||||
content = await self.client.get_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
args=params,
|
||||
|
@ -242,34 +239,31 @@ class TransportLayerClient(object):
|
|||
|
||||
return content
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def send_join_v1(self, destination, room_id, event_id, content):
|
||||
async def send_join_v1(self, destination, room_id, event_id, content):
|
||||
path = _create_v1_path("/send_join/%s/%s", room_id, event_id)
|
||||
|
||||
response = yield self.client.put_json(
|
||||
response = await self.client.put_json(
|
||||
destination=destination, path=path, data=content
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def send_join_v2(self, destination, room_id, event_id, content):
|
||||
async def send_join_v2(self, destination, room_id, event_id, content):
|
||||
path = _create_v2_path("/send_join/%s/%s", room_id, event_id)
|
||||
|
||||
response = yield self.client.put_json(
|
||||
response = await self.client.put_json(
|
||||
destination=destination, path=path, data=content
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def send_leave_v1(self, destination, room_id, event_id, content):
|
||||
async def send_leave_v1(self, destination, room_id, event_id, content):
|
||||
path = _create_v1_path("/send_leave/%s/%s", room_id, event_id)
|
||||
|
||||
response = yield self.client.put_json(
|
||||
response = await self.client.put_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
data=content,
|
||||
|
@ -282,12 +276,11 @@ class TransportLayerClient(object):
|
|||
|
||||
return response
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def send_leave_v2(self, destination, room_id, event_id, content):
|
||||
async def send_leave_v2(self, destination, room_id, event_id, content):
|
||||
path = _create_v2_path("/send_leave/%s/%s", room_id, event_id)
|
||||
|
||||
response = yield self.client.put_json(
|
||||
response = await self.client.put_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
data=content,
|
||||
|
@ -300,31 +293,28 @@ class TransportLayerClient(object):
|
|||
|
||||
return response
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def send_invite_v1(self, destination, room_id, event_id, content):
|
||||
async def send_invite_v1(self, destination, room_id, event_id, content):
|
||||
path = _create_v1_path("/invite/%s/%s", room_id, event_id)
|
||||
|
||||
response = yield self.client.put_json(
|
||||
response = await self.client.put_json(
|
||||
destination=destination, path=path, data=content, ignore_backoff=True
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def send_invite_v2(self, destination, room_id, event_id, content):
|
||||
async def send_invite_v2(self, destination, room_id, event_id, content):
|
||||
path = _create_v2_path("/invite/%s/%s", room_id, event_id)
|
||||
|
||||
response = yield self.client.put_json(
|
||||
response = await self.client.put_json(
|
||||
destination=destination, path=path, data=content, ignore_backoff=True
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def get_public_rooms(
|
||||
async def get_public_rooms(
|
||||
self,
|
||||
remote_server: str,
|
||||
limit: Optional[int] = None,
|
||||
|
@ -355,7 +345,7 @@ class TransportLayerClient(object):
|
|||
data["filter"] = search_filter
|
||||
|
||||
try:
|
||||
response = yield self.client.post_json(
|
||||
response = await self.client.post_json(
|
||||
destination=remote_server, path=path, data=data, ignore_backoff=True
|
||||
)
|
||||
except HttpResponseException as e:
|
||||
|
@ -381,7 +371,7 @@ class TransportLayerClient(object):
|
|||
args["since"] = [since_token]
|
||||
|
||||
try:
|
||||
response = yield self.client.get_json(
|
||||
response = await self.client.get_json(
|
||||
destination=remote_server, path=path, args=args, ignore_backoff=True
|
||||
)
|
||||
except HttpResponseException as e:
|
||||
|
@ -396,29 +386,26 @@ class TransportLayerClient(object):
|
|||
|
||||
return response
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def exchange_third_party_invite(self, destination, room_id, event_dict):
|
||||
async def exchange_third_party_invite(self, destination, room_id, event_dict):
|
||||
path = _create_v1_path("/exchange_third_party_invite/%s", room_id)
|
||||
|
||||
response = yield self.client.put_json(
|
||||
response = await self.client.put_json(
|
||||
destination=destination, path=path, data=event_dict
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def get_event_auth(self, destination, room_id, event_id):
|
||||
async def get_event_auth(self, destination, room_id, event_id):
|
||||
path = _create_v1_path("/event_auth/%s/%s", room_id, event_id)
|
||||
|
||||
content = yield self.client.get_json(destination=destination, path=path)
|
||||
content = await self.client.get_json(destination=destination, path=path)
|
||||
|
||||
return content
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def query_client_keys(self, destination, query_content, timeout):
|
||||
async def query_client_keys(self, destination, query_content, timeout):
|
||||
"""Query the device keys for a list of user ids hosted on a remote
|
||||
server.
|
||||
|
||||
|
@ -453,14 +440,13 @@ class TransportLayerClient(object):
|
|||
"""
|
||||
path = _create_v1_path("/user/keys/query")
|
||||
|
||||
content = yield self.client.post_json(
|
||||
content = await self.client.post_json(
|
||||
destination=destination, path=path, data=query_content, timeout=timeout
|
||||
)
|
||||
return content
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def query_user_devices(self, destination, user_id, timeout):
|
||||
async def query_user_devices(self, destination, user_id, timeout):
|
||||
"""Query the devices for a user id hosted on a remote server.
|
||||
|
||||
Response:
|
||||
|
@ -493,14 +479,13 @@ class TransportLayerClient(object):
|
|||
"""
|
||||
path = _create_v1_path("/user/devices/%s", user_id)
|
||||
|
||||
content = yield self.client.get_json(
|
||||
content = await self.client.get_json(
|
||||
destination=destination, path=path, timeout=timeout
|
||||
)
|
||||
return content
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def claim_client_keys(self, destination, query_content, timeout):
|
||||
async def claim_client_keys(self, destination, query_content, timeout):
|
||||
"""Claim one-time keys for a list of devices hosted on a remote server.
|
||||
|
||||
Request:
|
||||
|
@ -532,14 +517,13 @@ class TransportLayerClient(object):
|
|||
|
||||
path = _create_v1_path("/user/keys/claim")
|
||||
|
||||
content = yield self.client.post_json(
|
||||
content = await self.client.post_json(
|
||||
destination=destination, path=path, data=query_content, timeout=timeout
|
||||
)
|
||||
return content
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def get_missing_events(
|
||||
async def get_missing_events(
|
||||
self,
|
||||
destination,
|
||||
room_id,
|
||||
|
@ -551,7 +535,7 @@ class TransportLayerClient(object):
|
|||
):
|
||||
path = _create_v1_path("/get_missing_events/%s", room_id)
|
||||
|
||||
content = yield self.client.post_json(
|
||||
content = await self.client.post_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
data={
|
||||
|
|
|
@ -41,8 +41,6 @@ from typing import Tuple
|
|||
|
||||
from signedjson.sign import sign_json
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.types import get_domain_from_id
|
||||
|
@ -72,8 +70,9 @@ class GroupAttestationSigning(object):
|
|||
self.server_name = hs.hostname
|
||||
self.signing_key = hs.signing_key
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def verify_attestation(self, attestation, group_id, user_id, server_name=None):
|
||||
async def verify_attestation(
|
||||
self, attestation, group_id, user_id, server_name=None
|
||||
):
|
||||
"""Verifies that the given attestation matches the given parameters.
|
||||
|
||||
An optional server_name can be supplied to explicitly set which server's
|
||||
|
@ -102,7 +101,7 @@ class GroupAttestationSigning(object):
|
|||
if valid_until_ms < now:
|
||||
raise SynapseError(400, "Attestation expired")
|
||||
|
||||
yield self.keyring.verify_json_for_server(
|
||||
await self.keyring.verify_json_for_server(
|
||||
server_name, attestation, now, "Group attestation"
|
||||
)
|
||||
|
||||
|
@ -142,8 +141,7 @@ class GroupAttestionRenewer(object):
|
|||
self._start_renew_attestations, 30 * 60 * 1000
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_renew_attestation(self, group_id, user_id, content):
|
||||
async def on_renew_attestation(self, group_id, user_id, content):
|
||||
"""When a remote updates an attestation
|
||||
"""
|
||||
attestation = content["attestation"]
|
||||
|
@ -151,11 +149,11 @@ class GroupAttestionRenewer(object):
|
|||
if not self.is_mine_id(group_id) and not self.is_mine_id(user_id):
|
||||
raise SynapseError(400, "Neither user not group are on this server")
|
||||
|
||||
yield self.attestations.verify_attestation(
|
||||
await self.attestations.verify_attestation(
|
||||
attestation, user_id=user_id, group_id=group_id
|
||||
)
|
||||
|
||||
yield self.store.update_remote_attestion(group_id, user_id, attestation)
|
||||
await self.store.update_remote_attestion(group_id, user_id, attestation)
|
||||
|
||||
return {}
|
||||
|
||||
|
@ -172,8 +170,7 @@ class GroupAttestionRenewer(object):
|
|||
now + UPDATE_ATTESTATION_TIME_MS
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _renew_attestation(group_user: Tuple[str, str]):
|
||||
async def _renew_attestation(group_user: Tuple[str, str]):
|
||||
group_id, user_id = group_user
|
||||
try:
|
||||
if not self.is_mine_id(group_id):
|
||||
|
@ -186,16 +183,16 @@ class GroupAttestionRenewer(object):
|
|||
user_id,
|
||||
group_id,
|
||||
)
|
||||
yield self.store.remove_attestation_renewal(group_id, user_id)
|
||||
await self.store.remove_attestation_renewal(group_id, user_id)
|
||||
return
|
||||
|
||||
attestation = self.attestations.create_attestation(group_id, user_id)
|
||||
|
||||
yield self.transport_client.renew_group_attestation(
|
||||
await self.transport_client.renew_group_attestation(
|
||||
destination, group_id, user_id, content={"attestation": attestation}
|
||||
)
|
||||
|
||||
yield self.store.update_attestation_renewal(
|
||||
await self.store.update_attestation_renewal(
|
||||
group_id, user_id, attestation
|
||||
)
|
||||
except (RequestSendFailed, HttpResponseException) as e:
|
||||
|
|
|
@ -27,7 +27,6 @@ from synapse.metrics import (
|
|||
event_processing_loop_room_count,
|
||||
)
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.util import log_failure
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -100,10 +99,11 @@ class ApplicationServicesHandler(object):
|
|||
|
||||
if not self.started_scheduler:
|
||||
|
||||
def start_scheduler():
|
||||
return self.scheduler.start().addErrback(
|
||||
log_failure, "Application Services Failure"
|
||||
)
|
||||
async def start_scheduler():
|
||||
try:
|
||||
return self.scheduler.start()
|
||||
except Exception:
|
||||
logger.error("Application Services Failure")
|
||||
|
||||
run_as_background_process("as_scheduler", start_scheduler)
|
||||
self.started_scheduler = True
|
||||
|
|
|
@ -2470,7 +2470,7 @@ class FederationHandler(BaseHandler):
|
|||
}
|
||||
|
||||
current_state_ids = await context.get_current_state_ids()
|
||||
current_state_ids = dict(current_state_ids)
|
||||
current_state_ids = dict(current_state_ids) # type: ignore
|
||||
|
||||
current_state_ids.update(state_updates)
|
||||
|
||||
|
|
|
@ -23,39 +23,32 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
def _create_rerouter(func_name):
|
||||
"""Returns a function that looks at the group id and calls the function
|
||||
"""Returns an async function that looks at the group id and calls the function
|
||||
on federation or the local group server if the group is local
|
||||
"""
|
||||
|
||||
def f(self, group_id, *args, **kwargs):
|
||||
async def f(self, group_id, *args, **kwargs):
|
||||
if self.is_mine_id(group_id):
|
||||
return getattr(self.groups_server_handler, func_name)(
|
||||
return await getattr(self.groups_server_handler, func_name)(
|
||||
group_id, *args, **kwargs
|
||||
)
|
||||
else:
|
||||
destination = get_domain_from_id(group_id)
|
||||
d = getattr(self.transport_client, func_name)(
|
||||
|
||||
try:
|
||||
return await getattr(self.transport_client, func_name)(
|
||||
destination, group_id, *args, **kwargs
|
||||
)
|
||||
|
||||
except HttpResponseException as e:
|
||||
# Capture errors returned by the remote homeserver and
|
||||
# re-throw specific errors as SynapseErrors. This is so
|
||||
# when the remote end responds with things like 403 Not
|
||||
# In Group, we can communicate that to the client instead
|
||||
# of a 500.
|
||||
def http_response_errback(failure):
|
||||
failure.trap(HttpResponseException)
|
||||
e = failure.value
|
||||
raise e.to_synapse_error()
|
||||
|
||||
def request_failed_errback(failure):
|
||||
failure.trap(RequestSendFailed)
|
||||
except RequestSendFailed:
|
||||
raise SynapseError(502, "Failed to contact group server")
|
||||
|
||||
d.addErrback(http_response_errback)
|
||||
d.addErrback(request_failed_errback)
|
||||
return d
|
||||
|
||||
return f
|
||||
|
||||
|
||||
|
|
|
@ -502,26 +502,39 @@ class RoomMemberHandler(object):
|
|||
user_id=target.to_string(), room_id=room_id
|
||||
) # type: Optional[RoomsForUser]
|
||||
if not invite:
|
||||
logger.info(
|
||||
"%s sent a leave request to %s, but that is not an active room "
|
||||
"on this server, and there is no pending invite",
|
||||
target,
|
||||
room_id,
|
||||
)
|
||||
|
||||
raise SynapseError(404, "Not a known room")
|
||||
|
||||
logger.info(
|
||||
"%s rejects invite to %s from %s", target, room_id, invite.sender
|
||||
)
|
||||
|
||||
if self.hs.is_mine_id(invite.sender):
|
||||
# the inviter was on our server, but has now left. Carry on
|
||||
# with the normal rejection codepath.
|
||||
#
|
||||
# This is a bit of a hack, because the room might still be
|
||||
# active on other servers.
|
||||
pass
|
||||
else:
|
||||
if not self.hs.is_mine_id(invite.sender):
|
||||
# send the rejection to the inviter's HS (with fallback to
|
||||
# local event)
|
||||
return await self.remote_reject_invite(
|
||||
invite.event_id, txn_id, requester, content,
|
||||
)
|
||||
|
||||
# the inviter was on our server, but has now left. Carry on
|
||||
# with the normal rejection codepath, which will also send the
|
||||
# rejection out to any other servers we believe are still in the room.
|
||||
|
||||
# thanks to overzealous cleaning up of event_forward_extremities in
|
||||
# `delete_old_current_state_events`, it's possible to end up with no
|
||||
# forward extremities here. If that happens, let's just hang the
|
||||
# rejection off the invite event.
|
||||
#
|
||||
# see: https://github.com/matrix-org/synapse/issues/7139
|
||||
if len(latest_event_ids) == 0:
|
||||
latest_event_ids = [invite.event_id]
|
||||
|
||||
return await self._local_membership_update(
|
||||
requester=requester,
|
||||
target=target,
|
||||
|
@ -985,7 +998,11 @@ class RoomMemberMasterHandler(RoomMemberHandler):
|
|||
if len(remote_room_hosts) == 0:
|
||||
raise SynapseError(404, "No known servers")
|
||||
|
||||
if self.hs.config.limit_remote_rooms.enabled:
|
||||
check_complexity = self.hs.config.limit_remote_rooms.enabled
|
||||
if check_complexity and self.hs.config.limit_remote_rooms.admins_can_join:
|
||||
check_complexity = not await self.hs.auth.is_server_admin(user)
|
||||
|
||||
if check_complexity:
|
||||
# Fetch the room complexity
|
||||
too_complex = await self._is_remote_room_too_complex(
|
||||
room_id, remote_room_hosts
|
||||
|
@ -1008,7 +1025,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
|
|||
|
||||
# Check the room we just joined wasn't too large, if we didn't fetch the
|
||||
# complexity of it before.
|
||||
if self.hs.config.limit_remote_rooms.enabled:
|
||||
if check_complexity:
|
||||
if too_complex is False:
|
||||
# We checked, and we're under the limit.
|
||||
return event_id, stream_id
|
||||
|
|
|
@ -96,6 +96,9 @@ class SamlHandler:
|
|||
relay_state=client_redirect_url
|
||||
)
|
||||
|
||||
# Since SAML sessions timeout it is useful to log when they were created.
|
||||
logger.info("Initiating a new SAML session: %s" % (reqid,))
|
||||
|
||||
now = self._clock.time_msec()
|
||||
self._outstanding_requests_dict[reqid] = Saml2SessionData(
|
||||
creation_time=now, ui_auth_session_id=ui_auth_session_id,
|
||||
|
|
|
@ -103,6 +103,7 @@ class JoinedSyncResult:
|
|||
account_data = attr.ib(type=List[JsonDict])
|
||||
unread_notifications = attr.ib(type=JsonDict)
|
||||
summary = attr.ib(type=Optional[JsonDict])
|
||||
unread_count = attr.ib(type=int)
|
||||
|
||||
def __nonzero__(self) -> bool:
|
||||
"""Make the result appear empty if there are no updates. This is used
|
||||
|
@ -1886,6 +1887,10 @@ class SyncHandler(object):
|
|||
|
||||
if room_builder.rtype == "joined":
|
||||
unread_notifications = {} # type: Dict[str, str]
|
||||
|
||||
unread_count = await self.store.get_unread_message_count_for_user(
|
||||
room_id, sync_config.user.to_string(),
|
||||
)
|
||||
room_sync = JoinedSyncResult(
|
||||
room_id=room_id,
|
||||
timeline=batch,
|
||||
|
@ -1894,6 +1899,7 @@ class SyncHandler(object):
|
|||
account_data=account_data_events,
|
||||
unread_notifications=unread_notifications,
|
||||
summary=summary,
|
||||
unread_count=unread_count,
|
||||
)
|
||||
|
||||
if room_sync or always_include:
|
||||
|
|
|
@ -395,7 +395,9 @@ class SimpleHttpClient(object):
|
|||
if 200 <= response.code < 300:
|
||||
return json.loads(body.decode("utf-8"))
|
||||
else:
|
||||
raise HttpResponseException(response.code, response.phrase, body)
|
||||
raise HttpResponseException(
|
||||
response.code, response.phrase.decode("ascii", errors="replace"), body
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def post_json_get_json(self, uri, post_json, headers=None):
|
||||
|
@ -436,7 +438,9 @@ class SimpleHttpClient(object):
|
|||
if 200 <= response.code < 300:
|
||||
return json.loads(body.decode("utf-8"))
|
||||
else:
|
||||
raise HttpResponseException(response.code, response.phrase, body)
|
||||
raise HttpResponseException(
|
||||
response.code, response.phrase.decode("ascii", errors="replace"), body
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_json(self, uri, args={}, headers=None):
|
||||
|
@ -509,7 +513,9 @@ class SimpleHttpClient(object):
|
|||
if 200 <= response.code < 300:
|
||||
return json.loads(body.decode("utf-8"))
|
||||
else:
|
||||
raise HttpResponseException(response.code, response.phrase, body)
|
||||
raise HttpResponseException(
|
||||
response.code, response.phrase.decode("ascii", errors="replace"), body
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_raw(self, uri, args={}, headers=None):
|
||||
|
@ -544,7 +550,9 @@ class SimpleHttpClient(object):
|
|||
if 200 <= response.code < 300:
|
||||
return body
|
||||
else:
|
||||
raise HttpResponseException(response.code, response.phrase, body)
|
||||
raise HttpResponseException(
|
||||
response.code, response.phrase.decode("ascii", errors="replace"), body
|
||||
)
|
||||
|
||||
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
|
||||
# The two should be factored out.
|
||||
|
|
|
@ -121,8 +121,7 @@ class MatrixFederationRequest(object):
|
|||
return self.json
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _handle_json_response(reactor, timeout_sec, request, response):
|
||||
async def _handle_json_response(reactor, timeout_sec, request, response):
|
||||
"""
|
||||
Reads the JSON body of a response, with a timeout
|
||||
|
||||
|
@ -141,7 +140,7 @@ def _handle_json_response(reactor, timeout_sec, request, response):
|
|||
d = treq.json_content(response)
|
||||
d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor)
|
||||
|
||||
body = yield make_deferred_yieldable(d)
|
||||
body = await make_deferred_yieldable(d)
|
||||
except TimeoutError as e:
|
||||
logger.warning(
|
||||
"{%s} [%s] Timed out reading response", request.txn_id, request.destination,
|
||||
|
@ -224,8 +223,7 @@ class MatrixFederationHttpClient(object):
|
|||
|
||||
self._cooperator = Cooperator(scheduler=schedule)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _send_request_with_optional_trailing_slash(
|
||||
async def _send_request_with_optional_trailing_slash(
|
||||
self, request, try_trailing_slash_on_400=False, **send_request_args
|
||||
):
|
||||
"""Wrapper for _send_request which can optionally retry the request
|
||||
|
@ -246,10 +244,10 @@ class MatrixFederationHttpClient(object):
|
|||
(except 429).
|
||||
|
||||
Returns:
|
||||
Deferred[Dict]: Parsed JSON response body.
|
||||
Dict: Parsed JSON response body.
|
||||
"""
|
||||
try:
|
||||
response = yield self._send_request(request, **send_request_args)
|
||||
response = await self._send_request(request, **send_request_args)
|
||||
except HttpResponseException as e:
|
||||
# Received an HTTP error > 300. Check if it meets the requirements
|
||||
# to retry with a trailing slash
|
||||
|
@ -265,12 +263,11 @@ class MatrixFederationHttpClient(object):
|
|||
logger.info("Retrying request with trailing slash")
|
||||
request.path += "/"
|
||||
|
||||
response = yield self._send_request(request, **send_request_args)
|
||||
response = await self._send_request(request, **send_request_args)
|
||||
|
||||
return response
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _send_request(
|
||||
async def _send_request(
|
||||
self,
|
||||
request,
|
||||
retry_on_dns_fail=True,
|
||||
|
@ -311,7 +308,7 @@ class MatrixFederationHttpClient(object):
|
|||
backoff_on_404 (bool): Back off if we get a 404
|
||||
|
||||
Returns:
|
||||
Deferred[twisted.web.client.Response]: resolves with the HTTP
|
||||
twisted.web.client.Response: resolves with the HTTP
|
||||
response object on success.
|
||||
|
||||
Raises:
|
||||
|
@ -335,7 +332,7 @@ class MatrixFederationHttpClient(object):
|
|||
):
|
||||
raise FederationDeniedError(request.destination)
|
||||
|
||||
limiter = yield synapse.util.retryutils.get_retry_limiter(
|
||||
limiter = await synapse.util.retryutils.get_retry_limiter(
|
||||
request.destination,
|
||||
self.clock,
|
||||
self._store,
|
||||
|
@ -433,7 +430,7 @@ class MatrixFederationHttpClient(object):
|
|||
reactor=self.reactor,
|
||||
)
|
||||
|
||||
response = yield request_deferred
|
||||
response = await request_deferred
|
||||
except TimeoutError as e:
|
||||
raise RequestSendFailed(e, can_retry=True) from e
|
||||
except DNSLookupError as e:
|
||||
|
@ -447,6 +444,7 @@ class MatrixFederationHttpClient(object):
|
|||
).inc()
|
||||
|
||||
set_tag(tags.HTTP_STATUS_CODE, response.code)
|
||||
response_phrase = response.phrase.decode("ascii", errors="replace")
|
||||
|
||||
if 200 <= response.code < 300:
|
||||
logger.debug(
|
||||
|
@ -454,7 +452,7 @@ class MatrixFederationHttpClient(object):
|
|||
request.txn_id,
|
||||
request.destination,
|
||||
response.code,
|
||||
response.phrase.decode("ascii", errors="replace"),
|
||||
response_phrase,
|
||||
)
|
||||
pass
|
||||
else:
|
||||
|
@ -463,7 +461,7 @@ class MatrixFederationHttpClient(object):
|
|||
request.txn_id,
|
||||
request.destination,
|
||||
response.code,
|
||||
response.phrase.decode("ascii", errors="replace"),
|
||||
response_phrase,
|
||||
)
|
||||
# :'(
|
||||
# Update transactions table?
|
||||
|
@ -473,7 +471,7 @@ class MatrixFederationHttpClient(object):
|
|||
)
|
||||
|
||||
try:
|
||||
body = yield make_deferred_yieldable(d)
|
||||
body = await make_deferred_yieldable(d)
|
||||
except Exception as e:
|
||||
# Eh, we're already going to raise an exception so lets
|
||||
# ignore if this fails.
|
||||
|
@ -487,7 +485,7 @@ class MatrixFederationHttpClient(object):
|
|||
)
|
||||
body = None
|
||||
|
||||
e = HttpResponseException(response.code, response.phrase, body)
|
||||
e = HttpResponseException(response.code, response_phrase, body)
|
||||
|
||||
# Retry if the error is a 429 (Too Many Requests),
|
||||
# otherwise just raise a standard HttpResponseException
|
||||
|
@ -527,7 +525,7 @@ class MatrixFederationHttpClient(object):
|
|||
delay,
|
||||
)
|
||||
|
||||
yield self.clock.sleep(delay)
|
||||
await self.clock.sleep(delay)
|
||||
retries_left -= 1
|
||||
else:
|
||||
raise
|
||||
|
@ -590,8 +588,7 @@ class MatrixFederationHttpClient(object):
|
|||
)
|
||||
return auth_headers
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def put_json(
|
||||
async def put_json(
|
||||
self,
|
||||
destination,
|
||||
path,
|
||||
|
@ -635,7 +632,7 @@ class MatrixFederationHttpClient(object):
|
|||
enabled.
|
||||
|
||||
Returns:
|
||||
Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The
|
||||
dict|list: Succeeds when we get a 2xx HTTP response. The
|
||||
result will be the decoded JSON body.
|
||||
|
||||
Raises:
|
||||
|
@ -657,7 +654,7 @@ class MatrixFederationHttpClient(object):
|
|||
json=data,
|
||||
)
|
||||
|
||||
response = yield self._send_request_with_optional_trailing_slash(
|
||||
response = await self._send_request_with_optional_trailing_slash(
|
||||
request,
|
||||
try_trailing_slash_on_400,
|
||||
backoff_on_404=backoff_on_404,
|
||||
|
@ -666,14 +663,13 @@ class MatrixFederationHttpClient(object):
|
|||
timeout=timeout,
|
||||
)
|
||||
|
||||
body = yield _handle_json_response(
|
||||
body = await _handle_json_response(
|
||||
self.reactor, self.default_timeout, request, response
|
||||
)
|
||||
|
||||
return body
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def post_json(
|
||||
async def post_json(
|
||||
self,
|
||||
destination,
|
||||
path,
|
||||
|
@ -706,7 +702,7 @@ class MatrixFederationHttpClient(object):
|
|||
|
||||
args (dict): query params
|
||||
Returns:
|
||||
Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The
|
||||
dict|list: Succeeds when we get a 2xx HTTP response. The
|
||||
result will be the decoded JSON body.
|
||||
|
||||
Raises:
|
||||
|
@ -724,7 +720,7 @@ class MatrixFederationHttpClient(object):
|
|||
method="POST", destination=destination, path=path, query=args, json=data
|
||||
)
|
||||
|
||||
response = yield self._send_request(
|
||||
response = await self._send_request(
|
||||
request,
|
||||
long_retries=long_retries,
|
||||
timeout=timeout,
|
||||
|
@ -736,13 +732,12 @@ class MatrixFederationHttpClient(object):
|
|||
else:
|
||||
_sec_timeout = self.default_timeout
|
||||
|
||||
body = yield _handle_json_response(
|
||||
body = await _handle_json_response(
|
||||
self.reactor, _sec_timeout, request, response
|
||||
)
|
||||
return body
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_json(
|
||||
async def get_json(
|
||||
self,
|
||||
destination,
|
||||
path,
|
||||
|
@ -774,7 +769,7 @@ class MatrixFederationHttpClient(object):
|
|||
response we should try appending a trailing slash to the end of
|
||||
the request. Workaround for #3622 in Synapse <= v0.99.3.
|
||||
Returns:
|
||||
Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The
|
||||
dict|list: Succeeds when we get a 2xx HTTP response. The
|
||||
result will be the decoded JSON body.
|
||||
|
||||
Raises:
|
||||
|
@ -791,7 +786,7 @@ class MatrixFederationHttpClient(object):
|
|||
method="GET", destination=destination, path=path, query=args
|
||||
)
|
||||
|
||||
response = yield self._send_request_with_optional_trailing_slash(
|
||||
response = await self._send_request_with_optional_trailing_slash(
|
||||
request,
|
||||
try_trailing_slash_on_400,
|
||||
backoff_on_404=False,
|
||||
|
@ -800,14 +795,13 @@ class MatrixFederationHttpClient(object):
|
|||
timeout=timeout,
|
||||
)
|
||||
|
||||
body = yield _handle_json_response(
|
||||
body = await _handle_json_response(
|
||||
self.reactor, self.default_timeout, request, response
|
||||
)
|
||||
|
||||
return body
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_json(
|
||||
async def delete_json(
|
||||
self,
|
||||
destination,
|
||||
path,
|
||||
|
@ -835,7 +829,7 @@ class MatrixFederationHttpClient(object):
|
|||
|
||||
args (dict): query params
|
||||
Returns:
|
||||
Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The
|
||||
dict|list: Succeeds when we get a 2xx HTTP response. The
|
||||
result will be the decoded JSON body.
|
||||
|
||||
Raises:
|
||||
|
@ -852,20 +846,19 @@ class MatrixFederationHttpClient(object):
|
|||
method="DELETE", destination=destination, path=path, query=args
|
||||
)
|
||||
|
||||
response = yield self._send_request(
|
||||
response = await self._send_request(
|
||||
request,
|
||||
long_retries=long_retries,
|
||||
timeout=timeout,
|
||||
ignore_backoff=ignore_backoff,
|
||||
)
|
||||
|
||||
body = yield _handle_json_response(
|
||||
body = await _handle_json_response(
|
||||
self.reactor, self.default_timeout, request, response
|
||||
)
|
||||
return body
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_file(
|
||||
async def get_file(
|
||||
self,
|
||||
destination,
|
||||
path,
|
||||
|
@ -885,7 +878,7 @@ class MatrixFederationHttpClient(object):
|
|||
and try the request anyway.
|
||||
|
||||
Returns:
|
||||
Deferred[tuple[int, dict]]: Resolves with an (int,dict) tuple of
|
||||
tuple[int, dict]: Resolves with an (int,dict) tuple of
|
||||
the file length and a dict of the response headers.
|
||||
|
||||
Raises:
|
||||
|
@ -902,7 +895,7 @@ class MatrixFederationHttpClient(object):
|
|||
method="GET", destination=destination, path=path, query=args
|
||||
)
|
||||
|
||||
response = yield self._send_request(
|
||||
response = await self._send_request(
|
||||
request, retry_on_dns_fail=retry_on_dns_fail, ignore_backoff=ignore_backoff
|
||||
)
|
||||
|
||||
|
@ -911,7 +904,7 @@ class MatrixFederationHttpClient(object):
|
|||
try:
|
||||
d = _readBodyToFile(response, output_stream, max_size)
|
||||
d.addTimeout(self.default_timeout, self.reactor)
|
||||
length = yield make_deferred_yieldable(d)
|
||||
length = await make_deferred_yieldable(d)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"{%s} [%s] Error reading response: %s",
|
||||
|
|
|
@ -15,8 +15,6 @@
|
|||
|
||||
import logging
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
from .bulk_push_rule_evaluator import BulkPushRuleEvaluator
|
||||
|
@ -37,7 +35,6 @@ class ActionGenerator(object):
|
|||
# event stream, so we just run the rules for a client with no profile
|
||||
# tag (ie. we just need all the users).
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def handle_push_actions_for_event(self, event, context):
|
||||
async def handle_push_actions_for_event(self, event, context):
|
||||
with Measure(self.clock, "action_for_event_by_user"):
|
||||
yield self.bulk_evaluator.action_for_event_by_user(event, context)
|
||||
await self.bulk_evaluator.action_for_event_by_user(event, context)
|
||||
|
|
|
@ -19,8 +19,6 @@ from collections import namedtuple
|
|||
|
||||
from prometheus_client import Counter
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.event_auth import get_user_power_level
|
||||
from synapse.state import POWER_KEY
|
||||
|
@ -70,8 +68,7 @@ class BulkPushRuleEvaluator(object):
|
|||
resizable=False,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_rules_for_event(self, event, context):
|
||||
async def _get_rules_for_event(self, event, context):
|
||||
"""This gets the rules for all users in the room at the time of the event,
|
||||
as well as the push rules for the invitee if the event is an invite.
|
||||
|
||||
|
@ -79,19 +76,19 @@ class BulkPushRuleEvaluator(object):
|
|||
dict of user_id -> push_rules
|
||||
"""
|
||||
room_id = event.room_id
|
||||
rules_for_room = yield self._get_rules_for_room(room_id)
|
||||
rules_for_room = await self._get_rules_for_room(room_id)
|
||||
|
||||
rules_by_user = yield rules_for_room.get_rules(event, context)
|
||||
rules_by_user = await rules_for_room.get_rules(event, context)
|
||||
|
||||
# if this event is an invite event, we may need to run rules for the user
|
||||
# who's been invited, otherwise they won't get told they've been invited
|
||||
if event.type == "m.room.member" and event.content["membership"] == "invite":
|
||||
invited = event.state_key
|
||||
if invited and self.hs.is_mine_id(invited):
|
||||
has_pusher = yield self.store.user_has_pusher(invited)
|
||||
has_pusher = await self.store.user_has_pusher(invited)
|
||||
if has_pusher:
|
||||
rules_by_user = dict(rules_by_user)
|
||||
rules_by_user[invited] = yield self.store.get_push_rules_for_user(
|
||||
rules_by_user[invited] = await self.store.get_push_rules_for_user(
|
||||
invited
|
||||
)
|
||||
|
||||
|
@ -114,20 +111,19 @@ class BulkPushRuleEvaluator(object):
|
|||
self.room_push_rule_cache_metrics,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_power_levels_and_sender_level(self, event, context):
|
||||
prev_state_ids = yield context.get_prev_state_ids()
|
||||
async def _get_power_levels_and_sender_level(self, event, context):
|
||||
prev_state_ids = await context.get_prev_state_ids()
|
||||
pl_event_id = prev_state_ids.get(POWER_KEY)
|
||||
if pl_event_id:
|
||||
# fastpath: if there's a power level event, that's all we need, and
|
||||
# not having a power level event is an extreme edge case
|
||||
pl_event = yield self.store.get_event(pl_event_id)
|
||||
pl_event = await self.store.get_event(pl_event_id)
|
||||
auth_events = {POWER_KEY: pl_event}
|
||||
else:
|
||||
auth_events_ids = yield self.auth.compute_auth_events(
|
||||
auth_events_ids = await self.auth.compute_auth_events(
|
||||
event, prev_state_ids, for_verification=False
|
||||
)
|
||||
auth_events = yield self.store.get_events(auth_events_ids)
|
||||
auth_events = await self.store.get_events(auth_events_ids)
|
||||
auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
|
||||
|
||||
sender_level = get_user_power_level(event.sender, auth_events)
|
||||
|
@ -136,23 +132,19 @@ class BulkPushRuleEvaluator(object):
|
|||
|
||||
return pl_event.content if pl_event else {}, sender_level
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def action_for_event_by_user(self, event, context):
|
||||
async def action_for_event_by_user(self, event, context) -> None:
|
||||
"""Given an event and context, evaluate the push rules and insert the
|
||||
results into the event_push_actions_staging table.
|
||||
|
||||
Returns:
|
||||
Deferred
|
||||
"""
|
||||
rules_by_user = yield self._get_rules_for_event(event, context)
|
||||
rules_by_user = await self._get_rules_for_event(event, context)
|
||||
actions_by_user = {}
|
||||
|
||||
room_members = yield self.store.get_joined_users_from_context(event, context)
|
||||
room_members = await self.store.get_joined_users_from_context(event, context)
|
||||
|
||||
(
|
||||
power_levels,
|
||||
sender_power_level,
|
||||
) = yield self._get_power_levels_and_sender_level(event, context)
|
||||
) = await self._get_power_levels_and_sender_level(event, context)
|
||||
|
||||
evaluator = PushRuleEvaluatorForEvent(
|
||||
event, len(room_members), sender_power_level, power_levels
|
||||
|
@ -165,7 +157,7 @@ class BulkPushRuleEvaluator(object):
|
|||
continue
|
||||
|
||||
if not event.is_state():
|
||||
is_ignored = yield self.store.is_ignored_by(event.sender, uid)
|
||||
is_ignored = await self.store.is_ignored_by(event.sender, uid)
|
||||
if is_ignored:
|
||||
continue
|
||||
|
||||
|
@ -197,7 +189,7 @@ class BulkPushRuleEvaluator(object):
|
|||
# Mark in the DB staging area the push actions for users who should be
|
||||
# notified for this event. (This will then get handled when we persist
|
||||
# the event)
|
||||
yield self.store.add_push_actions_to_staging(event.event_id, actions_by_user)
|
||||
await self.store.add_push_actions_to_staging(event.event_id, actions_by_user)
|
||||
|
||||
|
||||
def _condition_checker(evaluator, conditions, uid, display_name, cache):
|
||||
|
@ -274,8 +266,7 @@ class RulesForRoom(object):
|
|||
# to self around in the callback.
|
||||
self.invalidate_all_cb = _Invalidation(rules_for_room_cache, room_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_rules(self, event, context):
|
||||
async def get_rules(self, event, context):
|
||||
"""Given an event context return the rules for all users who are
|
||||
currently in the room.
|
||||
"""
|
||||
|
@ -286,7 +277,7 @@ class RulesForRoom(object):
|
|||
self.room_push_rule_cache_metrics.inc_hits()
|
||||
return self.rules_by_user
|
||||
|
||||
with (yield self.linearizer.queue(())):
|
||||
with (await self.linearizer.queue(())):
|
||||
if state_group and self.state_group == state_group:
|
||||
logger.debug("Using cached rules for %r", self.room_id)
|
||||
self.room_push_rule_cache_metrics.inc_hits()
|
||||
|
@ -304,9 +295,7 @@ class RulesForRoom(object):
|
|||
|
||||
push_rules_delta_state_cache_metric.inc_hits()
|
||||
else:
|
||||
current_state_ids = yield defer.ensureDeferred(
|
||||
context.get_current_state_ids()
|
||||
)
|
||||
current_state_ids = await context.get_current_state_ids()
|
||||
push_rules_delta_state_cache_metric.inc_misses()
|
||||
|
||||
push_rules_state_size_counter.inc(len(current_state_ids))
|
||||
|
@ -353,7 +342,7 @@ class RulesForRoom(object):
|
|||
# If we have some memebr events we haven't seen, look them up
|
||||
# and fetch push rules for them if appropriate.
|
||||
logger.debug("Found new member events %r", missing_member_event_ids)
|
||||
yield self._update_rules_with_member_event_ids(
|
||||
await self._update_rules_with_member_event_ids(
|
||||
ret_rules_by_user, missing_member_event_ids, state_group, event
|
||||
)
|
||||
else:
|
||||
|
@ -371,8 +360,7 @@ class RulesForRoom(object):
|
|||
)
|
||||
return ret_rules_by_user
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _update_rules_with_member_event_ids(
|
||||
async def _update_rules_with_member_event_ids(
|
||||
self, ret_rules_by_user, member_event_ids, state_group, event
|
||||
):
|
||||
"""Update the partially filled rules_by_user dict by fetching rules for
|
||||
|
@ -388,7 +376,7 @@ class RulesForRoom(object):
|
|||
"""
|
||||
sequence = self.sequence
|
||||
|
||||
rows = yield self.store.get_membership_from_event_ids(member_event_ids.values())
|
||||
rows = await self.store.get_membership_from_event_ids(member_event_ids.values())
|
||||
|
||||
members = {row["event_id"]: (row["user_id"], row["membership"]) for row in rows}
|
||||
|
||||
|
@ -410,7 +398,7 @@ class RulesForRoom(object):
|
|||
|
||||
logger.debug("Joined: %r", interested_in_user_ids)
|
||||
|
||||
if_users_with_pushers = yield self.store.get_if_users_have_pushers(
|
||||
if_users_with_pushers = await self.store.get_if_users_have_pushers(
|
||||
interested_in_user_ids, on_invalidate=self.invalidate_all_cb
|
||||
)
|
||||
|
||||
|
@ -420,7 +408,7 @@ class RulesForRoom(object):
|
|||
|
||||
logger.debug("With pushers: %r", user_ids)
|
||||
|
||||
users_with_receipts = yield self.store.get_users_with_read_receipts_in_room(
|
||||
users_with_receipts = await self.store.get_users_with_read_receipts_in_room(
|
||||
self.room_id, on_invalidate=self.invalidate_all_cb
|
||||
)
|
||||
|
||||
|
@ -431,7 +419,7 @@ class RulesForRoom(object):
|
|||
if uid in interested_in_user_ids:
|
||||
user_ids.add(uid)
|
||||
|
||||
rules_by_user = yield self.store.bulk_get_push_rules(
|
||||
rules_by_user = await self.store.bulk_get_push_rules(
|
||||
user_ids, on_invalidate=self.invalidate_all_cb
|
||||
)
|
||||
|
||||
|
|
|
@ -17,7 +17,6 @@ import logging
|
|||
|
||||
from prometheus_client import Counter
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
|
@ -128,12 +127,11 @@ class HttpPusher(object):
|
|||
# but currently that's the only type of receipt anyway...
|
||||
run_as_background_process("http_pusher.on_new_receipts", self._update_badge)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _update_badge(self):
|
||||
async def _update_badge(self):
|
||||
# XXX as per https://github.com/matrix-org/matrix-doc/issues/2627, this seems
|
||||
# to be largely redundant. perhaps we can remove it.
|
||||
badge = yield push_tools.get_badge_count(self.hs.get_datastore(), self.user_id)
|
||||
yield self._send_badge(badge)
|
||||
badge = await push_tools.get_badge_count(self.hs.get_datastore(), self.user_id)
|
||||
await self._send_badge(badge)
|
||||
|
||||
def on_timer(self):
|
||||
self._start_processing()
|
||||
|
@ -152,8 +150,7 @@ class HttpPusher(object):
|
|||
|
||||
run_as_background_process("httppush.process", self._process)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _process(self):
|
||||
async def _process(self):
|
||||
# we should never get here if we are already processing
|
||||
assert not self._is_processing
|
||||
|
||||
|
@ -164,7 +161,7 @@ class HttpPusher(object):
|
|||
while True:
|
||||
starting_max_ordering = self.max_stream_ordering
|
||||
try:
|
||||
yield self._unsafe_process()
|
||||
await self._unsafe_process()
|
||||
except Exception:
|
||||
logger.exception("Exception processing notifs")
|
||||
if self.max_stream_ordering == starting_max_ordering:
|
||||
|
@ -172,8 +169,7 @@ class HttpPusher(object):
|
|||
finally:
|
||||
self._is_processing = False
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _unsafe_process(self):
|
||||
async def _unsafe_process(self):
|
||||
"""
|
||||
Looks for unset notifications and dispatch them, in order
|
||||
Never call this directly: use _process which will only allow this to
|
||||
|
@ -181,7 +177,7 @@ class HttpPusher(object):
|
|||
"""
|
||||
|
||||
fn = self.store.get_unread_push_actions_for_user_in_range_for_http
|
||||
unprocessed = yield fn(
|
||||
unprocessed = await fn(
|
||||
self.user_id, self.last_stream_ordering, self.max_stream_ordering
|
||||
)
|
||||
|
||||
|
@ -203,13 +199,13 @@ class HttpPusher(object):
|
|||
"app_display_name": self.app_display_name,
|
||||
},
|
||||
):
|
||||
processed = yield self._process_one(push_action)
|
||||
processed = await self._process_one(push_action)
|
||||
|
||||
if processed:
|
||||
http_push_processed_counter.inc()
|
||||
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
|
||||
self.last_stream_ordering = push_action["stream_ordering"]
|
||||
pusher_still_exists = yield self.store.update_pusher_last_stream_ordering_and_success(
|
||||
pusher_still_exists = await self.store.update_pusher_last_stream_ordering_and_success(
|
||||
self.app_id,
|
||||
self.pushkey,
|
||||
self.user_id,
|
||||
|
@ -224,14 +220,14 @@ class HttpPusher(object):
|
|||
|
||||
if self.failing_since:
|
||||
self.failing_since = None
|
||||
yield self.store.update_pusher_failing_since(
|
||||
await self.store.update_pusher_failing_since(
|
||||
self.app_id, self.pushkey, self.user_id, self.failing_since
|
||||
)
|
||||
else:
|
||||
http_push_failed_counter.inc()
|
||||
if not self.failing_since:
|
||||
self.failing_since = self.clock.time_msec()
|
||||
yield self.store.update_pusher_failing_since(
|
||||
await self.store.update_pusher_failing_since(
|
||||
self.app_id, self.pushkey, self.user_id, self.failing_since
|
||||
)
|
||||
|
||||
|
@ -250,7 +246,7 @@ class HttpPusher(object):
|
|||
)
|
||||
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
|
||||
self.last_stream_ordering = push_action["stream_ordering"]
|
||||
pusher_still_exists = yield self.store.update_pusher_last_stream_ordering(
|
||||
pusher_still_exists = await self.store.update_pusher_last_stream_ordering(
|
||||
self.app_id,
|
||||
self.pushkey,
|
||||
self.user_id,
|
||||
|
@ -263,7 +259,7 @@ class HttpPusher(object):
|
|||
return
|
||||
|
||||
self.failing_since = None
|
||||
yield self.store.update_pusher_failing_since(
|
||||
await self.store.update_pusher_failing_since(
|
||||
self.app_id, self.pushkey, self.user_id, self.failing_since
|
||||
)
|
||||
else:
|
||||
|
@ -276,18 +272,17 @@ class HttpPusher(object):
|
|||
)
|
||||
break
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _process_one(self, push_action):
|
||||
async def _process_one(self, push_action):
|
||||
if "notify" not in push_action["actions"]:
|
||||
return True
|
||||
|
||||
tweaks = push_rule_evaluator.tweaks_for_actions(push_action["actions"])
|
||||
badge = yield push_tools.get_badge_count(self.hs.get_datastore(), self.user_id)
|
||||
badge = await push_tools.get_badge_count(self.hs.get_datastore(), self.user_id)
|
||||
|
||||
event = yield self.store.get_event(push_action["event_id"], allow_none=True)
|
||||
event = await self.store.get_event(push_action["event_id"], allow_none=True)
|
||||
if event is None:
|
||||
return True # It's been redacted
|
||||
rejected = yield self.dispatch_push(event, tweaks, badge)
|
||||
rejected = await self.dispatch_push(event, tweaks, badge)
|
||||
if rejected is False:
|
||||
return False
|
||||
|
||||
|
@ -301,11 +296,10 @@ class HttpPusher(object):
|
|||
)
|
||||
else:
|
||||
logger.info("Pushkey %s was rejected: removing", pk)
|
||||
yield self.hs.remove_pusher(self.app_id, pk, self.user_id)
|
||||
await self.hs.remove_pusher(self.app_id, pk, self.user_id)
|
||||
return True
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _build_notification_dict(self, event, tweaks, badge):
|
||||
async def _build_notification_dict(self, event, tweaks, badge):
|
||||
priority = "low"
|
||||
if (
|
||||
event.type == EventTypes.Encrypted
|
||||
|
@ -335,7 +329,7 @@ class HttpPusher(object):
|
|||
}
|
||||
return d
|
||||
|
||||
ctx = yield push_tools.get_context_for_event(
|
||||
ctx = await push_tools.get_context_for_event(
|
||||
self.storage, self.state_handler, event, self.user_id
|
||||
)
|
||||
|
||||
|
@ -377,13 +371,12 @@ class HttpPusher(object):
|
|||
|
||||
return d
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def dispatch_push(self, event, tweaks, badge):
|
||||
notification_dict = yield self._build_notification_dict(event, tweaks, badge)
|
||||
async def dispatch_push(self, event, tweaks, badge):
|
||||
notification_dict = await self._build_notification_dict(event, tweaks, badge)
|
||||
if not notification_dict:
|
||||
return []
|
||||
try:
|
||||
resp = yield self.http_client.post_json_get_json(
|
||||
resp = await self.http_client.post_json_get_json(
|
||||
self.url, notification_dict
|
||||
)
|
||||
except Exception as e:
|
||||
|
@ -400,8 +393,7 @@ class HttpPusher(object):
|
|||
rejected = resp["rejected"]
|
||||
return rejected
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _send_badge(self, badge):
|
||||
async def _send_badge(self, badge):
|
||||
"""
|
||||
Args:
|
||||
badge (int): number of unread messages
|
||||
|
@ -424,7 +416,7 @@ class HttpPusher(object):
|
|||
}
|
||||
}
|
||||
try:
|
||||
yield self.http_client.post_json_get_json(self.url, d)
|
||||
await self.http_client.post_json_get_json(self.url, d)
|
||||
http_badges_processed_counter.inc()
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
|
|
|
@ -16,8 +16,6 @@
|
|||
import logging
|
||||
import re
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -29,8 +27,7 @@ ALIAS_RE = re.compile(r"^#.*:.+$")
|
|||
ALL_ALONE = "Empty Room"
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def calculate_room_name(
|
||||
async def calculate_room_name(
|
||||
store,
|
||||
room_state_ids,
|
||||
user_id,
|
||||
|
@ -53,7 +50,7 @@ def calculate_room_name(
|
|||
"""
|
||||
# does it have a name?
|
||||
if (EventTypes.Name, "") in room_state_ids:
|
||||
m_room_name = yield store.get_event(
|
||||
m_room_name = await store.get_event(
|
||||
room_state_ids[(EventTypes.Name, "")], allow_none=True
|
||||
)
|
||||
if m_room_name and m_room_name.content and m_room_name.content["name"]:
|
||||
|
@ -61,7 +58,7 @@ def calculate_room_name(
|
|||
|
||||
# does it have a canonical alias?
|
||||
if (EventTypes.CanonicalAlias, "") in room_state_ids:
|
||||
canon_alias = yield store.get_event(
|
||||
canon_alias = await store.get_event(
|
||||
room_state_ids[(EventTypes.CanonicalAlias, "")], allow_none=True
|
||||
)
|
||||
if (
|
||||
|
@ -81,7 +78,7 @@ def calculate_room_name(
|
|||
|
||||
my_member_event = None
|
||||
if (EventTypes.Member, user_id) in room_state_ids:
|
||||
my_member_event = yield store.get_event(
|
||||
my_member_event = await store.get_event(
|
||||
room_state_ids[(EventTypes.Member, user_id)], allow_none=True
|
||||
)
|
||||
|
||||
|
@ -90,7 +87,7 @@ def calculate_room_name(
|
|||
and my_member_event.content["membership"] == "invite"
|
||||
):
|
||||
if (EventTypes.Member, my_member_event.sender) in room_state_ids:
|
||||
inviter_member_event = yield store.get_event(
|
||||
inviter_member_event = await store.get_event(
|
||||
room_state_ids[(EventTypes.Member, my_member_event.sender)],
|
||||
allow_none=True,
|
||||
)
|
||||
|
@ -107,7 +104,7 @@ def calculate_room_name(
|
|||
# we're going to have to generate a name based on who's in the room,
|
||||
# so find out who is in the room that isn't the user.
|
||||
if EventTypes.Member in room_state_bytype_ids:
|
||||
member_events = yield store.get_events(
|
||||
member_events = await store.get_events(
|
||||
list(room_state_bytype_ids[EventTypes.Member].values())
|
||||
)
|
||||
all_members = [
|
||||
|
|
|
@ -13,53 +13,40 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.push.presentable_names import calculate_room_name, name_from_member_event
|
||||
from synapse.storage import Storage
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_badge_count(store, user_id):
|
||||
invites = yield store.get_invited_rooms_for_local_user(user_id)
|
||||
joins = yield store.get_rooms_for_user(user_id)
|
||||
|
||||
my_receipts_by_room = yield store.get_receipts_for_user(user_id, "m.read")
|
||||
async def get_badge_count(store, user_id):
|
||||
invites = await store.get_invited_rooms_for_local_user(user_id)
|
||||
joins = await store.get_rooms_for_user(user_id)
|
||||
|
||||
badge = len(invites)
|
||||
|
||||
for room_id in joins:
|
||||
if room_id in my_receipts_by_room:
|
||||
last_unread_event_id = my_receipts_by_room[room_id]
|
||||
|
||||
notifs = yield (
|
||||
store.get_unread_event_push_actions_by_room_for_user(
|
||||
room_id, user_id, last_unread_event_id
|
||||
)
|
||||
)
|
||||
unread_count = await store.get_unread_message_count_for_user(room_id, user_id)
|
||||
# return one badge count per conversation, as count per
|
||||
# message is so noisy as to be almost useless
|
||||
badge += 1 if notifs["notify_count"] else 0
|
||||
badge += 1 if unread_count else 0
|
||||
return badge
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_context_for_event(storage: Storage, state_handler, ev, user_id):
|
||||
async def get_context_for_event(storage: Storage, state_handler, ev, user_id):
|
||||
ctx = {}
|
||||
|
||||
room_state_ids = yield storage.state.get_state_ids_for_event(ev.event_id)
|
||||
room_state_ids = await storage.state.get_state_ids_for_event(ev.event_id)
|
||||
|
||||
# we no longer bother setting room_alias, and make room_name the
|
||||
# human-readable name instead, be that m.room.name, an alias or
|
||||
# a list of people in the room
|
||||
name = yield calculate_room_name(
|
||||
name = await calculate_room_name(
|
||||
storage.main, room_state_ids, user_id, fallback_to_single_member=False
|
||||
)
|
||||
if name:
|
||||
ctx["name"] = name
|
||||
|
||||
sender_state_event_id = room_state_ids[("m.room.member", ev.sender)]
|
||||
sender_state_event = yield storage.main.get_event(sender_state_event_id)
|
||||
sender_state_event = await storage.main.get_event(sender_state_event_id)
|
||||
ctx["sender_display_name"] = name_from_member_event(sender_state_event)
|
||||
|
||||
return ctx
|
||||
|
|
|
@ -19,8 +19,6 @@ from typing import TYPE_CHECKING, Dict, Union
|
|||
|
||||
from prometheus_client import Gauge
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.push import PusherConfigException
|
||||
from synapse.push.emailpusher import EmailPusher
|
||||
|
@ -52,7 +50,7 @@ class PusherPool:
|
|||
Note that it is expected that each pusher will have its own 'processing' loop which
|
||||
will send out the notifications in the background, rather than blocking until the
|
||||
notifications are sent; accordingly Pusher.on_started, Pusher.on_new_notifications and
|
||||
Pusher.on_new_receipts are not expected to return deferreds.
|
||||
Pusher.on_new_receipts are not expected to return awaitables.
|
||||
"""
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
|
@ -77,8 +75,7 @@ class PusherPool:
|
|||
return
|
||||
run_as_background_process("start_pushers", self._start_pushers)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def add_pusher(
|
||||
async def add_pusher(
|
||||
self,
|
||||
user_id,
|
||||
access_token,
|
||||
|
@ -94,7 +91,7 @@ class PusherPool:
|
|||
"""Creates a new pusher and adds it to the pool
|
||||
|
||||
Returns:
|
||||
Deferred[EmailPusher|HttpPusher]
|
||||
EmailPusher|HttpPusher
|
||||
"""
|
||||
|
||||
time_now_msec = self.clock.time_msec()
|
||||
|
@ -124,9 +121,9 @@ class PusherPool:
|
|||
# create the pusher setting last_stream_ordering to the current maximum
|
||||
# stream ordering in event_push_actions, so it will process
|
||||
# pushes from this point onwards.
|
||||
last_stream_ordering = yield self.store.get_latest_push_action_stream_ordering()
|
||||
last_stream_ordering = await self.store.get_latest_push_action_stream_ordering()
|
||||
|
||||
yield self.store.add_pusher(
|
||||
await self.store.add_pusher(
|
||||
user_id=user_id,
|
||||
access_token=access_token,
|
||||
kind=kind,
|
||||
|
@ -140,15 +137,14 @@ class PusherPool:
|
|||
last_stream_ordering=last_stream_ordering,
|
||||
profile_tag=profile_tag,
|
||||
)
|
||||
pusher = yield self.start_pusher_by_id(app_id, pushkey, user_id)
|
||||
pusher = await self.start_pusher_by_id(app_id, pushkey, user_id)
|
||||
|
||||
return pusher
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def remove_pushers_by_app_id_and_pushkey_not_user(
|
||||
async def remove_pushers_by_app_id_and_pushkey_not_user(
|
||||
self, app_id, pushkey, not_user_id
|
||||
):
|
||||
to_remove = yield self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
|
||||
to_remove = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
|
||||
for p in to_remove:
|
||||
if p["user_name"] != not_user_id:
|
||||
logger.info(
|
||||
|
@ -157,10 +153,9 @@ class PusherPool:
|
|||
pushkey,
|
||||
p["user_name"],
|
||||
)
|
||||
yield self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"])
|
||||
await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def remove_pushers_by_access_token(self, user_id, access_tokens):
|
||||
async def remove_pushers_by_access_token(self, user_id, access_tokens):
|
||||
"""Remove the pushers for a given user corresponding to a set of
|
||||
access_tokens.
|
||||
|
||||
|
@ -173,7 +168,7 @@ class PusherPool:
|
|||
return
|
||||
|
||||
tokens = set(access_tokens)
|
||||
for p in (yield self.store.get_pushers_by_user_id(user_id)):
|
||||
for p in await self.store.get_pushers_by_user_id(user_id):
|
||||
if p["access_token"] in tokens:
|
||||
logger.info(
|
||||
"Removing pusher for app id %s, pushkey %s, user %s",
|
||||
|
@ -181,16 +176,15 @@ class PusherPool:
|
|||
p["pushkey"],
|
||||
p["user_name"],
|
||||
)
|
||||
yield self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"])
|
||||
await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_new_notifications(self, min_stream_id, max_stream_id):
|
||||
async def on_new_notifications(self, min_stream_id, max_stream_id):
|
||||
if not self.pushers:
|
||||
# nothing to do here.
|
||||
return
|
||||
|
||||
try:
|
||||
users_affected = yield self.store.get_push_action_users_in_range(
|
||||
users_affected = await self.store.get_push_action_users_in_range(
|
||||
min_stream_id, max_stream_id
|
||||
)
|
||||
|
||||
|
@ -202,8 +196,7 @@ class PusherPool:
|
|||
except Exception:
|
||||
logger.exception("Exception in pusher on_new_notifications")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids):
|
||||
async def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids):
|
||||
if not self.pushers:
|
||||
# nothing to do here.
|
||||
return
|
||||
|
@ -211,7 +204,7 @@ class PusherPool:
|
|||
try:
|
||||
# Need to subtract 1 from the minimum because the lower bound here
|
||||
# is not inclusive
|
||||
users_affected = yield self.store.get_users_sent_receipts_between(
|
||||
users_affected = await self.store.get_users_sent_receipts_between(
|
||||
min_stream_id - 1, max_stream_id
|
||||
)
|
||||
|
||||
|
@ -223,12 +216,11 @@ class PusherPool:
|
|||
except Exception:
|
||||
logger.exception("Exception in pusher on_new_receipts")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def start_pusher_by_id(self, app_id, pushkey, user_id):
|
||||
async def start_pusher_by_id(self, app_id, pushkey, user_id):
|
||||
"""Look up the details for the given pusher, and start it
|
||||
|
||||
Returns:
|
||||
Deferred[EmailPusher|HttpPusher|None]: The pusher started, if any
|
||||
EmailPusher|HttpPusher|None: The pusher started, if any
|
||||
"""
|
||||
if not self._should_start_pushers:
|
||||
return
|
||||
|
@ -236,7 +228,7 @@ class PusherPool:
|
|||
if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
|
||||
return
|
||||
|
||||
resultlist = yield self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
|
||||
resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
|
||||
|
||||
pusher_dict = None
|
||||
for r in resultlist:
|
||||
|
@ -245,34 +237,29 @@ class PusherPool:
|
|||
|
||||
pusher = None
|
||||
if pusher_dict:
|
||||
pusher = yield self._start_pusher(pusher_dict)
|
||||
pusher = await self._start_pusher(pusher_dict)
|
||||
|
||||
return pusher
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _start_pushers(self):
|
||||
async def _start_pushers(self) -> None:
|
||||
"""Start all the pushers
|
||||
|
||||
Returns:
|
||||
Deferred
|
||||
"""
|
||||
pushers = yield self.store.get_all_pushers()
|
||||
pushers = await self.store.get_all_pushers()
|
||||
|
||||
# Stagger starting up the pushers so we don't completely drown the
|
||||
# process on start up.
|
||||
yield concurrently_execute(self._start_pusher, pushers, 10)
|
||||
await concurrently_execute(self._start_pusher, pushers, 10)
|
||||
|
||||
logger.info("Started pushers")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _start_pusher(self, pusherdict):
|
||||
async def _start_pusher(self, pusherdict):
|
||||
"""Start the given pusher
|
||||
|
||||
Args:
|
||||
pusherdict (dict): dict with the values pulled from the db table
|
||||
|
||||
Returns:
|
||||
Deferred[EmailPusher|HttpPusher]
|
||||
EmailPusher|HttpPusher
|
||||
"""
|
||||
if not self._pusher_shard_config.should_handle(
|
||||
self._instance_name, pusherdict["user_name"]
|
||||
|
@ -315,7 +302,7 @@ class PusherPool:
|
|||
user_id = pusherdict["user_name"]
|
||||
last_stream_ordering = pusherdict["last_stream_ordering"]
|
||||
if last_stream_ordering:
|
||||
have_notifs = yield self.store.get_if_maybe_push_in_range_for_user(
|
||||
have_notifs = await self.store.get_if_maybe_push_in_range_for_user(
|
||||
user_id, last_stream_ordering
|
||||
)
|
||||
else:
|
||||
|
@ -327,8 +314,7 @@ class PusherPool:
|
|||
|
||||
return p
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def remove_pusher(self, app_id, pushkey, user_id):
|
||||
async def remove_pusher(self, app_id, pushkey, user_id):
|
||||
appid_pushkey = "%s:%s" % (app_id, pushkey)
|
||||
|
||||
byuser = self.pushers.get(user_id, {})
|
||||
|
@ -340,6 +326,6 @@ class PusherPool:
|
|||
|
||||
synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec()
|
||||
|
||||
yield self.store.delete_pusher_by_app_id_pushkey_user_id(
|
||||
await self.store.delete_pusher_by_app_id_pushkey_user_id(
|
||||
app_id, pushkey, user_id
|
||||
)
|
||||
|
|
|
@ -43,7 +43,7 @@ REQUIREMENTS = [
|
|||
"jsonschema>=2.5.1",
|
||||
"frozendict>=1",
|
||||
"unpaddedbase64>=1.1.0",
|
||||
"canonicaljson>=1.1.3",
|
||||
"canonicaljson>=1.2.0",
|
||||
# we use the type definitions added in signedjson 1.1.
|
||||
"signedjson>=1.1.0",
|
||||
"pynacl>=1.2.1",
|
||||
|
|
|
@ -78,7 +78,9 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
|
|||
"""
|
||||
event_payloads = []
|
||||
for event, context in event_and_contexts:
|
||||
serialized_context = yield context.serialize(event, store)
|
||||
serialized_context = yield defer.ensureDeferred(
|
||||
context.serialize(event, store)
|
||||
)
|
||||
|
||||
event_payloads.append(
|
||||
{
|
||||
|
|
|
@ -77,7 +77,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
|
|||
extra_users (list(UserID)): Any extra users to notify about event
|
||||
"""
|
||||
|
||||
serialized_context = yield context.serialize(event, store)
|
||||
serialized_context = yield defer.ensureDeferred(context.serialize(event, store))
|
||||
|
||||
payload = {
|
||||
"event": event.get_pdu_json(),
|
||||
|
|
|
@ -103,6 +103,14 @@ class DeleteRoomRestServlet(RestServlet):
|
|||
Codes.BAD_JSON,
|
||||
)
|
||||
|
||||
purge = content.get("purge", True)
|
||||
if not isinstance(purge, bool):
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
"Param 'purge' must be a boolean, if given",
|
||||
Codes.BAD_JSON,
|
||||
)
|
||||
|
||||
ret = await self.room_shutdown_handler.shutdown_room(
|
||||
room_id=room_id,
|
||||
new_room_user_id=content.get("new_room_user_id"),
|
||||
|
@ -113,6 +121,7 @@ class DeleteRoomRestServlet(RestServlet):
|
|||
)
|
||||
|
||||
# Purge room
|
||||
if purge:
|
||||
await self.pagination_handler.purge_room(room_id)
|
||||
|
||||
return (200, ret)
|
||||
|
|
|
@ -426,6 +426,7 @@ class SyncRestServlet(RestServlet):
|
|||
result["ephemeral"] = {"events": ephemeral_events}
|
||||
result["unread_notifications"] = room.unread_notifications
|
||||
result["summary"] = room.summary
|
||||
result["org.matrix.msc2654.unread_count"] = room.unread_count
|
||||
|
||||
return result
|
||||
|
||||
|
|
|
@ -17,7 +17,9 @@
|
|||
import logging
|
||||
import os
|
||||
import urllib
|
||||
from typing import Awaitable
|
||||
|
||||
from twisted.internet.interfaces import IConsumer
|
||||
from twisted.protocols.basic import FileSender
|
||||
|
||||
from synapse.api.errors import Codes, SynapseError, cs_error
|
||||
|
@ -240,14 +242,14 @@ class Responder(object):
|
|||
held can be cleaned up.
|
||||
"""
|
||||
|
||||
def write_to_consumer(self, consumer):
|
||||
def write_to_consumer(self, consumer: IConsumer) -> Awaitable:
|
||||
"""Stream response into consumer
|
||||
|
||||
Args:
|
||||
consumer (IConsumer)
|
||||
consumer: The consumer to stream into.
|
||||
|
||||
Returns:
|
||||
Deferred: Resolves once the response has finished being written
|
||||
Resolves once the response has finished being written
|
||||
"""
|
||||
pass
|
||||
|
||||
|
|
|
@ -18,10 +18,11 @@ import errno
|
|||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from typing import Dict, Tuple
|
||||
from typing import IO, Dict, Optional, Tuple
|
||||
|
||||
import twisted.internet.error
|
||||
import twisted.web.http
|
||||
from twisted.web.http import Request
|
||||
from twisted.web.resource import Resource
|
||||
|
||||
from synapse.api.errors import (
|
||||
|
@ -40,6 +41,7 @@ from synapse.util.stringutils import random_string
|
|||
|
||||
from ._base import (
|
||||
FileInfo,
|
||||
Responder,
|
||||
get_filename_from_headers,
|
||||
respond_404,
|
||||
respond_with_responder,
|
||||
|
@ -135,19 +137,24 @@ class MediaRepository(object):
|
|||
self.recently_accessed_locals.add(media_id)
|
||||
|
||||
async def create_content(
|
||||
self, media_type, upload_name, content, content_length, auth_user
|
||||
):
|
||||
self,
|
||||
media_type: str,
|
||||
upload_name: str,
|
||||
content: IO,
|
||||
content_length: int,
|
||||
auth_user: str,
|
||||
) -> str:
|
||||
"""Store uploaded content for a local user and return the mxc URL
|
||||
|
||||
Args:
|
||||
media_type(str): The content type of the file
|
||||
upload_name(str): The name of the file
|
||||
media_type: The content type of the file
|
||||
upload_name: The name of the file
|
||||
content: A file like object that is the content to store
|
||||
content_length(int): The length of the content
|
||||
auth_user(str): The user_id of the uploader
|
||||
content_length: The length of the content
|
||||
auth_user: The user_id of the uploader
|
||||
|
||||
Returns:
|
||||
Deferred[str]: The mxc url of the stored content
|
||||
The mxc url of the stored content
|
||||
"""
|
||||
media_id = random_string(24)
|
||||
|
||||
|
@ -170,19 +177,20 @@ class MediaRepository(object):
|
|||
|
||||
return "mxc://%s/%s" % (self.server_name, media_id)
|
||||
|
||||
async def get_local_media(self, request, media_id, name):
|
||||
async def get_local_media(
|
||||
self, request: Request, media_id: str, name: Optional[str]
|
||||
) -> None:
|
||||
"""Responds to reqests for local media, if exists, or returns 404.
|
||||
|
||||
Args:
|
||||
request(twisted.web.http.Request)
|
||||
media_id (str): The media ID of the content. (This is the same as
|
||||
request: The incoming request.
|
||||
media_id: The media ID of the content. (This is the same as
|
||||
the file_id for local content.)
|
||||
name (str|None): Optional name that, if specified, will be used as
|
||||
name: Optional name that, if specified, will be used as
|
||||
the filename in the Content-Disposition header of the response.
|
||||
|
||||
Returns:
|
||||
Deferred: Resolves once a response has successfully been written
|
||||
to request
|
||||
Resolves once a response has successfully been written to request
|
||||
"""
|
||||
media_info = await self.store.get_local_media(media_id)
|
||||
if not media_info or media_info["quarantined_by"]:
|
||||
|
@ -203,20 +211,20 @@ class MediaRepository(object):
|
|||
request, responder, media_type, media_length, upload_name
|
||||
)
|
||||
|
||||
async def get_remote_media(self, request, server_name, media_id, name):
|
||||
async def get_remote_media(
|
||||
self, request: Request, server_name: str, media_id: str, name: Optional[str]
|
||||
) -> None:
|
||||
"""Respond to requests for remote media.
|
||||
|
||||
Args:
|
||||
request(twisted.web.http.Request)
|
||||
server_name (str): Remote server_name where the media originated.
|
||||
media_id (str): The media ID of the content (as defined by the
|
||||
remote server).
|
||||
name (str|None): Optional name that, if specified, will be used as
|
||||
request: The incoming request.
|
||||
server_name: Remote server_name where the media originated.
|
||||
media_id: The media ID of the content (as defined by the remote server).
|
||||
name: Optional name that, if specified, will be used as
|
||||
the filename in the Content-Disposition header of the response.
|
||||
|
||||
Returns:
|
||||
Deferred: Resolves once a response has successfully been written
|
||||
to request
|
||||
Resolves once a response has successfully been written to request
|
||||
"""
|
||||
if (
|
||||
self.federation_domain_whitelist is not None
|
||||
|
@ -245,17 +253,16 @@ class MediaRepository(object):
|
|||
else:
|
||||
respond_404(request)
|
||||
|
||||
async def get_remote_media_info(self, server_name, media_id):
|
||||
async def get_remote_media_info(self, server_name: str, media_id: str) -> dict:
|
||||
"""Gets the media info associated with the remote file, downloading
|
||||
if necessary.
|
||||
|
||||
Args:
|
||||
server_name (str): Remote server_name where the media originated.
|
||||
media_id (str): The media ID of the content (as defined by the
|
||||
remote server).
|
||||
server_name: Remote server_name where the media originated.
|
||||
media_id: The media ID of the content (as defined by the remote server).
|
||||
|
||||
Returns:
|
||||
Deferred[dict]: The media_info of the file
|
||||
The media info of the file
|
||||
"""
|
||||
if (
|
||||
self.federation_domain_whitelist is not None
|
||||
|
@ -278,7 +285,9 @@ class MediaRepository(object):
|
|||
|
||||
return media_info
|
||||
|
||||
async def _get_remote_media_impl(self, server_name, media_id):
|
||||
async def _get_remote_media_impl(
|
||||
self, server_name: str, media_id: str
|
||||
) -> Tuple[Optional[Responder], dict]:
|
||||
"""Looks for media in local cache, if not there then attempt to
|
||||
download from remote server.
|
||||
|
||||
|
@ -288,7 +297,7 @@ class MediaRepository(object):
|
|||
remote server).
|
||||
|
||||
Returns:
|
||||
Deferred[(Responder, media_info)]
|
||||
A tuple of responder and the media info of the file.
|
||||
"""
|
||||
media_info = await self.store.get_cached_remote_media(server_name, media_id)
|
||||
|
||||
|
@ -319,19 +328,21 @@ class MediaRepository(object):
|
|||
responder = await self.media_storage.fetch_media(file_info)
|
||||
return responder, media_info
|
||||
|
||||
async def _download_remote_file(self, server_name, media_id, file_id):
|
||||
async def _download_remote_file(
|
||||
self, server_name: str, media_id: str, file_id: str
|
||||
) -> dict:
|
||||
"""Attempt to download the remote file from the given server name,
|
||||
using the given file_id as the local id.
|
||||
|
||||
Args:
|
||||
server_name (str): Originating server
|
||||
media_id (str): The media ID of the content (as defined by the
|
||||
server_name: Originating server
|
||||
media_id: The media ID of the content (as defined by the
|
||||
remote server). This is different than the file_id, which is
|
||||
locally generated.
|
||||
file_id (str): Local file ID
|
||||
file_id: Local file ID
|
||||
|
||||
Returns:
|
||||
Deferred[MediaInfo]
|
||||
The media info of the file.
|
||||
"""
|
||||
|
||||
file_info = FileInfo(server_name=server_name, file_id=file_id)
|
||||
|
@ -549,25 +560,31 @@ class MediaRepository(object):
|
|||
return output_path
|
||||
|
||||
async def _generate_thumbnails(
|
||||
self, server_name, media_id, file_id, media_type, url_cache=False
|
||||
):
|
||||
self,
|
||||
server_name: Optional[str],
|
||||
media_id: str,
|
||||
file_id: str,
|
||||
media_type: str,
|
||||
url_cache: bool = False,
|
||||
) -> Optional[dict]:
|
||||
"""Generate and store thumbnails for an image.
|
||||
|
||||
Args:
|
||||
server_name (str|None): The server name if remote media, else None if local
|
||||
media_id (str): The media ID of the content. (This is the same as
|
||||
server_name: The server name if remote media, else None if local
|
||||
media_id: The media ID of the content. (This is the same as
|
||||
the file_id for local content)
|
||||
file_id (str): Local file ID
|
||||
media_type (str): The content type of the file
|
||||
url_cache (bool): If we are thumbnailing images downloaded for the URL cache,
|
||||
file_id: Local file ID
|
||||
media_type: The content type of the file
|
||||
url_cache: If we are thumbnailing images downloaded for the URL cache,
|
||||
used exclusively by the url previewer
|
||||
|
||||
Returns:
|
||||
Deferred[dict]: Dict with "width" and "height" keys of original image
|
||||
Dict with "width" and "height" keys of original image or None if the
|
||||
media cannot be thumbnailed.
|
||||
"""
|
||||
requirements = self._get_thumbnail_requirements(media_type)
|
||||
if not requirements:
|
||||
return
|
||||
return None
|
||||
|
||||
input_path = await self.media_storage.ensure_media_is_in_local_cache(
|
||||
FileInfo(server_name, file_id, url_cache=url_cache)
|
||||
|
@ -584,7 +601,7 @@ class MediaRepository(object):
|
|||
m_height,
|
||||
self.max_image_pixels,
|
||||
)
|
||||
return
|
||||
return None
|
||||
|
||||
if thumbnailer.transpose_method is not None:
|
||||
m_width, m_height = await defer_to_thread(
|
||||
|
|
|
@ -12,13 +12,12 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from typing import Optional
|
||||
from typing import IO, TYPE_CHECKING, Any, Optional, Sequence
|
||||
|
||||
from twisted.protocols.basic import FileSender
|
||||
|
||||
|
@ -26,6 +25,12 @@ from synapse.logging.context import defer_to_thread, make_deferred_yieldable
|
|||
from synapse.util.file_consumer import BackgroundFileConsumer
|
||||
|
||||
from ._base import FileInfo, Responder
|
||||
from .filepath import MediaFilePaths
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
from .storage_provider import StorageProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -34,20 +39,25 @@ class MediaStorage(object):
|
|||
"""Responsible for storing/fetching files from local sources.
|
||||
|
||||
Args:
|
||||
hs (synapse.server.Homeserver)
|
||||
local_media_directory (str): Base path where we store media on disk
|
||||
filepaths (MediaFilePaths)
|
||||
storage_providers ([StorageProvider]): List of StorageProvider that are
|
||||
used to fetch and store files.
|
||||
hs
|
||||
local_media_directory: Base path where we store media on disk
|
||||
filepaths
|
||||
storage_providers: List of StorageProvider that are used to fetch and store files.
|
||||
"""
|
||||
|
||||
def __init__(self, hs, local_media_directory, filepaths, storage_providers):
|
||||
def __init__(
|
||||
self,
|
||||
hs: "HomeServer",
|
||||
local_media_directory: str,
|
||||
filepaths: MediaFilePaths,
|
||||
storage_providers: Sequence["StorageProvider"],
|
||||
):
|
||||
self.hs = hs
|
||||
self.local_media_directory = local_media_directory
|
||||
self.filepaths = filepaths
|
||||
self.storage_providers = storage_providers
|
||||
|
||||
async def store_file(self, source, file_info: FileInfo) -> str:
|
||||
async def store_file(self, source: IO, file_info: FileInfo) -> str:
|
||||
"""Write `source` to the on disk media store, and also any other
|
||||
configured storage providers
|
||||
|
||||
|
@ -69,7 +79,7 @@ class MediaStorage(object):
|
|||
return fname
|
||||
|
||||
@contextlib.contextmanager
|
||||
def store_into_file(self, file_info):
|
||||
def store_into_file(self, file_info: FileInfo):
|
||||
"""Context manager used to get a file like object to write into, as
|
||||
described by file_info.
|
||||
|
||||
|
@ -85,7 +95,7 @@ class MediaStorage(object):
|
|||
error.
|
||||
|
||||
Args:
|
||||
file_info (FileInfo): Info about the file to store
|
||||
file_info: Info about the file to store
|
||||
|
||||
Example:
|
||||
|
||||
|
@ -143,9 +153,9 @@ class MediaStorage(object):
|
|||
return FileResponder(open(local_path, "rb"))
|
||||
|
||||
for provider in self.storage_providers:
|
||||
res = provider.fetch(path, file_info)
|
||||
# Fetch is supposed to return an Awaitable, but guard against
|
||||
# improper implementations.
|
||||
res = provider.fetch(path, file_info) # type: Any
|
||||
# Fetch is supposed to return an Awaitable[Responder], but guard
|
||||
# against improper implementations.
|
||||
if inspect.isawaitable(res):
|
||||
res = await res
|
||||
if res:
|
||||
|
@ -174,9 +184,9 @@ class MediaStorage(object):
|
|||
os.makedirs(dirname)
|
||||
|
||||
for provider in self.storage_providers:
|
||||
res = provider.fetch(path, file_info)
|
||||
# Fetch is supposed to return an Awaitable, but guard against
|
||||
# improper implementations.
|
||||
res = provider.fetch(path, file_info) # type: Any
|
||||
# Fetch is supposed to return an Awaitable[Responder], but guard
|
||||
# against improper implementations.
|
||||
if inspect.isawaitable(res):
|
||||
res = await res
|
||||
if res:
|
||||
|
@ -190,17 +200,11 @@ class MediaStorage(object):
|
|||
|
||||
raise Exception("file could not be found")
|
||||
|
||||
def _file_info_to_path(self, file_info):
|
||||
def _file_info_to_path(self, file_info: FileInfo) -> str:
|
||||
"""Converts file_info into a relative path.
|
||||
|
||||
The path is suitable for storing files under a directory, e.g. used to
|
||||
store files on local FS under the base media repository directory.
|
||||
|
||||
Args:
|
||||
file_info (FileInfo)
|
||||
|
||||
Returns:
|
||||
str
|
||||
"""
|
||||
if file_info.url_cache:
|
||||
if file_info.thumbnail:
|
||||
|
|
|
@ -231,16 +231,16 @@ class PreviewUrlResource(DirectServeJsonResource):
|
|||
og = await make_deferred_yieldable(defer.maybeDeferred(observable.observe))
|
||||
respond_with_json_bytes(request, 200, og, send_cors=True)
|
||||
|
||||
async def _do_preview(self, url, user, ts):
|
||||
async def _do_preview(self, url: str, user: str, ts: int) -> bytes:
|
||||
"""Check the db, and download the URL and build a preview
|
||||
|
||||
Args:
|
||||
url (str):
|
||||
user (str):
|
||||
ts (int):
|
||||
url: The URL to preview.
|
||||
user: The user requesting the preview.
|
||||
ts: The timestamp requested for the preview.
|
||||
|
||||
Returns:
|
||||
Deferred[bytes]: json-encoded og data
|
||||
json-encoded og data
|
||||
"""
|
||||
# check the URL cache in the DB (which will also provide us with
|
||||
# historical previews, if we have any)
|
||||
|
|
|
@ -16,62 +16,62 @@
|
|||
import logging
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from twisted.internet import defer
|
||||
from typing import Optional
|
||||
|
||||
from synapse.config._base import Config
|
||||
from synapse.logging.context import defer_to_thread, run_in_background
|
||||
|
||||
from ._base import FileInfo, Responder
|
||||
from .media_storage import FileResponder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StorageProvider(object):
|
||||
class StorageProvider:
|
||||
"""A storage provider is a service that can store uploaded media and
|
||||
retrieve them.
|
||||
"""
|
||||
|
||||
def store_file(self, path, file_info):
|
||||
async def store_file(self, path: str, file_info: FileInfo):
|
||||
"""Store the file described by file_info. The actual contents can be
|
||||
retrieved by reading the file in file_info.upload_path.
|
||||
|
||||
Args:
|
||||
path (str): Relative path of file in local cache
|
||||
file_info (FileInfo)
|
||||
|
||||
Returns:
|
||||
Deferred
|
||||
path: Relative path of file in local cache
|
||||
file_info: The metadata of the file.
|
||||
"""
|
||||
pass
|
||||
|
||||
def fetch(self, path, file_info):
|
||||
async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
|
||||
"""Attempt to fetch the file described by file_info and stream it
|
||||
into writer.
|
||||
|
||||
Args:
|
||||
path (str): Relative path of file in local cache
|
||||
file_info (FileInfo)
|
||||
path: Relative path of file in local cache
|
||||
file_info: The metadata of the file.
|
||||
|
||||
Returns:
|
||||
Deferred(Responder): Returns a Responder if the provider has the file,
|
||||
otherwise returns None.
|
||||
Returns a Responder if the provider has the file, otherwise returns None.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class StorageProviderWrapper(StorageProvider):
|
||||
"""Wraps a storage provider and provides various config options
|
||||
|
||||
Args:
|
||||
backend (StorageProvider)
|
||||
store_local (bool): Whether to store new local files or not.
|
||||
store_synchronous (bool): Whether to wait for file to be successfully
|
||||
backend: The storage provider to wrap.
|
||||
store_local: Whether to store new local files or not.
|
||||
store_synchronous: Whether to wait for file to be successfully
|
||||
uploaded, or todo the upload in the background.
|
||||
store_remote (bool): Whether remote media should be uploaded
|
||||
store_remote: Whether remote media should be uploaded
|
||||
"""
|
||||
|
||||
def __init__(self, backend, store_local, store_synchronous, store_remote):
|
||||
def __init__(
|
||||
self,
|
||||
backend: StorageProvider,
|
||||
store_local: bool,
|
||||
store_synchronous: bool,
|
||||
store_remote: bool,
|
||||
):
|
||||
self.backend = backend
|
||||
self.store_local = store_local
|
||||
self.store_synchronous = store_synchronous
|
||||
|
@ -80,15 +80,15 @@ class StorageProviderWrapper(StorageProvider):
|
|||
def __str__(self):
|
||||
return "StorageProviderWrapper[%s]" % (self.backend,)
|
||||
|
||||
def store_file(self, path, file_info):
|
||||
async def store_file(self, path, file_info):
|
||||
if not file_info.server_name and not self.store_local:
|
||||
return defer.succeed(None)
|
||||
return None
|
||||
|
||||
if file_info.server_name and not self.store_remote:
|
||||
return defer.succeed(None)
|
||||
return None
|
||||
|
||||
if self.store_synchronous:
|
||||
return self.backend.store_file(path, file_info)
|
||||
return await self.backend.store_file(path, file_info)
|
||||
else:
|
||||
# TODO: Handle errors.
|
||||
def store():
|
||||
|
@ -98,10 +98,10 @@ class StorageProviderWrapper(StorageProvider):
|
|||
logger.exception("Error storing file")
|
||||
|
||||
run_in_background(store)
|
||||
return defer.succeed(None)
|
||||
return None
|
||||
|
||||
def fetch(self, path, file_info):
|
||||
return self.backend.fetch(path, file_info)
|
||||
async def fetch(self, path, file_info):
|
||||
return await self.backend.fetch(path, file_info)
|
||||
|
||||
|
||||
class FileStorageProviderBackend(StorageProvider):
|
||||
|
@ -120,7 +120,7 @@ class FileStorageProviderBackend(StorageProvider):
|
|||
def __str__(self):
|
||||
return "FileStorageProviderBackend[%s]" % (self.base_directory,)
|
||||
|
||||
def store_file(self, path, file_info):
|
||||
async def store_file(self, path, file_info):
|
||||
"""See StorageProvider.store_file"""
|
||||
|
||||
primary_fname = os.path.join(self.cache_directory, path)
|
||||
|
@ -130,11 +130,11 @@ class FileStorageProviderBackend(StorageProvider):
|
|||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
|
||||
return defer_to_thread(
|
||||
return await defer_to_thread(
|
||||
self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname
|
||||
)
|
||||
|
||||
def fetch(self, path, file_info):
|
||||
async def fetch(self, path, file_info):
|
||||
"""See StorageProvider.fetch"""
|
||||
|
||||
backup_fname = os.path.join(self.base_directory, path)
|
||||
|
|
|
@ -172,6 +172,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
|||
|
||||
self.get_latest_event_ids_in_room.invalidate((room_id,))
|
||||
|
||||
self.get_unread_message_count_for_user.invalidate_many((room_id,))
|
||||
self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,))
|
||||
|
||||
if not backfilled:
|
||||
|
|
|
@ -15,11 +15,10 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from canonicaljson import json
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
|
||||
from synapse.storage.database import Database
|
||||
|
@ -166,8 +165,9 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
|||
|
||||
return {"notify_count": notify_count, "highlight_count": highlight_count}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_push_action_users_in_range(self, min_stream_ordering, max_stream_ordering):
|
||||
async def get_push_action_users_in_range(
|
||||
self, min_stream_ordering, max_stream_ordering
|
||||
):
|
||||
def f(txn):
|
||||
sql = (
|
||||
"SELECT DISTINCT(user_id) FROM event_push_actions WHERE"
|
||||
|
@ -176,26 +176,28 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
|||
txn.execute(sql, (min_stream_ordering, max_stream_ordering))
|
||||
return [r[0] for r in txn]
|
||||
|
||||
ret = yield self.db.runInteraction("get_push_action_users_in_range", f)
|
||||
ret = await self.db.runInteraction("get_push_action_users_in_range", f)
|
||||
return ret
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_unread_push_actions_for_user_in_range_for_http(
|
||||
self, user_id, min_stream_ordering, max_stream_ordering, limit=20
|
||||
):
|
||||
async def get_unread_push_actions_for_user_in_range_for_http(
|
||||
self,
|
||||
user_id: str,
|
||||
min_stream_ordering: int,
|
||||
max_stream_ordering: int,
|
||||
limit: int = 20,
|
||||
) -> List[dict]:
|
||||
"""Get a list of the most recent unread push actions for a given user,
|
||||
within the given stream ordering range. Called by the httppusher.
|
||||
|
||||
Args:
|
||||
user_id (str): The user to fetch push actions for.
|
||||
min_stream_ordering(int): The exclusive lower bound on the
|
||||
user_id: The user to fetch push actions for.
|
||||
min_stream_ordering: The exclusive lower bound on the
|
||||
stream ordering of event push actions to fetch.
|
||||
max_stream_ordering(int): The inclusive upper bound on the
|
||||
max_stream_ordering: The inclusive upper bound on the
|
||||
stream ordering of event push actions to fetch.
|
||||
limit (int): The maximum number of rows to return.
|
||||
limit: The maximum number of rows to return.
|
||||
Returns:
|
||||
A promise which resolves to a list of dicts with the keys "event_id",
|
||||
"room_id", "stream_ordering", "actions".
|
||||
A list of dicts with the keys "event_id", "room_id", "stream_ordering", "actions".
|
||||
The list will be ordered by ascending stream_ordering.
|
||||
The list will have between 0~limit entries.
|
||||
"""
|
||||
|
@ -228,7 +230,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
|||
txn.execute(sql, args)
|
||||
return txn.fetchall()
|
||||
|
||||
after_read_receipt = yield self.db.runInteraction(
|
||||
after_read_receipt = await self.db.runInteraction(
|
||||
"get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt
|
||||
)
|
||||
|
||||
|
@ -256,7 +258,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
|||
txn.execute(sql, args)
|
||||
return txn.fetchall()
|
||||
|
||||
no_read_receipt = yield self.db.runInteraction(
|
||||
no_read_receipt = await self.db.runInteraction(
|
||||
"get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt
|
||||
)
|
||||
|
||||
|
@ -280,23 +282,25 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
|||
# one of the subqueries may have hit the limit.
|
||||
return notifs[:limit]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_unread_push_actions_for_user_in_range_for_email(
|
||||
self, user_id, min_stream_ordering, max_stream_ordering, limit=20
|
||||
):
|
||||
async def get_unread_push_actions_for_user_in_range_for_email(
|
||||
self,
|
||||
user_id: str,
|
||||
min_stream_ordering: int,
|
||||
max_stream_ordering: int,
|
||||
limit: int = 20,
|
||||
) -> List[dict]:
|
||||
"""Get a list of the most recent unread push actions for a given user,
|
||||
within the given stream ordering range. Called by the emailpusher
|
||||
|
||||
Args:
|
||||
user_id (str): The user to fetch push actions for.
|
||||
min_stream_ordering(int): The exclusive lower bound on the
|
||||
user_id: The user to fetch push actions for.
|
||||
min_stream_ordering: The exclusive lower bound on the
|
||||
stream ordering of event push actions to fetch.
|
||||
max_stream_ordering(int): The inclusive upper bound on the
|
||||
max_stream_ordering: The inclusive upper bound on the
|
||||
stream ordering of event push actions to fetch.
|
||||
limit (int): The maximum number of rows to return.
|
||||
limit: The maximum number of rows to return.
|
||||
Returns:
|
||||
A promise which resolves to a list of dicts with the keys "event_id",
|
||||
"room_id", "stream_ordering", "actions", "received_ts".
|
||||
A list of dicts with the keys "event_id", "room_id", "stream_ordering", "actions", "received_ts".
|
||||
The list will be ordered by descending received_ts.
|
||||
The list will have between 0~limit entries.
|
||||
"""
|
||||
|
@ -328,7 +332,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
|||
txn.execute(sql, args)
|
||||
return txn.fetchall()
|
||||
|
||||
after_read_receipt = yield self.db.runInteraction(
|
||||
after_read_receipt = await self.db.runInteraction(
|
||||
"get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt
|
||||
)
|
||||
|
||||
|
@ -356,7 +360,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
|||
txn.execute(sql, args)
|
||||
return txn.fetchall()
|
||||
|
||||
no_read_receipt = yield self.db.runInteraction(
|
||||
no_read_receipt = await self.db.runInteraction(
|
||||
"get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt
|
||||
)
|
||||
|
||||
|
@ -411,7 +415,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
|||
_get_if_maybe_push_in_range_for_user_txn,
|
||||
)
|
||||
|
||||
def add_push_actions_to_staging(self, event_id, user_id_actions):
|
||||
async def add_push_actions_to_staging(self, event_id, user_id_actions):
|
||||
"""Add the push actions for the event to the push action staging area.
|
||||
|
||||
Args:
|
||||
|
@ -457,21 +461,17 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
|||
),
|
||||
)
|
||||
|
||||
return self.db.runInteraction(
|
||||
return await self.db.runInteraction(
|
||||
"add_push_actions_to_staging", _add_push_actions_to_staging_txn
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def remove_push_actions_from_staging(self, event_id):
|
||||
async def remove_push_actions_from_staging(self, event_id: str) -> None:
|
||||
"""Called if we failed to persist the event to ensure that stale push
|
||||
actions don't build up in the DB
|
||||
|
||||
Args:
|
||||
event_id (str)
|
||||
"""
|
||||
|
||||
try:
|
||||
res = yield self.db.simple_delete(
|
||||
res = await self.db.simple_delete(
|
||||
table="event_push_actions_staging",
|
||||
keyvalues={"event_id": event_id},
|
||||
desc="remove_push_actions_from_staging",
|
||||
|
@ -606,8 +606,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
|||
|
||||
return range_end
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_time_of_last_push_action_before(self, stream_ordering):
|
||||
async def get_time_of_last_push_action_before(self, stream_ordering):
|
||||
def f(txn):
|
||||
sql = (
|
||||
"SELECT e.received_ts"
|
||||
|
@ -620,7 +619,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
|||
txn.execute(sql, (stream_ordering,))
|
||||
return txn.fetchone()
|
||||
|
||||
result = yield self.db.runInteraction("get_time_of_last_push_action_before", f)
|
||||
result = await self.db.runInteraction("get_time_of_last_push_action_before", f)
|
||||
return result[0] if result else None
|
||||
|
||||
|
||||
|
@ -650,8 +649,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
|
|||
self._start_rotate_notifs, 30 * 60 * 1000
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_push_actions_for_user(
|
||||
async def get_push_actions_for_user(
|
||||
self, user_id, before=None, limit=50, only_highlight=False
|
||||
):
|
||||
def f(txn):
|
||||
|
@ -682,18 +680,17 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
|
|||
txn.execute(sql, args)
|
||||
return self.db.cursor_to_dict(txn)
|
||||
|
||||
push_actions = yield self.db.runInteraction("get_push_actions_for_user", f)
|
||||
push_actions = await self.db.runInteraction("get_push_actions_for_user", f)
|
||||
for pa in push_actions:
|
||||
pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
|
||||
return push_actions
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_latest_push_action_stream_ordering(self):
|
||||
async def get_latest_push_action_stream_ordering(self):
|
||||
def f(txn):
|
||||
txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions")
|
||||
return txn.fetchone()
|
||||
|
||||
result = yield self.db.runInteraction(
|
||||
result = await self.db.runInteraction(
|
||||
"get_latest_push_action_stream_ordering", f
|
||||
)
|
||||
return result[0] or 0
|
||||
|
@ -747,8 +744,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
|
|||
def _start_rotate_notifs(self):
|
||||
return run_as_background_process("rotate_notifs", self._rotate_notifs)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _rotate_notifs(self):
|
||||
async def _rotate_notifs(self):
|
||||
if self._doing_notif_rotation or self.stream_ordering_day_ago is None:
|
||||
return
|
||||
self._doing_notif_rotation = True
|
||||
|
@ -757,12 +753,12 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
|
|||
while True:
|
||||
logger.info("Rotating notifications")
|
||||
|
||||
caught_up = yield self.db.runInteraction(
|
||||
caught_up = await self.db.runInteraction(
|
||||
"_rotate_notifs", self._rotate_notifs_txn
|
||||
)
|
||||
if caught_up:
|
||||
break
|
||||
yield self.hs.get_clock().sleep(self._rotate_delay)
|
||||
await self.hs.get_clock().sleep(self._rotate_delay)
|
||||
finally:
|
||||
self._doing_notif_rotation = False
|
||||
|
||||
|
|
|
@ -53,6 +53,47 @@ event_counter = Counter(
|
|||
["type", "origin_type", "origin_entity"],
|
||||
)
|
||||
|
||||
STATE_EVENT_TYPES_TO_MARK_UNREAD = {
|
||||
EventTypes.Topic,
|
||||
EventTypes.Name,
|
||||
EventTypes.RoomAvatar,
|
||||
EventTypes.Tombstone,
|
||||
}
|
||||
|
||||
|
||||
def should_count_as_unread(event: EventBase, context: EventContext) -> bool:
|
||||
# Exclude rejected and soft-failed events.
|
||||
if context.rejected or event.internal_metadata.is_soft_failed():
|
||||
return False
|
||||
|
||||
# Exclude notices.
|
||||
if (
|
||||
not event.is_state()
|
||||
and event.type == EventTypes.Message
|
||||
and event.content.get("msgtype") == "m.notice"
|
||||
):
|
||||
return False
|
||||
|
||||
# Exclude edits.
|
||||
relates_to = event.content.get("m.relates_to", {})
|
||||
if relates_to.get("rel_type") == RelationTypes.REPLACE:
|
||||
return False
|
||||
|
||||
# Mark events that have a non-empty string body as unread.
|
||||
body = event.content.get("body")
|
||||
if isinstance(body, str) and body:
|
||||
return True
|
||||
|
||||
# Mark some state events as unread.
|
||||
if event.is_state() and event.type in STATE_EVENT_TYPES_TO_MARK_UNREAD:
|
||||
return True
|
||||
|
||||
# Mark encrypted events as unread.
|
||||
if not event.is_state() and event.type == EventTypes.Encrypted:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def encode_json(json_object):
|
||||
"""
|
||||
|
@ -196,6 +237,10 @@ class PersistEventsStore:
|
|||
|
||||
event_counter.labels(event.type, origin_type, origin_entity).inc()
|
||||
|
||||
self.store.get_unread_message_count_for_user.invalidate_many(
|
||||
(event.room_id,),
|
||||
)
|
||||
|
||||
for room_id, new_state in current_state_for_room.items():
|
||||
self.store.get_current_state_ids.prefill((room_id,), new_state)
|
||||
|
||||
|
@ -817,8 +862,9 @@ class PersistEventsStore:
|
|||
"contains_url": (
|
||||
"url" in event.content and isinstance(event.content["url"], str)
|
||||
),
|
||||
"count_as_unread": should_count_as_unread(event, context),
|
||||
}
|
||||
for event, _ in events_and_contexts
|
||||
for event, context in events_and_contexts
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -41,9 +41,15 @@ from synapse.replication.tcp.streams import BackfillStream
|
|||
from synapse.replication.tcp.streams.events import EventsStream
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
|
||||
from synapse.storage.database import Database
|
||||
from synapse.storage.types import Cursor
|
||||
from synapse.storage.util.id_generators import StreamIdGenerator
|
||||
from synapse.types import get_domain_from_id
|
||||
from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks
|
||||
from synapse.util.caches.descriptors import (
|
||||
Cache,
|
||||
_CacheContext,
|
||||
cached,
|
||||
cachedInlineCallbacks,
|
||||
)
|
||||
from synapse.util.iterutils import batch_iter
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
|
@ -1358,6 +1364,84 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
|
||||
)
|
||||
|
||||
@cached(tree=True, cache_context=True)
|
||||
async def get_unread_message_count_for_user(
|
||||
self, room_id: str, user_id: str, cache_context: _CacheContext,
|
||||
) -> int:
|
||||
"""Retrieve the count of unread messages for the given room and user.
|
||||
|
||||
Args:
|
||||
room_id: The ID of the room to count unread messages in.
|
||||
user_id: The ID of the user to count unread messages for.
|
||||
|
||||
Returns:
|
||||
The number of unread messages for the given user in the given room.
|
||||
"""
|
||||
with Measure(self._clock, "get_unread_message_count_for_user"):
|
||||
last_read_event_id = await self.get_last_receipt_event_id_for_user(
|
||||
user_id=user_id,
|
||||
room_id=room_id,
|
||||
receipt_type="m.read",
|
||||
on_invalidate=cache_context.invalidate,
|
||||
)
|
||||
|
||||
return await self.db.runInteraction(
|
||||
"get_unread_message_count_for_user",
|
||||
self._get_unread_message_count_for_user_txn,
|
||||
user_id,
|
||||
room_id,
|
||||
last_read_event_id,
|
||||
)
|
||||
|
||||
def _get_unread_message_count_for_user_txn(
|
||||
self,
|
||||
txn: Cursor,
|
||||
user_id: str,
|
||||
room_id: str,
|
||||
last_read_event_id: Optional[str],
|
||||
) -> int:
|
||||
if last_read_event_id:
|
||||
# Get the stream ordering for the last read event.
|
||||
stream_ordering = self.db.simple_select_one_onecol_txn(
|
||||
txn=txn,
|
||||
table="events",
|
||||
keyvalues={"room_id": room_id, "event_id": last_read_event_id},
|
||||
retcol="stream_ordering",
|
||||
)
|
||||
else:
|
||||
# If there's no read receipt for that room, it probably means the user hasn't
|
||||
# opened it yet, in which case use the stream ID of their join event.
|
||||
# We can't just set it to 0 otherwise messages from other local users from
|
||||
# before this user joined will be counted as well.
|
||||
txn.execute(
|
||||
"""
|
||||
SELECT stream_ordering FROM local_current_membership
|
||||
LEFT JOIN events USING (event_id, room_id)
|
||||
WHERE membership = 'join'
|
||||
AND user_id = ?
|
||||
AND room_id = ?
|
||||
""",
|
||||
(user_id, room_id),
|
||||
)
|
||||
row = txn.fetchone()
|
||||
|
||||
if row is None:
|
||||
return 0
|
||||
|
||||
stream_ordering = row[0]
|
||||
|
||||
# Count the messages that qualify as unread after the stream ordering we've just
|
||||
# retrieved.
|
||||
sql = """
|
||||
SELECT COUNT(*) FROM events
|
||||
WHERE sender != ? AND room_id = ? AND stream_ordering > ? AND count_as_unread
|
||||
"""
|
||||
|
||||
txn.execute(sql, (user_id, room_id, stream_ordering))
|
||||
row = txn.fetchone()
|
||||
|
||||
return row[0] if row else 0
|
||||
|
||||
|
||||
AllNewEventsResult = namedtuple(
|
||||
"AllNewEventsResult",
|
||||
|
|
|
@ -62,6 +62,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
|
|||
# event_json
|
||||
# event_push_actions
|
||||
# event_reference_hashes
|
||||
# event_relations
|
||||
# event_search
|
||||
# event_to_state_groups
|
||||
# events
|
||||
|
@ -209,6 +210,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
|
|||
"event_edges",
|
||||
"event_forward_extremities",
|
||||
"event_reference_hashes",
|
||||
"event_relations",
|
||||
"event_search",
|
||||
"rejections",
|
||||
):
|
||||
|
|
|
@ -23,8 +23,6 @@ from typing import Any, Dict, List, Optional, Tuple
|
|||
|
||||
from canonicaljson import json
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.api.room_versions import RoomVersion, RoomVersions
|
||||
|
@ -32,7 +30,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json
|
|||
from synapse.storage.data_stores.main.search import SearchStore
|
||||
from synapse.storage.database import Database, LoggingTransaction
|
||||
from synapse.types import ThirdPartyInstanceID
|
||||
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
|
||||
from synapse.util.caches.descriptors import cached
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -192,8 +190,7 @@ class RoomWorkerStore(SQLBaseStore):
|
|||
|
||||
return self.db.runInteraction("count_public_rooms", _count_public_rooms_txn)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_largest_public_rooms(
|
||||
async def get_largest_public_rooms(
|
||||
self,
|
||||
network_tuple: Optional[ThirdPartyInstanceID],
|
||||
search_filter: Optional[dict],
|
||||
|
@ -330,10 +327,10 @@ class RoomWorkerStore(SQLBaseStore):
|
|||
|
||||
return results
|
||||
|
||||
ret_val = yield self.db.runInteraction(
|
||||
ret_val = await self.db.runInteraction(
|
||||
"get_largest_public_rooms", _get_largest_public_rooms_txn
|
||||
)
|
||||
defer.returnValue(ret_val)
|
||||
return ret_val
|
||||
|
||||
@cached(max_entries=10000)
|
||||
def is_room_blocked(self, room_id):
|
||||
|
@ -509,8 +506,8 @@ class RoomWorkerStore(SQLBaseStore):
|
|||
"get_rooms_paginate", _get_rooms_paginate_txn,
|
||||
)
|
||||
|
||||
@cachedInlineCallbacks(max_entries=10000)
|
||||
def get_ratelimit_for_user(self, user_id):
|
||||
@cached(max_entries=10000)
|
||||
async def get_ratelimit_for_user(self, user_id):
|
||||
"""Check if there are any overrides for ratelimiting for the given
|
||||
user
|
||||
|
||||
|
@ -522,7 +519,7 @@ class RoomWorkerStore(SQLBaseStore):
|
|||
of RatelimitOverride are None or 0 then ratelimitng has been
|
||||
disabled for that user entirely.
|
||||
"""
|
||||
row = yield self.db.simple_select_one(
|
||||
row = await self.db.simple_select_one(
|
||||
table="ratelimit_override",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcols=("messages_per_second", "burst_count"),
|
||||
|
@ -538,8 +535,8 @@ class RoomWorkerStore(SQLBaseStore):
|
|||
else:
|
||||
return None
|
||||
|
||||
@cachedInlineCallbacks()
|
||||
def get_retention_policy_for_room(self, room_id):
|
||||
@cached()
|
||||
async def get_retention_policy_for_room(self, room_id):
|
||||
"""Get the retention policy for a given room.
|
||||
|
||||
If no retention policy has been found for this room, returns a policy defined
|
||||
|
@ -566,19 +563,17 @@ class RoomWorkerStore(SQLBaseStore):
|
|||
|
||||
return self.db.cursor_to_dict(txn)
|
||||
|
||||
ret = yield self.db.runInteraction(
|
||||
ret = await self.db.runInteraction(
|
||||
"get_retention_policy_for_room", get_retention_policy_for_room_txn,
|
||||
)
|
||||
|
||||
# If we don't know this room ID, ret will be None, in this case return the default
|
||||
# policy.
|
||||
if not ret:
|
||||
defer.returnValue(
|
||||
{
|
||||
return {
|
||||
"min_lifetime": self.config.retention_default_min_lifetime,
|
||||
"max_lifetime": self.config.retention_default_max_lifetime,
|
||||
}
|
||||
)
|
||||
|
||||
row = ret[0]
|
||||
|
||||
|
@ -592,7 +587,7 @@ class RoomWorkerStore(SQLBaseStore):
|
|||
if row["max_lifetime"] is None:
|
||||
row["max_lifetime"] = self.config.retention_default_max_lifetime
|
||||
|
||||
defer.returnValue(row)
|
||||
return row
|
||||
|
||||
def get_media_mxcs_in_room(self, room_id):
|
||||
"""Retrieves all the local and remote media MXC URIs in a given room
|
||||
|
@ -881,8 +876,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
|
|||
self._background_add_rooms_room_version_column,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _background_insert_retention(self, progress, batch_size):
|
||||
async def _background_insert_retention(self, progress, batch_size):
|
||||
"""Retrieves a list of all rooms within a range and inserts an entry for each of
|
||||
them into the room_retention table.
|
||||
NULLs the property's columns if missing from the retention event in the room's
|
||||
|
@ -940,14 +934,14 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
|
|||
else:
|
||||
return False
|
||||
|
||||
end = yield self.db.runInteraction(
|
||||
end = await self.db.runInteraction(
|
||||
"insert_room_retention", _background_insert_retention_txn,
|
||||
)
|
||||
|
||||
if end:
|
||||
yield self.db.updates._end_background_update("insert_room_retention")
|
||||
await self.db.updates._end_background_update("insert_room_retention")
|
||||
|
||||
defer.returnValue(batch_size)
|
||||
return batch_size
|
||||
|
||||
async def _background_add_rooms_room_version_column(
|
||||
self, progress: dict, batch_size: int
|
||||
|
@ -1096,8 +1090,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
|||
lock=False,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def store_room(
|
||||
async def store_room(
|
||||
self,
|
||||
room_id: str,
|
||||
room_creator_user_id: str,
|
||||
|
@ -1140,7 +1133,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
|||
)
|
||||
|
||||
with self._public_room_id_gen.get_next() as next_id:
|
||||
yield self.db.runInteraction("store_room_txn", store_room_txn, next_id)
|
||||
await self.db.runInteraction("store_room_txn", store_room_txn, next_id)
|
||||
except Exception as e:
|
||||
logger.error("store_room with room_id=%s failed: %s", room_id, e)
|
||||
raise StoreError(500, "Problem creating room.")
|
||||
|
@ -1165,8 +1158,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
|||
lock=False,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def set_room_is_public(self, room_id, is_public):
|
||||
async def set_room_is_public(self, room_id, is_public):
|
||||
def set_room_is_public_txn(txn, next_id):
|
||||
self.db.simple_update_one_txn(
|
||||
txn,
|
||||
|
@ -1206,13 +1198,12 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
|||
)
|
||||
|
||||
with self._public_room_id_gen.get_next() as next_id:
|
||||
yield self.db.runInteraction(
|
||||
await self.db.runInteraction(
|
||||
"set_room_is_public", set_room_is_public_txn, next_id
|
||||
)
|
||||
self.hs.get_notifier().on_new_replication_data()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def set_room_is_public_appservice(
|
||||
async def set_room_is_public_appservice(
|
||||
self, room_id, appservice_id, network_id, is_public
|
||||
):
|
||||
"""Edit the appservice/network specific public room list.
|
||||
|
@ -1287,7 +1278,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
|||
)
|
||||
|
||||
with self._public_room_id_gen.get_next() as next_id:
|
||||
yield self.db.runInteraction(
|
||||
await self.db.runInteraction(
|
||||
"set_room_is_public_appservice",
|
||||
set_room_is_public_appservice_txn,
|
||||
next_id,
|
||||
|
@ -1327,49 +1318,44 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
|||
def get_current_public_room_stream_id(self):
|
||||
return self._public_room_id_gen.get_current_token()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def block_room(self, room_id, user_id):
|
||||
async def block_room(self, room_id: str, user_id: str) -> None:
|
||||
"""Marks the room as blocked. Can be called multiple times.
|
||||
|
||||
Args:
|
||||
room_id (str): Room to block
|
||||
user_id (str): Who blocked it
|
||||
|
||||
Returns:
|
||||
Deferred
|
||||
room_id: Room to block
|
||||
user_id: Who blocked it
|
||||
"""
|
||||
yield self.db.simple_upsert(
|
||||
await self.db.simple_upsert(
|
||||
table="blocked_rooms",
|
||||
keyvalues={"room_id": room_id},
|
||||
values={},
|
||||
insertion_values={"user_id": user_id},
|
||||
desc="block_room",
|
||||
)
|
||||
yield self.db.runInteraction(
|
||||
await self.db.runInteraction(
|
||||
"block_room_invalidation",
|
||||
self._invalidate_cache_and_stream,
|
||||
self.is_room_blocked,
|
||||
(room_id,),
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_rooms_for_retention_period_in_range(
|
||||
self, min_ms, max_ms, include_null=False
|
||||
):
|
||||
async def get_rooms_for_retention_period_in_range(
|
||||
self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False
|
||||
) -> Dict[str, dict]:
|
||||
"""Retrieves all of the rooms within the given retention range.
|
||||
|
||||
Optionally includes the rooms which don't have a retention policy.
|
||||
|
||||
Args:
|
||||
min_ms (int|None): Duration in milliseconds that define the lower limit of
|
||||
min_ms: Duration in milliseconds that define the lower limit of
|
||||
the range to handle (exclusive). If None, doesn't set a lower limit.
|
||||
max_ms (int|None): Duration in milliseconds that define the upper limit of
|
||||
max_ms: Duration in milliseconds that define the upper limit of
|
||||
the range to handle (inclusive). If None, doesn't set an upper limit.
|
||||
include_null (bool): Whether to include rooms which retention policy is NULL
|
||||
include_null: Whether to include rooms which retention policy is NULL
|
||||
in the returned set.
|
||||
|
||||
Returns:
|
||||
dict[str, dict]: The rooms within this range, along with their retention
|
||||
The rooms within this range, along with their retention
|
||||
policy. The key is "room_id", and maps to a dict describing the retention
|
||||
policy associated with this room ID. The keys for this nested dict are
|
||||
"min_lifetime" (int|None), and "max_lifetime" (int|None).
|
||||
|
@ -1431,9 +1417,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
|||
|
||||
return rooms_dict
|
||||
|
||||
rooms = yield self.db.runInteraction(
|
||||
rooms = await self.db.runInteraction(
|
||||
"get_rooms_for_retention_period_in_range",
|
||||
get_rooms_for_retention_period_in_range_txn,
|
||||
)
|
||||
|
||||
defer.returnValue(rooms)
|
||||
return rooms
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
/* Copyright 2020 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
-- Store a boolean value in the events table for whether the event should be counted in
|
||||
-- the unread_count property of sync responses.
|
||||
ALTER TABLE events ADD COLUMN count_as_unread BOOLEAN;
|
|
@ -16,12 +16,12 @@
|
|||
import collections.abc
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
|
||||
from twisted.internet import defer
|
||||
from typing import Iterable, Optional, Set
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
|
||||
from synapse.events import EventBase
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
|
||||
from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore
|
||||
|
@ -108,16 +108,15 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
create_event = await self.get_create_event_for_room(room_id)
|
||||
return create_event.content.get("room_version", "1")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_room_predecessor(self, room_id):
|
||||
async def get_room_predecessor(self, room_id: str) -> Optional[dict]:
|
||||
"""Get the predecessor of an upgraded room if it exists.
|
||||
Otherwise return None.
|
||||
|
||||
Args:
|
||||
room_id (str)
|
||||
room_id: The room ID.
|
||||
|
||||
Returns:
|
||||
Deferred[dict|None]: A dictionary containing the structure of the predecessor
|
||||
A dictionary containing the structure of the predecessor
|
||||
field from the room's create event. The structure is subject to other servers,
|
||||
but it is expected to be:
|
||||
* room_id (str): The room ID of the predecessor room
|
||||
|
@ -129,7 +128,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
NotFoundError if the given room is unknown
|
||||
"""
|
||||
# Retrieve the room's create event
|
||||
create_event = yield self.get_create_event_for_room(room_id)
|
||||
create_event = await self.get_create_event_for_room(room_id)
|
||||
|
||||
# Retrieve the predecessor key of the create event
|
||||
predecessor = create_event.content.get("predecessor", None)
|
||||
|
@ -140,20 +139,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
|
||||
return predecessor
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_create_event_for_room(self, room_id):
|
||||
async def get_create_event_for_room(self, room_id: str) -> EventBase:
|
||||
"""Get the create state event for a room.
|
||||
|
||||
Args:
|
||||
room_id (str)
|
||||
room_id: The room ID.
|
||||
|
||||
Returns:
|
||||
Deferred[EventBase]: The room creation event.
|
||||
The room creation event.
|
||||
|
||||
Raises:
|
||||
NotFoundError if the room is unknown
|
||||
"""
|
||||
state_ids = yield self.get_current_state_ids(room_id)
|
||||
state_ids = await self.get_current_state_ids(room_id)
|
||||
create_id = state_ids.get((EventTypes.Create, ""))
|
||||
|
||||
# If we can't find the create event, assume we've hit a dead end
|
||||
|
@ -161,7 +159,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
raise NotFoundError("Unknown room %s" % (room_id,))
|
||||
|
||||
# Retrieve the room's create event and return
|
||||
create_event = yield self.get_event(create_id)
|
||||
create_event = await self.get_event(create_id)
|
||||
return create_event
|
||||
|
||||
@cached(max_entries=100000, iterable=True)
|
||||
|
@ -237,18 +235,17 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
"get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_canonical_alias_for_room(self, room_id):
|
||||
async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]:
|
||||
"""Get canonical alias for room, if any
|
||||
|
||||
Args:
|
||||
room_id (str)
|
||||
room_id: The room ID
|
||||
|
||||
Returns:
|
||||
Deferred[str|None]: The canonical alias, if any
|
||||
The canonical alias, if any
|
||||
"""
|
||||
|
||||
state = yield self.get_filtered_current_state_ids(
|
||||
state = await self.get_filtered_current_state_ids(
|
||||
room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")])
|
||||
)
|
||||
|
||||
|
@ -256,7 +253,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
if not event_id:
|
||||
return
|
||||
|
||||
event = yield self.get_event(event_id, allow_none=True)
|
||||
event = await self.get_event(event_id, allow_none=True)
|
||||
if not event:
|
||||
return
|
||||
|
||||
|
@ -292,19 +289,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
|
||||
return {row["event_id"]: row["state_group"] for row in rows}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_referenced_state_groups(self, state_groups):
|
||||
async def get_referenced_state_groups(
|
||||
self, state_groups: Iterable[int]
|
||||
) -> Set[int]:
|
||||
"""Check if the state groups are referenced by events.
|
||||
|
||||
Args:
|
||||
state_groups (Iterable[int])
|
||||
state_groups
|
||||
|
||||
Returns:
|
||||
Deferred[set[int]]: The subset of state groups that are
|
||||
referenced.
|
||||
The subset of state groups that are referenced.
|
||||
"""
|
||||
|
||||
rows = yield self.db.simple_select_many_batch(
|
||||
rows = await self.db.simple_select_many_batch(
|
||||
table="event_to_state_groups",
|
||||
column="state_group",
|
||||
iterable=state_groups,
|
||||
|
|
|
@ -16,8 +16,8 @@
|
|||
|
||||
import logging
|
||||
from itertools import chain
|
||||
from typing import Tuple
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.defer import DeferredLock
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
|
@ -97,13 +97,12 @@ class StatsStore(StateDeltasStore):
|
|||
"""
|
||||
return (ts // self.stats_bucket_size) * self.stats_bucket_size
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _populate_stats_process_users(self, progress, batch_size):
|
||||
async def _populate_stats_process_users(self, progress, batch_size):
|
||||
"""
|
||||
This is a background update which regenerates statistics for users.
|
||||
"""
|
||||
if not self.stats_enabled:
|
||||
yield self.db.updates._end_background_update("populate_stats_process_users")
|
||||
await self.db.updates._end_background_update("populate_stats_process_users")
|
||||
return 1
|
||||
|
||||
last_user_id = progress.get("last_user_id", "")
|
||||
|
@ -118,20 +117,20 @@ class StatsStore(StateDeltasStore):
|
|||
txn.execute(sql, (last_user_id, batch_size))
|
||||
return [r for r, in txn]
|
||||
|
||||
users_to_work_on = yield self.db.runInteraction(
|
||||
users_to_work_on = await self.db.runInteraction(
|
||||
"_populate_stats_process_users", _get_next_batch
|
||||
)
|
||||
|
||||
# No more rooms -- complete the transaction.
|
||||
if not users_to_work_on:
|
||||
yield self.db.updates._end_background_update("populate_stats_process_users")
|
||||
await self.db.updates._end_background_update("populate_stats_process_users")
|
||||
return 1
|
||||
|
||||
for user_id in users_to_work_on:
|
||||
yield self._calculate_and_set_initial_state_for_user(user_id)
|
||||
await self._calculate_and_set_initial_state_for_user(user_id)
|
||||
progress["last_user_id"] = user_id
|
||||
|
||||
yield self.db.runInteraction(
|
||||
await self.db.runInteraction(
|
||||
"populate_stats_process_users",
|
||||
self.db.updates._background_update_progress_txn,
|
||||
"populate_stats_process_users",
|
||||
|
@ -140,13 +139,12 @@ class StatsStore(StateDeltasStore):
|
|||
|
||||
return len(users_to_work_on)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _populate_stats_process_rooms(self, progress, batch_size):
|
||||
async def _populate_stats_process_rooms(self, progress, batch_size):
|
||||
"""
|
||||
This is a background update which regenerates statistics for rooms.
|
||||
"""
|
||||
if not self.stats_enabled:
|
||||
yield self.db.updates._end_background_update("populate_stats_process_rooms")
|
||||
await self.db.updates._end_background_update("populate_stats_process_rooms")
|
||||
return 1
|
||||
|
||||
last_room_id = progress.get("last_room_id", "")
|
||||
|
@ -161,20 +159,20 @@ class StatsStore(StateDeltasStore):
|
|||
txn.execute(sql, (last_room_id, batch_size))
|
||||
return [r for r, in txn]
|
||||
|
||||
rooms_to_work_on = yield self.db.runInteraction(
|
||||
rooms_to_work_on = await self.db.runInteraction(
|
||||
"populate_stats_rooms_get_batch", _get_next_batch
|
||||
)
|
||||
|
||||
# No more rooms -- complete the transaction.
|
||||
if not rooms_to_work_on:
|
||||
yield self.db.updates._end_background_update("populate_stats_process_rooms")
|
||||
await self.db.updates._end_background_update("populate_stats_process_rooms")
|
||||
return 1
|
||||
|
||||
for room_id in rooms_to_work_on:
|
||||
yield self._calculate_and_set_initial_state_for_room(room_id)
|
||||
await self._calculate_and_set_initial_state_for_room(room_id)
|
||||
progress["last_room_id"] = room_id
|
||||
|
||||
yield self.db.runInteraction(
|
||||
await self.db.runInteraction(
|
||||
"_populate_stats_process_rooms",
|
||||
self.db.updates._background_update_progress_txn,
|
||||
"populate_stats_process_rooms",
|
||||
|
@ -696,16 +694,16 @@ class StatsStore(StateDeltasStore):
|
|||
|
||||
return room_deltas, user_deltas
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _calculate_and_set_initial_state_for_room(self, room_id):
|
||||
async def _calculate_and_set_initial_state_for_room(
|
||||
self, room_id: str
|
||||
) -> Tuple[dict, dict, int]:
|
||||
"""Calculate and insert an entry into room_stats_current.
|
||||
|
||||
Args:
|
||||
room_id (str)
|
||||
room_id: The room ID under calculation.
|
||||
|
||||
Returns:
|
||||
Deferred[tuple[dict, dict, int]]: A tuple of room state, membership
|
||||
counts and stream position.
|
||||
A tuple of room state, membership counts and stream position.
|
||||
"""
|
||||
|
||||
def _fetch_current_state_stats(txn):
|
||||
|
@ -767,11 +765,11 @@ class StatsStore(StateDeltasStore):
|
|||
current_state_events_count,
|
||||
users_in_room,
|
||||
pos,
|
||||
) = yield self.db.runInteraction(
|
||||
) = await self.db.runInteraction(
|
||||
"get_initial_state_for_room", _fetch_current_state_stats
|
||||
)
|
||||
|
||||
state_event_map = yield self.get_events(event_ids, get_prev_content=False)
|
||||
state_event_map = await self.get_events(event_ids, get_prev_content=False)
|
||||
|
||||
room_state = {
|
||||
"join_rules": None,
|
||||
|
@ -806,11 +804,11 @@ class StatsStore(StateDeltasStore):
|
|||
event.content.get("m.federate", True) is True
|
||||
)
|
||||
|
||||
yield self.update_room_state(room_id, room_state)
|
||||
await self.update_room_state(room_id, room_state)
|
||||
|
||||
local_users_in_room = [u for u in users_in_room if self.hs.is_mine_id(u)]
|
||||
|
||||
yield self.update_stats_delta(
|
||||
await self.update_stats_delta(
|
||||
ts=self.clock.time_msec(),
|
||||
stats_type="room",
|
||||
stats_id=room_id,
|
||||
|
@ -826,8 +824,7 @@ class StatsStore(StateDeltasStore):
|
|||
},
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _calculate_and_set_initial_state_for_user(self, user_id):
|
||||
async def _calculate_and_set_initial_state_for_user(self, user_id):
|
||||
def _calculate_and_set_initial_state_for_user_txn(txn):
|
||||
pos = self._get_max_stream_id_in_current_state_deltas_txn(txn)
|
||||
|
||||
|
@ -842,12 +839,12 @@ class StatsStore(StateDeltasStore):
|
|||
(count,) = txn.fetchone()
|
||||
return count, pos
|
||||
|
||||
joined_rooms, pos = yield self.db.runInteraction(
|
||||
joined_rooms, pos = await self.db.runInteraction(
|
||||
"calculate_and_set_initial_state_for_user",
|
||||
_calculate_and_set_initial_state_for_user_txn,
|
||||
)
|
||||
|
||||
yield self.update_stats_delta(
|
||||
await self.update_stats_delta(
|
||||
ts=self.clock.time_msec(),
|
||||
stats_type="user",
|
||||
stats_id=user_id,
|
||||
|
|
|
@ -139,10 +139,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
|||
"get_state_group_delta", _get_state_group_delta_txn
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_state_groups_from_groups(
|
||||
async def _get_state_groups_from_groups(
|
||||
self, groups: List[int], state_filter: StateFilter
|
||||
):
|
||||
) -> Dict[int, StateMap[str]]:
|
||||
"""Returns the state groups for a given set of groups from the
|
||||
database, filtering on types of state events.
|
||||
|
||||
|
@ -151,13 +150,13 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
|||
state_filter: The state filter used to fetch state
|
||||
from the database.
|
||||
Returns:
|
||||
Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
|
||||
Dict of state group to state map.
|
||||
"""
|
||||
results = {}
|
||||
|
||||
chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)]
|
||||
for chunk in chunks:
|
||||
res = yield self.db.runInteraction(
|
||||
res = await self.db.runInteraction(
|
||||
"_get_state_groups_from_groups",
|
||||
self._get_state_groups_from_groups_txn,
|
||||
chunk,
|
||||
|
@ -206,10 +205,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
|||
|
||||
return state_filter.filter_state(state_dict_ids), not missing_types
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_state_for_groups(
|
||||
async def _get_state_for_groups(
|
||||
self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
|
||||
):
|
||||
) -> Dict[int, StateMap[str]]:
|
||||
"""Gets the state at each of a list of state groups, optionally
|
||||
filtering by type/state_key
|
||||
|
||||
|
@ -219,7 +217,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
|||
state_filter: The state filter used to fetch state
|
||||
from the database.
|
||||
Returns:
|
||||
Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
|
||||
Dict of state group to state map.
|
||||
"""
|
||||
|
||||
member_filter, non_member_filter = state_filter.get_member_split()
|
||||
|
@ -228,14 +226,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
|||
(
|
||||
non_member_state,
|
||||
incomplete_groups_nm,
|
||||
) = yield self._get_state_for_groups_using_cache(
|
||||
) = self._get_state_for_groups_using_cache(
|
||||
groups, self._state_group_cache, state_filter=non_member_filter
|
||||
)
|
||||
|
||||
(
|
||||
member_state,
|
||||
incomplete_groups_m,
|
||||
) = yield self._get_state_for_groups_using_cache(
|
||||
(member_state, incomplete_groups_m,) = self._get_state_for_groups_using_cache(
|
||||
groups, self._state_group_members_cache, state_filter=member_filter
|
||||
)
|
||||
|
||||
|
@ -256,7 +251,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
|||
# Help the cache hit ratio by expanding the filter a bit
|
||||
db_state_filter = state_filter.return_expanded()
|
||||
|
||||
group_to_state_dict = yield self._get_state_groups_from_groups(
|
||||
group_to_state_dict = await self._get_state_groups_from_groups(
|
||||
list(incomplete_groups), state_filter=db_state_filter
|
||||
)
|
||||
|
||||
|
@ -576,19 +571,19 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
|||
((sg,) for sg in state_groups_to_delete),
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_previous_state_groups(self, state_groups):
|
||||
async def get_previous_state_groups(
|
||||
self, state_groups: Iterable[int]
|
||||
) -> Dict[int, int]:
|
||||
"""Fetch the previous groups of the given state groups.
|
||||
|
||||
Args:
|
||||
state_groups (Iterable[int])
|
||||
state_groups
|
||||
|
||||
Returns:
|
||||
Deferred[dict[int, int]]: mapping from state group to previous
|
||||
state group.
|
||||
A mapping from state group to previous state group.
|
||||
"""
|
||||
|
||||
rows = yield self.db.simple_select_many_batch(
|
||||
rows = await self.db.simple_select_many_batch(
|
||||
table="state_group_edges",
|
||||
column="prev_state_group",
|
||||
iterable=state_groups,
|
||||
|
|
|
@ -49,11 +49,11 @@ from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3E
|
|||
from synapse.storage.types import Connection, Cursor
|
||||
from synapse.types import Collection
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# python 3 does not have a maximum int value
|
||||
MAX_TXN_ID = 2 ** 63 - 1
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
sql_logger = logging.getLogger("synapse.storage.SQL")
|
||||
transaction_logger = logging.getLogger("synapse.storage.txn")
|
||||
perf_logger = logging.getLogger("synapse.storage.TIME")
|
||||
|
@ -233,7 +233,7 @@ class LoggingTransaction:
|
|||
try:
|
||||
return func(sql, *args)
|
||||
except Exception as e:
|
||||
logger.debug("[SQL FAIL] {%s} %s", self.name, e)
|
||||
sql_logger.debug("[SQL FAIL] {%s} %s", self.name, e)
|
||||
raise
|
||||
finally:
|
||||
secs = time.time() - start
|
||||
|
@ -419,7 +419,7 @@ class Database(object):
|
|||
except self.engine.module.OperationalError as e:
|
||||
# This can happen if the database disappears mid
|
||||
# transaction.
|
||||
logger.warning(
|
||||
transaction_logger.warning(
|
||||
"[TXN OPERROR] {%s} %s %d/%d", name, e, i, N,
|
||||
)
|
||||
if i < N:
|
||||
|
@ -427,18 +427,20 @@ class Database(object):
|
|||
try:
|
||||
conn.rollback()
|
||||
except self.engine.module.Error as e1:
|
||||
logger.warning("[TXN EROLL] {%s} %s", name, e1)
|
||||
transaction_logger.warning("[TXN EROLL] {%s} %s", name, e1)
|
||||
continue
|
||||
raise
|
||||
except self.engine.module.DatabaseError as e:
|
||||
if self.engine.is_deadlock(e):
|
||||
logger.warning("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
|
||||
transaction_logger.warning(
|
||||
"[TXN DEADLOCK] {%s} %d/%d", name, i, N
|
||||
)
|
||||
if i < N:
|
||||
i += 1
|
||||
try:
|
||||
conn.rollback()
|
||||
except self.engine.module.Error as e1:
|
||||
logger.warning(
|
||||
transaction_logger.warning(
|
||||
"[TXN EROLL] {%s} %s", name, e1,
|
||||
)
|
||||
continue
|
||||
|
@ -478,7 +480,7 @@ class Database(object):
|
|||
# [2]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/cursor.c#L236
|
||||
cursor.close()
|
||||
except Exception as e:
|
||||
logger.debug("[TXN FAIL] {%s} %s", name, e)
|
||||
transaction_logger.debug("[TXN FAIL] {%s} %s", name, e)
|
||||
raise
|
||||
finally:
|
||||
end = monotonic_time()
|
||||
|
|
|
@ -25,7 +25,7 @@ from prometheus_client import Counter, Histogram
|
|||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.events import FrozenEvent
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
|
@ -192,12 +192,11 @@ class EventsPersistenceStorage(object):
|
|||
self._event_persist_queue = _EventPeristenceQueue()
|
||||
self._state_resolution_handler = hs.get_state_resolution_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def persist_events(
|
||||
async def persist_events(
|
||||
self,
|
||||
events_and_contexts: List[Tuple[FrozenEvent, EventContext]],
|
||||
events_and_contexts: List[Tuple[EventBase, EventContext]],
|
||||
backfilled: bool = False,
|
||||
):
|
||||
) -> int:
|
||||
"""
|
||||
Write events to the database
|
||||
Args:
|
||||
|
@ -207,7 +206,7 @@ class EventsPersistenceStorage(object):
|
|||
which might update the current state etc.
|
||||
|
||||
Returns:
|
||||
Deferred[int]: the stream ordering of the latest persisted event
|
||||
the stream ordering of the latest persisted event
|
||||
"""
|
||||
partitioned = {}
|
||||
for event, ctx in events_and_contexts:
|
||||
|
@ -223,22 +222,19 @@ class EventsPersistenceStorage(object):
|
|||
for room_id in partitioned:
|
||||
self._maybe_start_persisting(room_id)
|
||||
|
||||
yield make_deferred_yieldable(
|
||||
await make_deferred_yieldable(
|
||||
defer.gatherResults(deferreds, consumeErrors=True)
|
||||
)
|
||||
|
||||
max_persisted_id = yield self.main_store.get_current_events_token()
|
||||
return self.main_store.get_current_events_token()
|
||||
|
||||
return max_persisted_id
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def persist_event(
|
||||
self, event: FrozenEvent, context: EventContext, backfilled: bool = False
|
||||
):
|
||||
async def persist_event(
|
||||
self, event: EventBase, context: EventContext, backfilled: bool = False
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
Returns:
|
||||
Deferred[Tuple[int, int]]: the stream ordering of ``event``,
|
||||
and the stream ordering of the latest persisted event
|
||||
The stream ordering of `event`, and the stream ordering of the
|
||||
latest persisted event
|
||||
"""
|
||||
deferred = self._event_persist_queue.add_to_queue(
|
||||
event.room_id, [(event, context)], backfilled=backfilled
|
||||
|
@ -246,9 +242,9 @@ class EventsPersistenceStorage(object):
|
|||
|
||||
self._maybe_start_persisting(event.room_id)
|
||||
|
||||
yield make_deferred_yieldable(deferred)
|
||||
await make_deferred_yieldable(deferred)
|
||||
|
||||
max_persisted_id = yield self.main_store.get_current_events_token()
|
||||
max_persisted_id = self.main_store.get_current_events_token()
|
||||
return (event.internal_metadata.stream_ordering, max_persisted_id)
|
||||
|
||||
def _maybe_start_persisting(self, room_id: str):
|
||||
|
@ -262,7 +258,7 @@ class EventsPersistenceStorage(object):
|
|||
|
||||
async def _persist_events(
|
||||
self,
|
||||
events_and_contexts: List[Tuple[FrozenEvent, EventContext]],
|
||||
events_and_contexts: List[Tuple[EventBase, EventContext]],
|
||||
backfilled: bool = False,
|
||||
):
|
||||
"""Calculates the change to current state and forward extremities, and
|
||||
|
@ -439,7 +435,7 @@ class EventsPersistenceStorage(object):
|
|||
async def _calculate_new_extremities(
|
||||
self,
|
||||
room_id: str,
|
||||
event_contexts: List[Tuple[FrozenEvent, EventContext]],
|
||||
event_contexts: List[Tuple[EventBase, EventContext]],
|
||||
latest_event_ids: List[str],
|
||||
):
|
||||
"""Calculates the new forward extremities for a room given events to
|
||||
|
@ -497,7 +493,7 @@ class EventsPersistenceStorage(object):
|
|||
async def _get_new_state_after_events(
|
||||
self,
|
||||
room_id: str,
|
||||
events_context: List[Tuple[FrozenEvent, EventContext]],
|
||||
events_context: List[Tuple[EventBase, EventContext]],
|
||||
old_latest_event_ids: Iterable[str],
|
||||
new_latest_event_ids: Iterable[str],
|
||||
) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]]]:
|
||||
|
@ -683,7 +679,7 @@ class EventsPersistenceStorage(object):
|
|||
async def _is_server_still_joined(
|
||||
self,
|
||||
room_id: str,
|
||||
ev_ctx_rm: List[Tuple[FrozenEvent, EventContext]],
|
||||
ev_ctx_rm: List[Tuple[EventBase, EventContext]],
|
||||
delta: DeltaState,
|
||||
current_state: Optional[StateMap[str]],
|
||||
potentially_left_users: Set[str],
|
||||
|
|
|
@ -15,8 +15,7 @@
|
|||
|
||||
import itertools
|
||||
import logging
|
||||
|
||||
from twisted.internet import defer
|
||||
from typing import Set
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -28,49 +27,48 @@ class PurgeEventsStorage(object):
|
|||
def __init__(self, hs, stores):
|
||||
self.stores = stores
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def purge_room(self, room_id: str):
|
||||
async def purge_room(self, room_id: str):
|
||||
"""Deletes all record of a room
|
||||
"""
|
||||
|
||||
state_groups_to_delete = yield self.stores.main.purge_room(room_id)
|
||||
yield self.stores.state.purge_room_state(room_id, state_groups_to_delete)
|
||||
state_groups_to_delete = await self.stores.main.purge_room(room_id)
|
||||
await self.stores.state.purge_room_state(room_id, state_groups_to_delete)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def purge_history(self, room_id, token, delete_local_events):
|
||||
async def purge_history(
|
||||
self, room_id: str, token: str, delete_local_events: bool
|
||||
) -> None:
|
||||
"""Deletes room history before a certain point
|
||||
|
||||
Args:
|
||||
room_id (str):
|
||||
room_id: The room ID
|
||||
|
||||
token (str): A topological token to delete events before
|
||||
token: A topological token to delete events before
|
||||
|
||||
delete_local_events (bool):
|
||||
delete_local_events:
|
||||
if True, we will delete local events as well as remote ones
|
||||
(instead of just marking them as outliers and deleting their
|
||||
state groups).
|
||||
"""
|
||||
state_groups = yield self.stores.main.purge_history(
|
||||
state_groups = await self.stores.main.purge_history(
|
||||
room_id, token, delete_local_events
|
||||
)
|
||||
|
||||
logger.info("[purge] finding state groups that can be deleted")
|
||||
|
||||
sg_to_delete = yield self._find_unreferenced_groups(state_groups)
|
||||
sg_to_delete = await self._find_unreferenced_groups(state_groups)
|
||||
|
||||
yield self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete)
|
||||
await self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _find_unreferenced_groups(self, state_groups):
|
||||
async def _find_unreferenced_groups(self, state_groups: Set[int]) -> Set[int]:
|
||||
"""Used when purging history to figure out which state groups can be
|
||||
deleted.
|
||||
|
||||
Args:
|
||||
state_groups (set[int]): Set of state groups referenced by events
|
||||
state_groups: Set of state groups referenced by events
|
||||
that are going to be deleted.
|
||||
|
||||
Returns:
|
||||
Deferred[set[int]] The set of state groups that can be deleted.
|
||||
The set of state groups that can be deleted.
|
||||
"""
|
||||
# Graph of state group -> previous group
|
||||
graph = {}
|
||||
|
@ -93,7 +91,7 @@ class PurgeEventsStorage(object):
|
|||
current_search = set(itertools.islice(next_to_search, 100))
|
||||
next_to_search -= current_search
|
||||
|
||||
referenced = yield self.stores.main.get_referenced_state_groups(
|
||||
referenced = await self.stores.main.get_referenced_state_groups(
|
||||
current_search
|
||||
)
|
||||
referenced_groups |= referenced
|
||||
|
@ -102,7 +100,7 @@ class PurgeEventsStorage(object):
|
|||
# groups that are referenced.
|
||||
current_search -= referenced
|
||||
|
||||
edges = yield self.stores.state.get_previous_state_groups(current_search)
|
||||
edges = await self.stores.state.get_previous_state_groups(current_search)
|
||||
|
||||
prevs = set(edges.values())
|
||||
# We don't bother re-handling groups we've already seen
|
||||
|
|
|
@ -14,13 +14,12 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Iterable, List, TypeVar
|
||||
from typing import Awaitable, Dict, Iterable, List, Optional, Set, Tuple, TypeVar
|
||||
|
||||
import attr
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.events import EventBase
|
||||
from synapse.types import StateMap
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -34,16 +33,16 @@ class StateFilter(object):
|
|||
"""A filter used when querying for state.
|
||||
|
||||
Attributes:
|
||||
types (dict[str, set[str]|None]): Map from type to set of state keys (or
|
||||
None). This specifies which state_keys for the given type to fetch
|
||||
from the DB. If None then all events with that type are fetched. If
|
||||
the set is empty then no events with that type are fetched.
|
||||
include_others (bool): Whether to fetch events with types that do not
|
||||
types: Map from type to set of state keys (or None). This specifies
|
||||
which state_keys for the given type to fetch from the DB. If None
|
||||
then all events with that type are fetched. If the set is empty
|
||||
then no events with that type are fetched.
|
||||
include_others: Whether to fetch events with types that do not
|
||||
appear in `types`.
|
||||
"""
|
||||
|
||||
types = attr.ib()
|
||||
include_others = attr.ib(default=False)
|
||||
types = attr.ib(type=Dict[str, Optional[Set[str]]])
|
||||
include_others = attr.ib(default=False, type=bool)
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
# If `include_others` is set we canonicalise the filter by removing
|
||||
|
@ -52,36 +51,35 @@ class StateFilter(object):
|
|||
self.types = {k: v for k, v in self.types.items() if v is not None}
|
||||
|
||||
@staticmethod
|
||||
def all():
|
||||
def all() -> "StateFilter":
|
||||
"""Creates a filter that fetches everything.
|
||||
|
||||
Returns:
|
||||
StateFilter
|
||||
The new state filter.
|
||||
"""
|
||||
return StateFilter(types={}, include_others=True)
|
||||
|
||||
@staticmethod
|
||||
def none():
|
||||
def none() -> "StateFilter":
|
||||
"""Creates a filter that fetches nothing.
|
||||
|
||||
Returns:
|
||||
StateFilter
|
||||
The new state filter.
|
||||
"""
|
||||
return StateFilter(types={}, include_others=False)
|
||||
|
||||
@staticmethod
|
||||
def from_types(types):
|
||||
def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter":
|
||||
"""Creates a filter that only fetches the given types
|
||||
|
||||
Args:
|
||||
types (Iterable[tuple[str, str|None]]): A list of type and state
|
||||
keys to fetch. A state_key of None fetches everything for
|
||||
that type
|
||||
types: A list of type and state keys to fetch. A state_key of None
|
||||
fetches everything for that type
|
||||
|
||||
Returns:
|
||||
StateFilter
|
||||
The new state filter.
|
||||
"""
|
||||
type_dict = {}
|
||||
type_dict = {} # type: Dict[str, Optional[Set[str]]]
|
||||
for typ, s in types:
|
||||
if typ in type_dict:
|
||||
if type_dict[typ] is None:
|
||||
|
@ -91,24 +89,24 @@ class StateFilter(object):
|
|||
type_dict[typ] = None
|
||||
continue
|
||||
|
||||
type_dict.setdefault(typ, set()).add(s)
|
||||
type_dict.setdefault(typ, set()).add(s) # type: ignore
|
||||
|
||||
return StateFilter(types=type_dict)
|
||||
|
||||
@staticmethod
|
||||
def from_lazy_load_member_list(members):
|
||||
def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter":
|
||||
"""Creates a filter that returns all non-member events, plus the member
|
||||
events for the given users
|
||||
|
||||
Args:
|
||||
members (iterable[str]): Set of user IDs
|
||||
members: Set of user IDs
|
||||
|
||||
Returns:
|
||||
StateFilter
|
||||
The new state filter
|
||||
"""
|
||||
return StateFilter(types={EventTypes.Member: set(members)}, include_others=True)
|
||||
|
||||
def return_expanded(self):
|
||||
def return_expanded(self) -> "StateFilter":
|
||||
"""Creates a new StateFilter where type wild cards have been removed
|
||||
(except for memberships). The returned filter is a superset of the
|
||||
current one, i.e. anything that passes the current filter will pass
|
||||
|
@ -130,7 +128,7 @@ class StateFilter(object):
|
|||
return all non-member events
|
||||
|
||||
Returns:
|
||||
StateFilter
|
||||
The new state filter.
|
||||
"""
|
||||
|
||||
if self.is_full():
|
||||
|
@ -167,7 +165,7 @@ class StateFilter(object):
|
|||
include_others=True,
|
||||
)
|
||||
|
||||
def make_sql_filter_clause(self):
|
||||
def make_sql_filter_clause(self) -> Tuple[str, List[str]]:
|
||||
"""Converts the filter to an SQL clause.
|
||||
|
||||
For example:
|
||||
|
@ -179,13 +177,12 @@ class StateFilter(object):
|
|||
|
||||
|
||||
Returns:
|
||||
tuple[str, list]: The SQL string (may be empty) and arguments. An
|
||||
empty SQL string is returned when the filter matches everything
|
||||
(i.e. is "full").
|
||||
The SQL string (may be empty) and arguments. An empty SQL string is
|
||||
returned when the filter matches everything (i.e. is "full").
|
||||
"""
|
||||
|
||||
where_clause = ""
|
||||
where_args = []
|
||||
where_args = [] # type: List[str]
|
||||
|
||||
if self.is_full():
|
||||
return where_clause, where_args
|
||||
|
@ -221,7 +218,7 @@ class StateFilter(object):
|
|||
|
||||
return where_clause, where_args
|
||||
|
||||
def max_entries_returned(self):
|
||||
def max_entries_returned(self) -> Optional[int]:
|
||||
"""Returns the maximum number of entries this filter will return if
|
||||
known, otherwise returns None.
|
||||
|
||||
|
@ -260,33 +257,33 @@ class StateFilter(object):
|
|||
|
||||
return filtered_state
|
||||
|
||||
def is_full(self):
|
||||
def is_full(self) -> bool:
|
||||
"""Whether this filter fetches everything or not
|
||||
|
||||
Returns:
|
||||
bool
|
||||
True if the filter fetches everything.
|
||||
"""
|
||||
return self.include_others and not self.types
|
||||
|
||||
def has_wildcards(self):
|
||||
def has_wildcards(self) -> bool:
|
||||
"""Whether the filter includes wildcards or is attempting to fetch
|
||||
specific state.
|
||||
|
||||
Returns:
|
||||
bool
|
||||
True if the filter includes wildcards.
|
||||
"""
|
||||
|
||||
return self.include_others or any(
|
||||
state_keys is None for state_keys in self.types.values()
|
||||
)
|
||||
|
||||
def concrete_types(self):
|
||||
def concrete_types(self) -> List[Tuple[str, str]]:
|
||||
"""Returns a list of concrete type/state_keys (i.e. not None) that
|
||||
will be fetched. This will be a complete list if `has_wildcards`
|
||||
returns False, but otherwise will be a subset (or even empty).
|
||||
|
||||
Returns:
|
||||
list[tuple[str,str]]
|
||||
A list of type/state_keys tuples.
|
||||
"""
|
||||
return [
|
||||
(t, s)
|
||||
|
@ -295,7 +292,7 @@ class StateFilter(object):
|
|||
for s in state_keys
|
||||
]
|
||||
|
||||
def get_member_split(self):
|
||||
def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]:
|
||||
"""Return the filter split into two: one which assumes it's exclusively
|
||||
matching against member state, and one which assumes it's matching
|
||||
against non member state.
|
||||
|
@ -307,7 +304,7 @@ class StateFilter(object):
|
|||
state caches).
|
||||
|
||||
Returns:
|
||||
tuple[StateFilter, StateFilter]: The member and non member filters
|
||||
The member and non member filters
|
||||
"""
|
||||
|
||||
if EventTypes.Member in self.types:
|
||||
|
@ -340,6 +337,9 @@ class StateGroupStorage(object):
|
|||
"""Given a state group try to return a previous group and a delta between
|
||||
the old and the new.
|
||||
|
||||
Args:
|
||||
state_group: The state group used to retrieve state deltas.
|
||||
|
||||
Returns:
|
||||
Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]:
|
||||
(prev_group, delta_ids)
|
||||
|
@ -347,55 +347,59 @@ class StateGroupStorage(object):
|
|||
|
||||
return self.stores.state.get_state_group_delta(state_group)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_groups_ids(self, _room_id, event_ids):
|
||||
async def get_state_groups_ids(
|
||||
self, _room_id: str, event_ids: Iterable[str]
|
||||
) -> Dict[int, StateMap[str]]:
|
||||
"""Get the event IDs of all the state for the state groups for the given events
|
||||
|
||||
Args:
|
||||
_room_id (str): id of the room for these events
|
||||
event_ids (iterable[str]): ids of the events
|
||||
_room_id: id of the room for these events
|
||||
event_ids: ids of the events
|
||||
|
||||
Returns:
|
||||
Deferred[dict[int, StateMap[str]]]:
|
||||
dict of state_group_id -> (dict of (type, state_key) -> event id)
|
||||
"""
|
||||
if not event_ids:
|
||||
return {}
|
||||
|
||||
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
|
||||
event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
|
||||
|
||||
groups = set(event_to_groups.values())
|
||||
group_to_state = yield self.stores.state._get_state_for_groups(groups)
|
||||
group_to_state = await self.stores.state._get_state_for_groups(groups)
|
||||
|
||||
return group_to_state
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_ids_for_group(self, state_group):
|
||||
async def get_state_ids_for_group(self, state_group: int) -> StateMap[str]:
|
||||
"""Get the event IDs of all the state in the given state group
|
||||
|
||||
Args:
|
||||
state_group (int)
|
||||
state_group: A state group for which we want to get the state IDs.
|
||||
|
||||
Returns:
|
||||
Deferred[dict]: Resolves to a map of (type, state_key) -> event_id
|
||||
Resolves to a map of (type, state_key) -> event_id
|
||||
"""
|
||||
group_to_state = yield self._get_state_for_groups((state_group,))
|
||||
group_to_state = await self._get_state_for_groups((state_group,))
|
||||
|
||||
return group_to_state[state_group]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_groups(self, room_id, event_ids):
|
||||
async def get_state_groups(
|
||||
self, room_id: str, event_ids: Iterable[str]
|
||||
) -> Dict[int, List[EventBase]]:
|
||||
""" Get the state groups for the given list of event_ids
|
||||
|
||||
Args:
|
||||
room_id: ID of the room for these events.
|
||||
event_ids: The event IDs to retrieve state for.
|
||||
|
||||
Returns:
|
||||
Deferred[dict[int, list[EventBase]]]:
|
||||
dict of state_group_id -> list of state events.
|
||||
"""
|
||||
if not event_ids:
|
||||
return {}
|
||||
|
||||
group_to_ids = yield self.get_state_groups_ids(room_id, event_ids)
|
||||
group_to_ids = await self.get_state_groups_ids(room_id, event_ids)
|
||||
|
||||
state_event_map = yield self.stores.main.get_events(
|
||||
state_event_map = await self.stores.main.get_events(
|
||||
[
|
||||
ev_id
|
||||
for group_ids in group_to_ids.values()
|
||||
|
@ -415,7 +419,7 @@ class StateGroupStorage(object):
|
|||
|
||||
def _get_state_groups_from_groups(
|
||||
self, groups: List[int], state_filter: StateFilter
|
||||
):
|
||||
) -> Awaitable[Dict[int, StateMap[str]]]:
|
||||
"""Returns the state groups for a given set of groups, filtering on
|
||||
types of state events.
|
||||
|
||||
|
@ -423,31 +427,34 @@ class StateGroupStorage(object):
|
|||
groups: list of state group IDs to query
|
||||
state_filter: The state filter used to fetch state
|
||||
from the database.
|
||||
|
||||
Returns:
|
||||
Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
|
||||
Dict of state group to state map.
|
||||
"""
|
||||
|
||||
return self.stores.state._get_state_groups_from_groups(groups, state_filter)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_for_events(self, event_ids, state_filter=StateFilter.all()):
|
||||
async def get_state_for_events(
|
||||
self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
|
||||
):
|
||||
"""Given a list of event_ids and type tuples, return a list of state
|
||||
dicts for each event.
|
||||
|
||||
Args:
|
||||
event_ids (list[string])
|
||||
state_filter (StateFilter): The state filter used to fetch state
|
||||
from the database.
|
||||
event_ids: The events to fetch the state of.
|
||||
state_filter: The state filter used to fetch state.
|
||||
|
||||
Returns:
|
||||
deferred: A dict of (event_id) -> (type, state_key) -> [state_events]
|
||||
A dict of (event_id) -> (type, state_key) -> [state_events]
|
||||
"""
|
||||
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
|
||||
event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
|
||||
|
||||
groups = set(event_to_groups.values())
|
||||
group_to_state = yield self.stores.state._get_state_for_groups(
|
||||
group_to_state = await self.stores.state._get_state_for_groups(
|
||||
groups, state_filter
|
||||
)
|
||||
|
||||
state_event_map = yield self.stores.main.get_events(
|
||||
state_event_map = await self.stores.main.get_events(
|
||||
[ev_id for sd in group_to_state.values() for ev_id in sd.values()],
|
||||
get_prev_content=False,
|
||||
)
|
||||
|
@ -463,24 +470,24 @@ class StateGroupStorage(object):
|
|||
|
||||
return {event: event_to_state[event] for event in event_ids}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()):
|
||||
async def get_state_ids_for_events(
|
||||
self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
|
||||
):
|
||||
"""
|
||||
Get the state dicts corresponding to a list of events, containing the event_ids
|
||||
of the state events (as opposed to the events themselves)
|
||||
|
||||
Args:
|
||||
event_ids(list(str)): events whose state should be returned
|
||||
state_filter (StateFilter): The state filter used to fetch state
|
||||
from the database.
|
||||
event_ids: events whose state should be returned
|
||||
state_filter: The state filter used to fetch state from the database.
|
||||
|
||||
Returns:
|
||||
A deferred dict from event_id -> (type, state_key) -> event_id
|
||||
A dict from event_id -> (type, state_key) -> event_id
|
||||
"""
|
||||
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
|
||||
event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
|
||||
|
||||
groups = set(event_to_groups.values())
|
||||
group_to_state = yield self.stores.state._get_state_for_groups(
|
||||
group_to_state = await self.stores.state._get_state_for_groups(
|
||||
groups, state_filter
|
||||
)
|
||||
|
||||
|
@ -491,67 +498,72 @@ class StateGroupStorage(object):
|
|||
|
||||
return {event: event_to_state[event] for event in event_ids}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_for_event(self, event_id, state_filter=StateFilter.all()):
|
||||
async def get_state_for_event(
|
||||
self, event_id: str, state_filter: StateFilter = StateFilter.all()
|
||||
):
|
||||
"""
|
||||
Get the state dict corresponding to a particular event
|
||||
|
||||
Args:
|
||||
event_id(str): event whose state should be returned
|
||||
state_filter (StateFilter): The state filter used to fetch state
|
||||
from the database.
|
||||
event_id: event whose state should be returned
|
||||
state_filter: The state filter used to fetch state from the database.
|
||||
|
||||
Returns:
|
||||
A deferred dict from (type, state_key) -> state_event
|
||||
A dict from (type, state_key) -> state_event
|
||||
"""
|
||||
state_map = yield self.get_state_for_events([event_id], state_filter)
|
||||
state_map = await self.get_state_for_events([event_id], state_filter)
|
||||
return state_map[event_id]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()):
|
||||
async def get_state_ids_for_event(
|
||||
self, event_id: str, state_filter: StateFilter = StateFilter.all()
|
||||
):
|
||||
"""
|
||||
Get the state dict corresponding to a particular event
|
||||
|
||||
Args:
|
||||
event_id(str): event whose state should be returned
|
||||
state_filter (StateFilter): The state filter used to fetch state
|
||||
from the database.
|
||||
event_id: event whose state should be returned
|
||||
state_filter: The state filter used to fetch state from the database.
|
||||
|
||||
Returns:
|
||||
A deferred dict from (type, state_key) -> state_event
|
||||
"""
|
||||
state_map = yield self.get_state_ids_for_events([event_id], state_filter)
|
||||
state_map = await self.get_state_ids_for_events([event_id], state_filter)
|
||||
return state_map[event_id]
|
||||
|
||||
def _get_state_for_groups(
|
||||
self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
|
||||
):
|
||||
) -> Awaitable[Dict[int, StateMap[str]]]:
|
||||
"""Gets the state at each of a list of state groups, optionally
|
||||
filtering by type/state_key
|
||||
|
||||
Args:
|
||||
groups (iterable[int]): list of state groups for which we want
|
||||
to get the state.
|
||||
state_filter (StateFilter): The state filter used to fetch state
|
||||
groups: list of state groups for which we want to get the state.
|
||||
state_filter: The state filter used to fetch state.
|
||||
from the database.
|
||||
|
||||
Returns:
|
||||
Deferred[dict[int, StateMap[str]]]: Dict of state group to state map.
|
||||
Dict of state group to state map.
|
||||
"""
|
||||
return self.stores.state._get_state_for_groups(groups, state_filter)
|
||||
|
||||
def store_state_group(
|
||||
self, event_id, room_id, prev_group, delta_ids, current_state_ids
|
||||
self,
|
||||
event_id: str,
|
||||
room_id: str,
|
||||
prev_group: Optional[int],
|
||||
delta_ids: Optional[dict],
|
||||
current_state_ids: dict,
|
||||
):
|
||||
"""Store a new set of state, returning a newly assigned state group.
|
||||
|
||||
Args:
|
||||
event_id (str): The event ID for which the state was calculated
|
||||
room_id (str)
|
||||
prev_group (int|None): A previous state group for the room, optional.
|
||||
delta_ids (dict|None): The delta between state at `prev_group` and
|
||||
event_id: The event ID for which the state was calculated.
|
||||
room_id: ID of the room for which the state was calculated.
|
||||
prev_group: A previous state group for the room, optional.
|
||||
delta_ids: The delta between state at `prev_group` and
|
||||
`current_state_ids`, if `prev_group` was given. Same format as
|
||||
`current_state_ids`.
|
||||
current_state_ids (dict): The state to store. Map of (type, state_key)
|
||||
current_state_ids: The state to store. Map of (type, state_key)
|
||||
to event_id.
|
||||
|
||||
Returns:
|
||||
|
|
|
@ -16,8 +16,6 @@
|
|||
import logging
|
||||
import operator
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.events.utils import prune_event
|
||||
from synapse.storage import Storage
|
||||
|
@ -39,8 +37,7 @@ MEMBERSHIP_PRIORITY = (
|
|||
)
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def filter_events_for_client(
|
||||
async def filter_events_for_client(
|
||||
storage: Storage,
|
||||
user_id,
|
||||
events,
|
||||
|
@ -67,19 +64,19 @@ def filter_events_for_client(
|
|||
also be called to check whether a user can see the state at a given point.
|
||||
|
||||
Returns:
|
||||
Deferred[list[synapse.events.EventBase]]
|
||||
list[synapse.events.EventBase]
|
||||
"""
|
||||
# Filter out events that have been soft failed so that we don't relay them
|
||||
# to clients.
|
||||
events = [e for e in events if not e.internal_metadata.is_soft_failed()]
|
||||
|
||||
types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id))
|
||||
event_id_to_state = yield storage.state.get_state_for_events(
|
||||
event_id_to_state = await storage.state.get_state_for_events(
|
||||
frozenset(e.event_id for e in events),
|
||||
state_filter=StateFilter.from_types(types),
|
||||
)
|
||||
|
||||
ignore_dict_content = yield storage.main.get_global_account_data_by_type_for_user(
|
||||
ignore_dict_content = await storage.main.get_global_account_data_by_type_for_user(
|
||||
"m.ignored_user_list", user_id
|
||||
)
|
||||
|
||||
|
@ -90,7 +87,7 @@ def filter_events_for_client(
|
|||
else []
|
||||
)
|
||||
|
||||
erased_senders = yield storage.main.are_users_erased((e.sender for e in events))
|
||||
erased_senders = await storage.main.are_users_erased((e.sender for e in events))
|
||||
|
||||
if filter_send_to_client:
|
||||
room_ids = {e.room_id for e in events}
|
||||
|
@ -99,7 +96,7 @@ def filter_events_for_client(
|
|||
for room_id in room_ids:
|
||||
retention_policies[
|
||||
room_id
|
||||
] = yield storage.main.get_retention_policy_for_room(room_id)
|
||||
] = await storage.main.get_retention_policy_for_room(room_id)
|
||||
|
||||
def allowed(event):
|
||||
"""
|
||||
|
@ -254,8 +251,7 @@ def filter_events_for_client(
|
|||
return list(filtered_events)
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def filter_events_for_server(
|
||||
async def filter_events_for_server(
|
||||
storage: Storage,
|
||||
server_name,
|
||||
events,
|
||||
|
@ -277,7 +273,7 @@ def filter_events_for_server(
|
|||
backfill or not.
|
||||
|
||||
Returns
|
||||
Deferred[list[FrozenEvent]]
|
||||
list[FrozenEvent]
|
||||
"""
|
||||
|
||||
def is_sender_erased(event, erased_senders):
|
||||
|
@ -321,7 +317,7 @@ def filter_events_for_server(
|
|||
# Lets check to see if all the events have a history visibility
|
||||
# of "shared" or "world_readable". If that's the case then we don't
|
||||
# need to check membership (as we know the server is in the room).
|
||||
event_to_state_ids = yield storage.state.get_state_ids_for_events(
|
||||
event_to_state_ids = await storage.state.get_state_ids_for_events(
|
||||
frozenset(e.event_id for e in events),
|
||||
state_filter=StateFilter.from_types(
|
||||
types=((EventTypes.RoomHistoryVisibility, ""),)
|
||||
|
@ -339,14 +335,14 @@ def filter_events_for_server(
|
|||
if not visibility_ids:
|
||||
all_open = True
|
||||
else:
|
||||
event_map = yield storage.main.get_events(visibility_ids)
|
||||
event_map = await storage.main.get_events(visibility_ids)
|
||||
all_open = all(
|
||||
e.content.get("history_visibility") in (None, "shared", "world_readable")
|
||||
for e in event_map.values()
|
||||
)
|
||||
|
||||
if not check_history_visibility_only:
|
||||
erased_senders = yield storage.main.are_users_erased((e.sender for e in events))
|
||||
erased_senders = await storage.main.are_users_erased((e.sender for e in events))
|
||||
else:
|
||||
# We don't want to check whether users are erased, which is equivalent
|
||||
# to no users having been erased.
|
||||
|
@ -375,7 +371,7 @@ def filter_events_for_server(
|
|||
|
||||
# first, for each event we're wanting to return, get the event_ids
|
||||
# of the history vis and membership state at those events.
|
||||
event_to_state_ids = yield storage.state.get_state_ids_for_events(
|
||||
event_to_state_ids = await storage.state.get_state_ids_for_events(
|
||||
frozenset(e.event_id for e in events),
|
||||
state_filter=StateFilter.from_types(
|
||||
types=((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, None))
|
||||
|
@ -405,7 +401,7 @@ def filter_events_for_server(
|
|||
return False
|
||||
return state_key[idx + 1 :] == server_name
|
||||
|
||||
event_map = yield storage.main.get_events(
|
||||
event_map = await storage.main.get_events(
|
||||
[e_id for e_id, key in event_id_to_state_key.items() if include(key[0], key[1])]
|
||||
)
|
||||
|
||||
|
|
|
@ -50,13 +50,17 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
|||
def test_regex_user_id_prefix_match(self):
|
||||
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
|
||||
self.event.sender = "@irc_foobar:matrix.org"
|
||||
self.assertTrue((yield self.service.is_interested(self.event)))
|
||||
self.assertTrue(
|
||||
(yield defer.ensureDeferred(self.service.is_interested(self.event)))
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_regex_user_id_prefix_no_match(self):
|
||||
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
|
||||
self.event.sender = "@someone_else:matrix.org"
|
||||
self.assertFalse((yield self.service.is_interested(self.event)))
|
||||
self.assertFalse(
|
||||
(yield defer.ensureDeferred(self.service.is_interested(self.event)))
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_regex_room_member_is_checked(self):
|
||||
|
@ -64,7 +68,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
|||
self.event.sender = "@someone_else:matrix.org"
|
||||
self.event.type = "m.room.member"
|
||||
self.event.state_key = "@irc_foobar:matrix.org"
|
||||
self.assertTrue((yield self.service.is_interested(self.event)))
|
||||
self.assertTrue(
|
||||
(yield defer.ensureDeferred(self.service.is_interested(self.event)))
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_regex_room_id_match(self):
|
||||
|
@ -72,7 +78,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
|||
_regex("!some_prefix.*some_suffix:matrix.org")
|
||||
)
|
||||
self.event.room_id = "!some_prefixs0m3th1nGsome_suffix:matrix.org"
|
||||
self.assertTrue((yield self.service.is_interested(self.event)))
|
||||
self.assertTrue(
|
||||
(yield defer.ensureDeferred(self.service.is_interested(self.event)))
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_regex_room_id_no_match(self):
|
||||
|
@ -80,19 +88,26 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
|||
_regex("!some_prefix.*some_suffix:matrix.org")
|
||||
)
|
||||
self.event.room_id = "!XqBunHwQIXUiqCaoxq:matrix.org"
|
||||
self.assertFalse((yield self.service.is_interested(self.event)))
|
||||
self.assertFalse(
|
||||
(yield defer.ensureDeferred(self.service.is_interested(self.event)))
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_regex_alias_match(self):
|
||||
self.service.namespaces[ApplicationService.NS_ALIASES].append(
|
||||
_regex("#irc_.*:matrix.org")
|
||||
)
|
||||
self.store.get_aliases_for_room.return_value = [
|
||||
"#irc_foobar:matrix.org",
|
||||
"#athing:matrix.org",
|
||||
]
|
||||
self.store.get_users_in_room.return_value = []
|
||||
self.assertTrue((yield self.service.is_interested(self.event, self.store)))
|
||||
self.store.get_aliases_for_room.return_value = defer.succeed(
|
||||
["#irc_foobar:matrix.org", "#athing:matrix.org"]
|
||||
)
|
||||
self.store.get_users_in_room.return_value = defer.succeed([])
|
||||
self.assertTrue(
|
||||
(
|
||||
yield defer.ensureDeferred(
|
||||
self.service.is_interested(self.event, self.store)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
def test_non_exclusive_alias(self):
|
||||
self.service.namespaces[ApplicationService.NS_ALIASES].append(
|
||||
|
@ -135,12 +150,17 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
|||
self.service.namespaces[ApplicationService.NS_ALIASES].append(
|
||||
_regex("#irc_.*:matrix.org")
|
||||
)
|
||||
self.store.get_aliases_for_room.return_value = [
|
||||
"#xmpp_foobar:matrix.org",
|
||||
"#athing:matrix.org",
|
||||
]
|
||||
self.store.get_users_in_room.return_value = []
|
||||
self.assertFalse((yield self.service.is_interested(self.event, self.store)))
|
||||
self.store.get_aliases_for_room.return_value = defer.succeed(
|
||||
["#xmpp_foobar:matrix.org", "#athing:matrix.org"]
|
||||
)
|
||||
self.store.get_users_in_room.return_value = defer.succeed([])
|
||||
self.assertFalse(
|
||||
(
|
||||
yield defer.ensureDeferred(
|
||||
self.service.is_interested(self.event, self.store)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_regex_multiple_matches(self):
|
||||
|
@ -149,9 +169,17 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
|||
)
|
||||
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
|
||||
self.event.sender = "@irc_foobar:matrix.org"
|
||||
self.store.get_aliases_for_room.return_value = ["#irc_barfoo:matrix.org"]
|
||||
self.store.get_users_in_room.return_value = []
|
||||
self.assertTrue((yield self.service.is_interested(self.event, self.store)))
|
||||
self.store.get_aliases_for_room.return_value = defer.succeed(
|
||||
["#irc_barfoo:matrix.org"]
|
||||
)
|
||||
self.store.get_users_in_room.return_value = defer.succeed([])
|
||||
self.assertTrue(
|
||||
(
|
||||
yield defer.ensureDeferred(
|
||||
self.service.is_interested(self.event, self.store)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_interested_in_self(self):
|
||||
|
@ -161,19 +189,24 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
|||
self.event.type = "m.room.member"
|
||||
self.event.content = {"membership": "invite"}
|
||||
self.event.state_key = self.service.sender
|
||||
self.assertTrue((yield self.service.is_interested(self.event)))
|
||||
self.assertTrue(
|
||||
(yield defer.ensureDeferred(self.service.is_interested(self.event)))
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_member_list_match(self):
|
||||
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
|
||||
self.store.get_users_in_room.return_value = [
|
||||
"@alice:here",
|
||||
"@irc_fo:here", # AS user
|
||||
"@bob:here",
|
||||
]
|
||||
self.store.get_aliases_for_room.return_value = []
|
||||
# Note that @irc_fo:here is the AS user.
|
||||
self.store.get_users_in_room.return_value = defer.succeed(
|
||||
["@alice:here", "@irc_fo:here", "@bob:here"]
|
||||
)
|
||||
self.store.get_aliases_for_room.return_value = defer.succeed([])
|
||||
|
||||
self.event.sender = "@xmpp_foobar:matrix.org"
|
||||
self.assertTrue(
|
||||
(yield self.service.is_interested(event=self.event, store=self.store))
|
||||
(
|
||||
yield defer.ensureDeferred(
|
||||
self.service.is_interested(event=self.event, store=self.store)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
|
|
@ -25,6 +25,7 @@ from synapse.appservice.scheduler import (
|
|||
from synapse.logging.context import make_deferred_yieldable
|
||||
|
||||
from tests import unittest
|
||||
from tests.test_utils import make_awaitable
|
||||
|
||||
from ..utils import MockClock
|
||||
|
||||
|
@ -52,11 +53,11 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
|
|||
self.store.get_appservice_state = Mock(
|
||||
return_value=defer.succeed(ApplicationServiceState.UP)
|
||||
)
|
||||
txn.send = Mock(return_value=defer.succeed(True))
|
||||
txn.send = Mock(return_value=make_awaitable(True))
|
||||
self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn))
|
||||
|
||||
# actual call
|
||||
self.txnctrl.send(service, events)
|
||||
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
|
||||
|
||||
self.store.create_appservice_txn.assert_called_once_with(
|
||||
service=service, events=events # txn made and saved
|
||||
|
@ -77,7 +78,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
|
|||
self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn))
|
||||
|
||||
# actual call
|
||||
self.txnctrl.send(service, events)
|
||||
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
|
||||
|
||||
self.store.create_appservice_txn.assert_called_once_with(
|
||||
service=service, events=events # txn made and saved
|
||||
|
@ -98,11 +99,11 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
|
|||
return_value=defer.succeed(ApplicationServiceState.UP)
|
||||
)
|
||||
self.store.set_appservice_state = Mock(return_value=defer.succeed(True))
|
||||
txn.send = Mock(return_value=defer.succeed(False)) # fails to send
|
||||
txn.send = Mock(return_value=make_awaitable(False)) # fails to send
|
||||
self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn))
|
||||
|
||||
# actual call
|
||||
self.txnctrl.send(service, events)
|
||||
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
|
||||
|
||||
self.store.create_appservice_txn.assert_called_once_with(
|
||||
service=service, events=events
|
||||
|
@ -144,7 +145,8 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
|
|||
self.recoverer.recover()
|
||||
# shouldn't have called anything prior to waiting for exp backoff
|
||||
self.assertEquals(0, self.store.get_oldest_unsent_txn.call_count)
|
||||
txn.send = Mock(return_value=True)
|
||||
txn.send = Mock(return_value=make_awaitable(True))
|
||||
txn.complete.return_value = make_awaitable(None)
|
||||
# wait for exp backoff
|
||||
self.clock.advance_time(2)
|
||||
self.assertEquals(1, txn.send.call_count)
|
||||
|
@ -169,7 +171,8 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
|
|||
|
||||
self.recoverer.recover()
|
||||
self.assertEquals(0, self.store.get_oldest_unsent_txn.call_count)
|
||||
txn.send = Mock(return_value=False)
|
||||
txn.send = Mock(return_value=make_awaitable(False))
|
||||
txn.complete.return_value = make_awaitable(None)
|
||||
self.clock.advance_time(2)
|
||||
self.assertEquals(1, txn.send.call_count)
|
||||
self.assertEquals(0, txn.complete.call_count)
|
||||
|
@ -182,7 +185,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
|
|||
self.assertEquals(3, txn.send.call_count)
|
||||
self.assertEquals(0, txn.complete.call_count)
|
||||
self.assertEquals(0, self.callback.call_count)
|
||||
txn.send = Mock(return_value=True) # successfully send the txn
|
||||
txn.send = Mock(return_value=make_awaitable(True)) # successfully send the txn
|
||||
pop_txn = True # returns the txn the first time, then no more.
|
||||
self.clock.advance_time(16)
|
||||
self.assertEquals(1, txn.send.call_count) # new mock reset call count
|
||||
|
|
|
@ -102,11 +102,10 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
|||
}
|
||||
persp_deferred = defer.Deferred()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_perspectives(**kwargs):
|
||||
async def get_perspectives(**kwargs):
|
||||
self.assertEquals(current_context().request, "11")
|
||||
with PreserveLoggingContext():
|
||||
yield persp_deferred
|
||||
await persp_deferred
|
||||
return persp_resp
|
||||
|
||||
self.http_client.post_json.side_effect = get_perspectives
|
||||
|
@ -355,7 +354,7 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
|
|||
}
|
||||
signedjson.sign.sign_json(response, SERVER_NAME, testkey)
|
||||
|
||||
def get_json(destination, path, **kwargs):
|
||||
async def get_json(destination, path, **kwargs):
|
||||
self.assertEqual(destination, SERVER_NAME)
|
||||
self.assertEqual(path, "/_matrix/key/v2/server/key1")
|
||||
return response
|
||||
|
@ -444,7 +443,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
|
|||
Tell the mock http client to expect a perspectives-server key query
|
||||
"""
|
||||
|
||||
def post_json(destination, path, data, **kwargs):
|
||||
async def post_json(destination, path, data, **kwargs):
|
||||
self.assertEqual(destination, self.mock_perspective_server.server_name)
|
||||
self.assertEqual(path, "/_matrix/key/v2/query")
|
||||
|
||||
|
@ -580,14 +579,12 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
|
|||
# remove the perspectives server's signature
|
||||
response = build_response()
|
||||
del response["signatures"][self.mock_perspective_server.server_name]
|
||||
self.http_client.post_json.return_value = {"server_keys": [response]}
|
||||
keys = get_key_from_perspectives(response)
|
||||
self.assertEqual(keys, {}, "Expected empty dict with missing persp server sig")
|
||||
|
||||
# remove the origin server's signature
|
||||
response = build_response()
|
||||
del response["signatures"][SERVER_NAME]
|
||||
self.http_client.post_json.return_value = {"server_keys": [response]}
|
||||
keys = get_key_from_perspectives(response)
|
||||
self.assertEqual(keys, {}, "Expected empty dict with missing origin server sig")
|
||||
|
||||
|
|
|
@ -23,6 +23,7 @@ from synapse.rest.client.v1 import login, room
|
|||
from synapse.types import UserID
|
||||
|
||||
from tests import unittest
|
||||
from tests.test_utils import make_awaitable
|
||||
|
||||
|
||||
class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
|
||||
|
@ -78,9 +79,40 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
|
|||
fed_transport = self.hs.get_federation_transport_client()
|
||||
|
||||
# Mock out some things, because we don't want to test the whole join
|
||||
fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999}))
|
||||
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
|
||||
handler.federation_handler.do_invite_join = Mock(
|
||||
return_value=defer.succeed(("", 1))
|
||||
return_value=make_awaitable(("", 1))
|
||||
)
|
||||
|
||||
d = handler._remote_join(
|
||||
None,
|
||||
["other.example.com"],
|
||||
"roomid",
|
||||
UserID.from_string(u1),
|
||||
{"membership": "join"},
|
||||
)
|
||||
|
||||
self.pump()
|
||||
|
||||
# The request failed with a SynapseError saying the resource limit was
|
||||
# exceeded.
|
||||
f = self.get_failure(d, SynapseError)
|
||||
self.assertEqual(f.value.code, 400, f.value)
|
||||
self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
|
||||
|
||||
def test_join_too_large_admin(self):
|
||||
# Check whether an admin can join if option "admins_can_join" is undefined,
|
||||
# this option defaults to false, so the join should fail.
|
||||
|
||||
u1 = self.register_user("u1", "pass", admin=True)
|
||||
|
||||
handler = self.hs.get_room_member_handler()
|
||||
fed_transport = self.hs.get_federation_transport_client()
|
||||
|
||||
# Mock out some things, because we don't want to test the whole join
|
||||
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
|
||||
handler.federation_handler.do_invite_join = Mock(
|
||||
return_value=make_awaitable(("", 1))
|
||||
)
|
||||
|
||||
d = handler._remote_join(
|
||||
|
@ -116,9 +148,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
|
|||
fed_transport = self.hs.get_federation_transport_client()
|
||||
|
||||
# Mock out some things, because we don't want to test the whole join
|
||||
fed_transport.client.get_json = Mock(return_value=defer.succeed(None))
|
||||
fed_transport.client.get_json = Mock(return_value=make_awaitable(None))
|
||||
handler.federation_handler.do_invite_join = Mock(
|
||||
return_value=defer.succeed(("", 1))
|
||||
return_value=make_awaitable(("", 1))
|
||||
)
|
||||
|
||||
# Artificially raise the complexity
|
||||
|
@ -141,3 +173,81 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
|
|||
f = self.get_failure(d, SynapseError)
|
||||
self.assertEqual(f.value.code, 400)
|
||||
self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
|
||||
|
||||
|
||||
class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
|
||||
# Test the behavior of joining rooms which exceed the complexity if option
|
||||
# limit_remote_rooms.admins_can_join is True.
|
||||
|
||||
servlets = [
|
||||
admin.register_servlets,
|
||||
room.register_servlets,
|
||||
login.register_servlets,
|
||||
]
|
||||
|
||||
def default_config(self):
|
||||
config = super().default_config()
|
||||
config["limit_remote_rooms"] = {
|
||||
"enabled": True,
|
||||
"complexity": 0.05,
|
||||
"admins_can_join": True,
|
||||
}
|
||||
return config
|
||||
|
||||
def test_join_too_large_no_admin(self):
|
||||
# A user which is not an admin should not be able to join a remote room
|
||||
# which is too complex.
|
||||
|
||||
u1 = self.register_user("u1", "pass")
|
||||
|
||||
handler = self.hs.get_room_member_handler()
|
||||
fed_transport = self.hs.get_federation_transport_client()
|
||||
|
||||
# Mock out some things, because we don't want to test the whole join
|
||||
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
|
||||
handler.federation_handler.do_invite_join = Mock(
|
||||
return_value=make_awaitable(("", 1))
|
||||
)
|
||||
|
||||
d = handler._remote_join(
|
||||
None,
|
||||
["other.example.com"],
|
||||
"roomid",
|
||||
UserID.from_string(u1),
|
||||
{"membership": "join"},
|
||||
)
|
||||
|
||||
self.pump()
|
||||
|
||||
# The request failed with a SynapseError saying the resource limit was
|
||||
# exceeded.
|
||||
f = self.get_failure(d, SynapseError)
|
||||
self.assertEqual(f.value.code, 400, f.value)
|
||||
self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
|
||||
|
||||
def test_join_too_large_admin(self):
|
||||
# An admin should be able to join rooms where a complexity check fails.
|
||||
|
||||
u1 = self.register_user("u1", "pass", admin=True)
|
||||
|
||||
handler = self.hs.get_room_member_handler()
|
||||
fed_transport = self.hs.get_federation_transport_client()
|
||||
|
||||
# Mock out some things, because we don't want to test the whole join
|
||||
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
|
||||
handler.federation_handler.do_invite_join = Mock(
|
||||
return_value=make_awaitable(("", 1))
|
||||
)
|
||||
|
||||
d = handler._remote_join(
|
||||
None,
|
||||
["other.example.com"],
|
||||
"roomid",
|
||||
UserID.from_string(u1),
|
||||
{"membership": "join"},
|
||||
)
|
||||
|
||||
self.pump()
|
||||
|
||||
# The request success since the user is an admin
|
||||
self.get_success(d)
|
||||
|
|
|
@ -47,13 +47,13 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
|
|||
mock_send_transaction = (
|
||||
self.hs.get_federation_transport_client().send_transaction
|
||||
)
|
||||
mock_send_transaction.return_value = defer.succeed({})
|
||||
mock_send_transaction.return_value = make_awaitable({})
|
||||
|
||||
sender = self.hs.get_federation_sender()
|
||||
receipt = ReadReceipt(
|
||||
"room_id", "m.read", "user_id", ["event_id"], {"ts": 1234}
|
||||
)
|
||||
self.successResultOf(sender.send_read_receipt(receipt))
|
||||
self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
|
||||
|
||||
self.pump()
|
||||
|
||||
|
@ -87,13 +87,13 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
|
|||
mock_send_transaction = (
|
||||
self.hs.get_federation_transport_client().send_transaction
|
||||
)
|
||||
mock_send_transaction.return_value = defer.succeed({})
|
||||
mock_send_transaction.return_value = make_awaitable({})
|
||||
|
||||
sender = self.hs.get_federation_sender()
|
||||
receipt = ReadReceipt(
|
||||
"room_id", "m.read", "user_id", ["event_id"], {"ts": 1234}
|
||||
)
|
||||
self.successResultOf(sender.send_read_receipt(receipt))
|
||||
self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
|
||||
|
||||
self.pump()
|
||||
|
||||
|
@ -125,7 +125,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
|
|||
receipt = ReadReceipt(
|
||||
"room_id", "m.read", "user_id", ["other_id"], {"ts": 1234}
|
||||
)
|
||||
self.successResultOf(sender.send_read_receipt(receipt))
|
||||
self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
|
||||
self.pump()
|
||||
mock_send_transaction.assert_not_called()
|
||||
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue