mirror of
https://mau.dev/maunium/synapse.git
synced 2024-11-15 22:42:23 +01:00
Merge remote-tracking branch 'upstream/release-v1.56'
This commit is contained in:
commit
ef6a9c8a70
98 changed files with 6349 additions and 5436 deletions
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
|
@ -377,7 +377,7 @@ jobs:
|
||||||
# Run Complement
|
# Run Complement
|
||||||
- run: |
|
- run: |
|
||||||
set -o pipefail
|
set -o pipefail
|
||||||
go test -v -json -p 1 -tags synapse_blacklist,msc2403,msc2716,msc3030 ./tests/... 2>&1 | gotestfmt
|
go test -v -json -tags synapse_blacklist,msc2403,msc2716,msc3030 ./tests/... 2>&1 | gotestfmt
|
||||||
shell: bash
|
shell: bash
|
||||||
name: Run Complement Tests
|
name: Run Complement Tests
|
||||||
env:
|
env:
|
||||||
|
|
3640
CHANGES-pre-1.0.md
Normal file
3640
CHANGES-pre-1.0.md
Normal file
File diff suppressed because it is too large
Load diff
3724
CHANGES.md
3724
CHANGES.md
File diff suppressed because it is too large
Load diff
18
debian/changelog
vendored
18
debian/changelog
vendored
|
@ -1,3 +1,21 @@
|
||||||
|
matrix-synapse-py3 (1.56.0~rc1) stable; urgency=medium
|
||||||
|
|
||||||
|
* New synapse release 1.56.0~rc1.
|
||||||
|
|
||||||
|
-- Synapse Packaging team <packages@matrix.org> Tue, 29 Mar 2022 10:40:50 +0100
|
||||||
|
|
||||||
|
matrix-synapse-py3 (1.55.2) stable; urgency=medium
|
||||||
|
|
||||||
|
* New synapse release 1.55.2.
|
||||||
|
|
||||||
|
-- Synapse Packaging team <packages@matrix.org> Thu, 24 Mar 2022 19:07:11 +0000
|
||||||
|
|
||||||
|
matrix-synapse-py3 (1.55.1) stable; urgency=medium
|
||||||
|
|
||||||
|
* New synapse release 1.55.1.
|
||||||
|
|
||||||
|
-- Synapse Packaging team <packages@matrix.org> Thu, 24 Mar 2022 17:44:23 +0000
|
||||||
|
|
||||||
matrix-synapse-py3 (1.55.0) stable; urgency=medium
|
matrix-synapse-py3 (1.55.0) stable; urgency=medium
|
||||||
|
|
||||||
* New synapse release 1.55.0.
|
* New synapse release 1.55.0.
|
||||||
|
|
|
@ -38,6 +38,7 @@ for port in 8080 8081 8082; do
|
||||||
printf '\n\n# Customisation made by demo/start.sh\n\n'
|
printf '\n\n# Customisation made by demo/start.sh\n\n'
|
||||||
echo "public_baseurl: http://localhost:$port/"
|
echo "public_baseurl: http://localhost:$port/"
|
||||||
echo 'enable_registration: true'
|
echo 'enable_registration: true'
|
||||||
|
echo 'enable_registration_without_verification: true'
|
||||||
echo ''
|
echo ''
|
||||||
|
|
||||||
# Warning, this heredoc depends on the interaction of tabs and spaces.
|
# Warning, this heredoc depends on the interaction of tabs and spaces.
|
||||||
|
|
|
@ -172,7 +172,7 @@ any of the subsequent implementations of this callback.
|
||||||
_First introduced in Synapse v1.37.0_
|
_First introduced in Synapse v1.37.0_
|
||||||
|
|
||||||
```python
|
```python
|
||||||
async def check_username_for_spam(user_profile: Dict[str, str]) -> bool
|
async def check_username_for_spam(user_profile: synapse.module_api.UserProfile) -> bool
|
||||||
```
|
```
|
||||||
|
|
||||||
Called when computing search results in the user directory. The module must return a
|
Called when computing search results in the user directory. The module must return a
|
||||||
|
@ -182,9 +182,11 @@ search results; otherwise return `False`.
|
||||||
|
|
||||||
The profile is represented as a dictionary with the following keys:
|
The profile is represented as a dictionary with the following keys:
|
||||||
|
|
||||||
* `user_id`: The Matrix ID for this user.
|
* `user_id: str`. The Matrix ID for this user.
|
||||||
* `display_name`: The user's display name.
|
* `display_name: Optional[str]`. The user's display name, or `None` if this user
|
||||||
* `avatar_url`: The `mxc://` URL to the user's avatar.
|
has not set a display name.
|
||||||
|
* `avatar_url: Optional[str]`. The `mxc://` URL to the user's avatar, or `None`
|
||||||
|
if this user has not set an avatar.
|
||||||
|
|
||||||
The module is given a copy of the original dictionary, so modifying it from within the
|
The module is given a copy of the original dictionary, so modifying it from within the
|
||||||
module cannot modify a user's profile when included in user directory search results.
|
module cannot modify a user's profile when included in user directory search results.
|
||||||
|
|
|
@ -225,6 +225,8 @@ oidc_providers:
|
||||||
3. Create an application for synapse in Authentik and link it to the provider.
|
3. Create an application for synapse in Authentik and link it to the provider.
|
||||||
4. Note the slug of your application, Client ID and Client Secret.
|
4. Note the slug of your application, Client ID and Client Secret.
|
||||||
|
|
||||||
|
Note: RSA keys must be used for signing for Authentik, ECC keys do not work.
|
||||||
|
|
||||||
Synapse config:
|
Synapse config:
|
||||||
```yaml
|
```yaml
|
||||||
oidc_providers:
|
oidc_providers:
|
||||||
|
@ -240,7 +242,7 @@ oidc_providers:
|
||||||
- "email"
|
- "email"
|
||||||
user_mapping_provider:
|
user_mapping_provider:
|
||||||
config:
|
config:
|
||||||
localpart_template: "{{ user.preferred_username }}}"
|
localpart_template: "{{ user.preferred_username }}"
|
||||||
display_name_template: "{{ user.preferred_username|capitalize }}" # TO BE FILLED: If your users have names in Authentik and you want those in Synapse, this should be replaced with user.name|capitalize.
|
display_name_template: "{{ user.preferred_username|capitalize }}" # TO BE FILLED: If your users have names in Authentik and you want those in Synapse, this should be replaced with user.name|capitalize.
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -234,12 +234,13 @@ host all all ::1/128 ident
|
||||||
### Fixing incorrect `COLLATE` or `CTYPE`
|
### Fixing incorrect `COLLATE` or `CTYPE`
|
||||||
|
|
||||||
Synapse will refuse to set up a new database if it has the wrong values of
|
Synapse will refuse to set up a new database if it has the wrong values of
|
||||||
`COLLATE` and `CTYPE` set, and will log warnings on existing databases. Using
|
`COLLATE` and `CTYPE` set. Synapse will also refuse to start an existing database with incorrect values
|
||||||
different locales can cause issues if the locale library is updated from
|
of `COLLATE` and `CTYPE` unless the config flag `allow_unsafe_locale`, found in the
|
||||||
|
`database` section of the config, is set to true. Using different locales can cause issues if the locale library is updated from
|
||||||
underneath the database, or if a different version of the locale is used on any
|
underneath the database, or if a different version of the locale is used on any
|
||||||
replicas.
|
replicas.
|
||||||
|
|
||||||
The safest way to fix the issue is to dump the database and recreate it with
|
If you have a databse with an unsafe locale, the safest way to fix the issue is to dump the database and recreate it with
|
||||||
the correct locale parameter (as shown above). It is also possible to change the
|
the correct locale parameter (as shown above). It is also possible to change the
|
||||||
parameters on a live database and run a `REINDEX` on the entire database,
|
parameters on a live database and run a `REINDEX` on the entire database,
|
||||||
however extreme care must be taken to avoid database corruption.
|
however extreme care must be taken to avoid database corruption.
|
||||||
|
|
|
@ -182,7 +182,7 @@ matrix.example.com {
|
||||||
|
|
||||||
```
|
```
|
||||||
frontend https
|
frontend https
|
||||||
bind :::443 v4v6 ssl crt /etc/ssl/haproxy/ strict-sni alpn h2,http/1.1
|
bind *:443,[::]:443 ssl crt /etc/ssl/haproxy/ strict-sni alpn h2,http/1.1
|
||||||
http-request set-header X-Forwarded-Proto https if { ssl_fc }
|
http-request set-header X-Forwarded-Proto https if { ssl_fc }
|
||||||
http-request set-header X-Forwarded-Proto http if !{ ssl_fc }
|
http-request set-header X-Forwarded-Proto http if !{ ssl_fc }
|
||||||
http-request set-header X-Forwarded-For %[src]
|
http-request set-header X-Forwarded-For %[src]
|
||||||
|
@ -195,7 +195,7 @@ frontend https
|
||||||
use_backend matrix if matrix-host matrix-path
|
use_backend matrix if matrix-host matrix-path
|
||||||
|
|
||||||
frontend matrix-federation
|
frontend matrix-federation
|
||||||
bind :::8448 v4v6 ssl crt /etc/ssl/haproxy/synapse.pem alpn h2,http/1.1
|
bind *:8448,[::]:8448 ssl crt /etc/ssl/haproxy/synapse.pem alpn h2,http/1.1
|
||||||
http-request set-header X-Forwarded-Proto https if { ssl_fc }
|
http-request set-header X-Forwarded-Proto https if { ssl_fc }
|
||||||
http-request set-header X-Forwarded-Proto http if !{ ssl_fc }
|
http-request set-header X-Forwarded-Proto http if !{ ssl_fc }
|
||||||
http-request set-header X-Forwarded-For %[src]
|
http-request set-header X-Forwarded-For %[src]
|
||||||
|
|
|
@ -783,6 +783,12 @@ caches:
|
||||||
# 'txn_limit' gives the maximum number of transactions to run per connection
|
# 'txn_limit' gives the maximum number of transactions to run per connection
|
||||||
# before reconnecting. Defaults to 0, which means no limit.
|
# before reconnecting. Defaults to 0, which means no limit.
|
||||||
#
|
#
|
||||||
|
# 'allow_unsafe_locale' is an option specific to Postgres. Under the default behavior, Synapse will refuse to
|
||||||
|
# start if the postgres db is set to a non-C locale. You can override this behavior (which is *not* recommended)
|
||||||
|
# by setting 'allow_unsafe_locale' to true. Note that doing so may corrupt your database. You can find more information
|
||||||
|
# here: https://matrix-org.github.io/synapse/latest/postgres.html#fixing-incorrect-collate-or-ctype and here:
|
||||||
|
# https://wiki.postgresql.org/wiki/Locale_data_changes
|
||||||
|
#
|
||||||
# 'args' gives options which are passed through to the database engine,
|
# 'args' gives options which are passed through to the database engine,
|
||||||
# except for options starting 'cp_', which are used to configure the Twisted
|
# except for options starting 'cp_', which are used to configure the Twisted
|
||||||
# connection pool. For a reference to valid arguments, see:
|
# connection pool. For a reference to valid arguments, see:
|
||||||
|
@ -1212,10 +1218,18 @@ oembed:
|
||||||
# Registration can be rate-limited using the parameters in the "Ratelimiting"
|
# Registration can be rate-limited using the parameters in the "Ratelimiting"
|
||||||
# section of this file.
|
# section of this file.
|
||||||
|
|
||||||
# Enable registration for new users.
|
# Enable registration for new users. Defaults to 'false'. It is highly recommended that if you enable registration,
|
||||||
|
# you use either captcha, email, or token-based verification to verify that new users are not bots. In order to enable registration
|
||||||
|
# without any verification, you must also set `enable_registration_without_verification`, found below.
|
||||||
#
|
#
|
||||||
#enable_registration: false
|
#enable_registration: false
|
||||||
|
|
||||||
|
# Enable registration without email or captcha verification. Note: this option is *not* recommended,
|
||||||
|
# as registration without verification is a known vector for spam and abuse. Defaults to false. Has no effect
|
||||||
|
# unless `enable_registration` is also enabled.
|
||||||
|
#
|
||||||
|
#enable_registration_without_verification: true
|
||||||
|
|
||||||
# Time that a user's session remains valid for, after they log in.
|
# Time that a user's session remains valid for, after they log in.
|
||||||
#
|
#
|
||||||
# Note that this is not currently compatible with guest logins.
|
# Note that this is not currently compatible with guest logins.
|
||||||
|
|
|
@ -99,8 +99,21 @@ experimental_features:
|
||||||
groups_enabled: false
|
groups_enabled: false
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Change in behaviour for PostgreSQL databases with unsafe locale
|
||||||
|
|
||||||
|
Synapse now refuses to start when using PostgreSQL with non-`C` values for `COLLATE` and
|
||||||
|
`CTYPE` unless the config flag `allow_unsafe_locale`, found in the database section of
|
||||||
|
the configuration file, is set to `true`. See the [PostgreSQL documentation](https://matrix-org.github.io/synapse/latest/postgres.html#fixing-incorrect-collate-or-ctype)
|
||||||
|
for more information and instructions on how to fix a database with incorrect values.
|
||||||
|
|
||||||
# Upgrading to v1.55.0
|
# Upgrading to v1.55.0
|
||||||
|
|
||||||
|
## Open registration without verification is now disabled by default
|
||||||
|
|
||||||
|
Synapse will refuse to start if registration is enabled without email, captcha, or token-based verification unless the new config
|
||||||
|
flag `enable_registration_without_verification` is set to "true".
|
||||||
|
|
||||||
|
|
||||||
## `synctl` script has been moved
|
## `synctl` script has been moved
|
||||||
|
|
||||||
The `synctl` script
|
The `synctl` script
|
||||||
|
|
|
@ -185,8 +185,8 @@ worker: refer to the [stream writers](#stream-writers) section below for further
|
||||||
information.
|
information.
|
||||||
|
|
||||||
# Sync requests
|
# Sync requests
|
||||||
^/_matrix/client/(v2_alpha|r0|v3)/sync$
|
^/_matrix/client/(r0|v3)/sync$
|
||||||
^/_matrix/client/(api/v1|v2_alpha|r0|v3)/events$
|
^/_matrix/client/(api/v1|r0|v3)/events$
|
||||||
^/_matrix/client/(api/v1|r0|v3)/initialSync$
|
^/_matrix/client/(api/v1|r0|v3)/initialSync$
|
||||||
^/_matrix/client/(api/v1|r0|v3)/rooms/[^/]+/initialSync$
|
^/_matrix/client/(api/v1|r0|v3)/rooms/[^/]+/initialSync$
|
||||||
|
|
||||||
|
@ -200,13 +200,9 @@ information.
|
||||||
^/_matrix/federation/v1/query/
|
^/_matrix/federation/v1/query/
|
||||||
^/_matrix/federation/v1/make_join/
|
^/_matrix/federation/v1/make_join/
|
||||||
^/_matrix/federation/v1/make_leave/
|
^/_matrix/federation/v1/make_leave/
|
||||||
^/_matrix/federation/v1/send_join/
|
^/_matrix/federation/(v1|v2)/send_join/
|
||||||
^/_matrix/federation/v2/send_join/
|
^/_matrix/federation/(v1|v2)/send_leave/
|
||||||
^/_matrix/federation/v1/send_leave/
|
^/_matrix/federation/(v1|v2)/invite/
|
||||||
^/_matrix/federation/v2/send_leave/
|
|
||||||
^/_matrix/federation/v1/invite/
|
|
||||||
^/_matrix/federation/v2/invite/
|
|
||||||
^/_matrix/federation/v1/query_auth/
|
|
||||||
^/_matrix/federation/v1/event_auth/
|
^/_matrix/federation/v1/event_auth/
|
||||||
^/_matrix/federation/v1/exchange_third_party_invite/
|
^/_matrix/federation/v1/exchange_third_party_invite/
|
||||||
^/_matrix/federation/v1/user/devices/
|
^/_matrix/federation/v1/user/devices/
|
||||||
|
@ -274,6 +270,8 @@ information.
|
||||||
Additionally, the following REST endpoints can be handled for GET requests:
|
Additionally, the following REST endpoints can be handled for GET requests:
|
||||||
|
|
||||||
^/_matrix/federation/v1/groups/
|
^/_matrix/federation/v1/groups/
|
||||||
|
^/_matrix/client/(api/v1|r0|v3|unstable)/pushrules/
|
||||||
|
^/_matrix/client/(r0|v3|unstable)/groups/
|
||||||
|
|
||||||
Pagination requests can also be handled, but all requests for a given
|
Pagination requests can also be handled, but all requests for a given
|
||||||
room must be routed to the same instance. Additionally, care must be taken to
|
room must be routed to the same instance. Additionally, care must be taken to
|
||||||
|
@ -397,23 +395,23 @@ the stream writer for the `typing` stream:
|
||||||
The following endpoints should be routed directly to the worker configured as
|
The following endpoints should be routed directly to the worker configured as
|
||||||
the stream writer for the `to_device` stream:
|
the stream writer for the `to_device` stream:
|
||||||
|
|
||||||
^/_matrix/client/(api/v1|r0|v3|unstable)/sendToDevice/
|
^/_matrix/client/(r0|v3|unstable)/sendToDevice/
|
||||||
|
|
||||||
##### The `account_data` stream
|
##### The `account_data` stream
|
||||||
|
|
||||||
The following endpoints should be routed directly to the worker configured as
|
The following endpoints should be routed directly to the worker configured as
|
||||||
the stream writer for the `account_data` stream:
|
the stream writer for the `account_data` stream:
|
||||||
|
|
||||||
^/_matrix/client/(api/v1|r0|v3|unstable)/.*/tags
|
^/_matrix/client/(r0|v3|unstable)/.*/tags
|
||||||
^/_matrix/client/(api/v1|r0|v3|unstable)/.*/account_data
|
^/_matrix/client/(r0|v3|unstable)/.*/account_data
|
||||||
|
|
||||||
##### The `receipts` stream
|
##### The `receipts` stream
|
||||||
|
|
||||||
The following endpoints should be routed directly to the worker configured as
|
The following endpoints should be routed directly to the worker configured as
|
||||||
the stream writer for the `receipts` stream:
|
the stream writer for the `receipts` stream:
|
||||||
|
|
||||||
^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/receipt
|
^/_matrix/client/(r0|v3|unstable)/rooms/.*/receipt
|
||||||
^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/read_markers
|
^/_matrix/client/(r0|v3|unstable)/rooms/.*/read_markers
|
||||||
|
|
||||||
##### The `presence` stream
|
##### The `presence` stream
|
||||||
|
|
||||||
|
@ -528,19 +526,28 @@ Note that if a reverse proxy is used , then `/_matrix/media/` must be routed for
|
||||||
Handles searches in the user directory. It can handle REST endpoints matching
|
Handles searches in the user directory. It can handle REST endpoints matching
|
||||||
the following regular expressions:
|
the following regular expressions:
|
||||||
|
|
||||||
^/_matrix/client/(api/v1|r0|v3|unstable)/user_directory/search$
|
^/_matrix/client/(r0|v3|unstable)/user_directory/search$
|
||||||
|
|
||||||
When using this worker you must also set `update_user_directory: False` in the
|
When using this worker you must also set `update_user_directory: false` in the
|
||||||
shared configuration file to stop the main synapse running background
|
shared configuration file to stop the main synapse running background
|
||||||
jobs related to updating the user directory.
|
jobs related to updating the user directory.
|
||||||
|
|
||||||
|
Above endpoint is not *required* to be routed to this worker. By default,
|
||||||
|
`update_user_directory` is set to `true`, which means the main process
|
||||||
|
will handle updates. All workers configured with `client` can handle the above
|
||||||
|
endpoint as long as either this worker or the main process are configured to
|
||||||
|
handle it, and are online.
|
||||||
|
|
||||||
|
If `update_user_directory` is set to `false`, and this worker is not running,
|
||||||
|
the above endpoint may give outdated results.
|
||||||
|
|
||||||
### `synapse.app.frontend_proxy`
|
### `synapse.app.frontend_proxy`
|
||||||
|
|
||||||
Proxies some frequently-requested client endpoints to add caching and remove
|
Proxies some frequently-requested client endpoints to add caching and remove
|
||||||
load from the main synapse. It can handle REST endpoints matching the following
|
load from the main synapse. It can handle REST endpoints matching the following
|
||||||
regular expressions:
|
regular expressions:
|
||||||
|
|
||||||
^/_matrix/client/(api/v1|r0|v3|unstable)/keys/upload
|
^/_matrix/client/(r0|v3|unstable)/keys/upload
|
||||||
|
|
||||||
If `use_presence` is False in the homeserver config, it can also handle REST
|
If `use_presence` is False in the homeserver config, it can also handle REST
|
||||||
endpoints matching the following regular expressions:
|
endpoints matching the following regular expressions:
|
||||||
|
|
20
mypy.ini
20
mypy.ini
|
@ -38,17 +38,11 @@ exclude = (?x)
|
||||||
|synapse/_scripts/update_synapse_database.py
|
|synapse/_scripts/update_synapse_database.py
|
||||||
|
|
||||||
|synapse/storage/databases/__init__.py
|
|synapse/storage/databases/__init__.py
|
||||||
|synapse/storage/databases/main/__init__.py
|
|
||||||
|synapse/storage/databases/main/cache.py
|
|synapse/storage/databases/main/cache.py
|
||||||
|synapse/storage/databases/main/devices.py
|
|synapse/storage/databases/main/devices.py
|
||||||
|synapse/storage/databases/main/event_federation.py
|
|synapse/storage/databases/main/event_federation.py
|
||||||
|synapse/storage/databases/main/group_server.py
|
|
||||||
|synapse/storage/databases/main/metrics.py
|
|
||||||
|synapse/storage/databases/main/monthly_active_users.py
|
|
||||||
|synapse/storage/databases/main/push_rule.py
|
|synapse/storage/databases/main/push_rule.py
|
||||||
|synapse/storage/databases/main/receipts.py
|
|
||||||
|synapse/storage/databases/main/roommember.py
|
|synapse/storage/databases/main/roommember.py
|
||||||
|synapse/storage/databases/main/search.py
|
|
||||||
|synapse/storage/databases/main/state.py
|
|synapse/storage/databases/main/state.py
|
||||||
|synapse/storage/schema/
|
|synapse/storage/schema/
|
||||||
|
|
||||||
|
@ -66,14 +60,6 @@ exclude = (?x)
|
||||||
|tests/federation/test_federation_server.py
|
|tests/federation/test_federation_server.py
|
||||||
|tests/federation/transport/test_knocking.py
|
|tests/federation/transport/test_knocking.py
|
||||||
|tests/federation/transport/test_server.py
|
|tests/federation/transport/test_server.py
|
||||||
|tests/handlers/test_cas.py
|
|
||||||
|tests/handlers/test_directory.py
|
|
||||||
|tests/handlers/test_e2e_keys.py
|
|
||||||
|tests/handlers/test_federation.py
|
|
||||||
|tests/handlers/test_oidc.py
|
|
||||||
|tests/handlers/test_presence.py
|
|
||||||
|tests/handlers/test_profile.py
|
|
||||||
|tests/handlers/test_saml.py
|
|
||||||
|tests/handlers/test_typing.py
|
|tests/handlers/test_typing.py
|
||||||
|tests/http/federation/test_matrix_federation_agent.py
|
|tests/http/federation/test_matrix_federation_agent.py
|
||||||
|tests/http/federation/test_srv_resolver.py
|
|tests/http/federation/test_srv_resolver.py
|
||||||
|
@ -85,7 +71,6 @@ exclude = (?x)
|
||||||
|tests/logging/test_terse_json.py
|
|tests/logging/test_terse_json.py
|
||||||
|tests/module_api/test_api.py
|
|tests/module_api/test_api.py
|
||||||
|tests/push/test_email.py
|
|tests/push/test_email.py
|
||||||
|tests/push/test_http.py
|
|
||||||
|tests/push/test_presentable_names.py
|
|tests/push/test_presentable_names.py
|
||||||
|tests/push/test_push_rule_evaluator.py
|
|tests/push/test_push_rule_evaluator.py
|
||||||
|tests/rest/client/test_transactions.py
|
|tests/rest/client/test_transactions.py
|
||||||
|
@ -94,12 +79,7 @@ exclude = (?x)
|
||||||
|tests/server.py
|
|tests/server.py
|
||||||
|tests/server_notices/test_resource_limits_server_notices.py
|
|tests/server_notices/test_resource_limits_server_notices.py
|
||||||
|tests/state/test_v2.py
|
|tests/state/test_v2.py
|
||||||
|tests/storage/test_background_update.py
|
|
||||||
|tests/storage/test_base.py
|
|tests/storage/test_base.py
|
||||||
|tests/storage/test_client_ips.py
|
|
||||||
|tests/storage/test_database.py
|
|
||||||
|tests/storage/test_event_federation.py
|
|
||||||
|tests/storage/test_id_generators.py
|
|
||||||
|tests/storage/test_roommember.py
|
|tests/storage/test_roommember.py
|
||||||
|tests/test_metrics.py
|
|tests/test_metrics.py
|
||||||
|tests/test_phone_home.py
|
|tests/test_phone_home.py
|
||||||
|
|
|
@ -66,11 +66,15 @@ def cli():
|
||||||
|
|
||||||
./scripts-dev/release.py tag
|
./scripts-dev/release.py tag
|
||||||
|
|
||||||
# ... wait for asssets to build ...
|
# ... wait for assets to build ...
|
||||||
|
|
||||||
./scripts-dev/release.py publish
|
./scripts-dev/release.py publish
|
||||||
./scripts-dev/release.py upload
|
./scripts-dev/release.py upload
|
||||||
|
|
||||||
|
# Optional: generate some nice links for the announcement
|
||||||
|
|
||||||
|
./scripts-dev/release.py upload
|
||||||
|
|
||||||
If the env var GH_TOKEN (or GITHUB_TOKEN) is set, or passed into the
|
If the env var GH_TOKEN (or GITHUB_TOKEN) is set, or passed into the
|
||||||
`tag`/`publish` command, then a new draft release will be created/published.
|
`tag`/`publish` command, then a new draft release will be created/published.
|
||||||
"""
|
"""
|
||||||
|
@ -415,6 +419,41 @@ def upload():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
|
def announce():
|
||||||
|
"""Generate markdown to announce the release."""
|
||||||
|
|
||||||
|
current_version, _, _ = parse_version_from_module()
|
||||||
|
tag_name = f"v{current_version}"
|
||||||
|
|
||||||
|
click.echo(
|
||||||
|
f"""
|
||||||
|
Hi everyone. Synapse {current_version} has just been released.
|
||||||
|
|
||||||
|
[notes](https://github.com/matrix-org/synapse/releases/tag/{tag_name}) |\
|
||||||
|
[docker](https://hub.docker.com/r/matrixdotorg/synapse/tags?name={tag_name}) | \
|
||||||
|
[debs](https://packages.matrix.org/debian/) | \
|
||||||
|
[pypi](https://pypi.org/project/matrix-synapse/{current_version}/)"""
|
||||||
|
)
|
||||||
|
|
||||||
|
if "rc" in tag_name:
|
||||||
|
click.echo(
|
||||||
|
"""
|
||||||
|
Announce the RC in
|
||||||
|
- #homeowners:matrix.org (Synapse Announcements)
|
||||||
|
- #synapse-dev:matrix.org"""
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
click.echo(
|
||||||
|
"""
|
||||||
|
Announce the release in
|
||||||
|
- #homeowners:matrix.org (Synapse Announcements), bumping the version in the topic
|
||||||
|
- #synapse:matrix.org (Synapse Admins), bumping the version in the topic
|
||||||
|
- #synapse-dev:matrix.org
|
||||||
|
- #synapse-package-maintainers:matrix.org"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def parse_version_from_module() -> Tuple[
|
def parse_version_from_module() -> Tuple[
|
||||||
version.Version, redbaron.RedBaron, redbaron.Node
|
version.Version, redbaron.RedBaron, redbaron.Node
|
||||||
]:
|
]:
|
||||||
|
|
1
setup.py
1
setup.py
|
@ -108,6 +108,7 @@ CONDITIONAL_REQUIREMENTS["mypy"] = [
|
||||||
"types-jsonschema>=3.2.0",
|
"types-jsonschema>=3.2.0",
|
||||||
"types-opentracing>=2.4.2",
|
"types-opentracing>=2.4.2",
|
||||||
"types-Pillow>=8.3.4",
|
"types-Pillow>=8.3.4",
|
||||||
|
"types-psycopg2>=2.9.9",
|
||||||
"types-pyOpenSSL>=20.0.7",
|
"types-pyOpenSSL>=20.0.7",
|
||||||
"types-PyYAML>=5.4.10",
|
"types-PyYAML>=5.4.10",
|
||||||
"types-requests>=2.26.0",
|
"types-requests>=2.26.0",
|
||||||
|
|
|
@ -68,7 +68,7 @@ try:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
__version__ = "1.55.0"
|
__version__ = "1.56.0rc1"
|
||||||
|
|
||||||
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
|
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
|
||||||
# We import here so that we don't have to install a bunch of deps when
|
# We import here so that we don't have to install a bunch of deps when
|
||||||
|
|
|
@ -261,7 +261,10 @@ class SynapseHomeServer(HomeServer):
|
||||||
resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self)
|
resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self)
|
||||||
|
|
||||||
if name == "metrics" and self.config.metrics.enable_metrics:
|
if name == "metrics" and self.config.metrics.enable_metrics:
|
||||||
resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
|
metrics_resource: Resource = MetricsResource(RegistryProxy)
|
||||||
|
if compress:
|
||||||
|
metrics_resource = gz_wrap(metrics_resource)
|
||||||
|
resources[METRICS_PREFIX] = metrics_resource
|
||||||
|
|
||||||
if name == "replication":
|
if name == "replication":
|
||||||
resources[REPLICATION_PREFIX] = ReplicationRestResource(self)
|
resources[REPLICATION_PREFIX] = ReplicationRestResource(self)
|
||||||
|
@ -348,6 +351,23 @@ def setup(config_options: List[str]) -> SynapseHomeServer:
|
||||||
if config.server.gc_seconds:
|
if config.server.gc_seconds:
|
||||||
synapse.metrics.MIN_TIME_BETWEEN_GCS = config.server.gc_seconds
|
synapse.metrics.MIN_TIME_BETWEEN_GCS = config.server.gc_seconds
|
||||||
|
|
||||||
|
if (
|
||||||
|
config.registration.enable_registration
|
||||||
|
and not config.registration.enable_registration_without_verification
|
||||||
|
):
|
||||||
|
if (
|
||||||
|
not config.captcha.enable_registration_captcha
|
||||||
|
and not config.registration.registrations_require_3pid
|
||||||
|
and not config.registration.registration_requires_token
|
||||||
|
):
|
||||||
|
|
||||||
|
raise ConfigError(
|
||||||
|
"You have enabled open registration without any verification. This is a known vector for "
|
||||||
|
"spam and abuse. If you would like to allow public registration, please consider adding email, "
|
||||||
|
"captcha, or token-based verification. Otherwise this check can be removed by setting the "
|
||||||
|
"`enable_registration_without_verification` config option to `true`."
|
||||||
|
)
|
||||||
|
|
||||||
hs = SynapseHomeServer(
|
hs = SynapseHomeServer(
|
||||||
config.server.server_name,
|
config.server.server_name,
|
||||||
config=config,
|
config=config,
|
||||||
|
|
|
@ -37,6 +37,12 @@ DEFAULT_CONFIG = """\
|
||||||
# 'txn_limit' gives the maximum number of transactions to run per connection
|
# 'txn_limit' gives the maximum number of transactions to run per connection
|
||||||
# before reconnecting. Defaults to 0, which means no limit.
|
# before reconnecting. Defaults to 0, which means no limit.
|
||||||
#
|
#
|
||||||
|
# 'allow_unsafe_locale' is an option specific to Postgres. Under the default behavior, Synapse will refuse to
|
||||||
|
# start if the postgres db is set to a non-C locale. You can override this behavior (which is *not* recommended)
|
||||||
|
# by setting 'allow_unsafe_locale' to true. Note that doing so may corrupt your database. You can find more information
|
||||||
|
# here: https://matrix-org.github.io/synapse/latest/postgres.html#fixing-incorrect-collate-or-ctype and here:
|
||||||
|
# https://wiki.postgresql.org/wiki/Locale_data_changes
|
||||||
|
#
|
||||||
# 'args' gives options which are passed through to the database engine,
|
# 'args' gives options which are passed through to the database engine,
|
||||||
# except for options starting 'cp_', which are used to configure the Twisted
|
# except for options starting 'cp_', which are used to configure the Twisted
|
||||||
# connection pool. For a reference to valid arguments, see:
|
# connection pool. For a reference to valid arguments, see:
|
||||||
|
|
|
@ -33,6 +33,10 @@ class RegistrationConfig(Config):
|
||||||
str(config["disable_registration"])
|
str(config["disable_registration"])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.enable_registration_without_verification = strtobool(
|
||||||
|
str(config.get("enable_registration_without_verification", False))
|
||||||
|
)
|
||||||
|
|
||||||
self.registrations_require_3pid = config.get("registrations_require_3pid", [])
|
self.registrations_require_3pid = config.get("registrations_require_3pid", [])
|
||||||
self.allowed_local_3pids = config.get("allowed_local_3pids", [])
|
self.allowed_local_3pids = config.get("allowed_local_3pids", [])
|
||||||
self.enable_3pid_lookup = config.get("enable_3pid_lookup", True)
|
self.enable_3pid_lookup = config.get("enable_3pid_lookup", True)
|
||||||
|
@ -207,10 +211,18 @@ class RegistrationConfig(Config):
|
||||||
# Registration can be rate-limited using the parameters in the "Ratelimiting"
|
# Registration can be rate-limited using the parameters in the "Ratelimiting"
|
||||||
# section of this file.
|
# section of this file.
|
||||||
|
|
||||||
# Enable registration for new users.
|
# Enable registration for new users. Defaults to 'false'. It is highly recommended that if you enable registration,
|
||||||
|
# you use either captcha, email, or token-based verification to verify that new users are not bots. In order to enable registration
|
||||||
|
# without any verification, you must also set `enable_registration_without_verification`, found below.
|
||||||
#
|
#
|
||||||
#enable_registration: false
|
#enable_registration: false
|
||||||
|
|
||||||
|
# Enable registration without email or captcha verification. Note: this option is *not* recommended,
|
||||||
|
# as registration without verification is a known vector for spam and abuse. Defaults to false. Has no effect
|
||||||
|
# unless `enable_registration` is also enabled.
|
||||||
|
#
|
||||||
|
#enable_registration_without_verification: true
|
||||||
|
|
||||||
# Time that a user's session remains valid for, after they log in.
|
# Time that a user's session remains valid for, after they log in.
|
||||||
#
|
#
|
||||||
# Note that this is not currently compatible with guest logins.
|
# Note that this is not currently compatible with guest logins.
|
||||||
|
|
|
@ -676,6 +676,10 @@ class ServerConfig(Config):
|
||||||
):
|
):
|
||||||
raise ConfigError("'custom_template_directory' must be a string")
|
raise ConfigError("'custom_template_directory' must be a string")
|
||||||
|
|
||||||
|
self.use_account_validity_in_account_status: bool = (
|
||||||
|
config.get("use_account_validity_in_account_status") or False
|
||||||
|
)
|
||||||
|
|
||||||
def has_tls_listener(self) -> bool:
|
def has_tls_listener(self) -> bool:
|
||||||
return any(listener.tls for listener in self.listeners)
|
return any(listener.tls for listener in self.listeners)
|
||||||
|
|
||||||
|
|
|
@ -25,8 +25,8 @@ logger = logging.getLogger(__name__)
|
||||||
LEGACY_SPAM_CHECKER_WARNING = """
|
LEGACY_SPAM_CHECKER_WARNING = """
|
||||||
This server is using a spam checker module that is implementing the deprecated spam
|
This server is using a spam checker module that is implementing the deprecated spam
|
||||||
checker interface. Please check with the module's maintainer to see if a new version
|
checker interface. Please check with the module's maintainer to see if a new version
|
||||||
supporting Synapse's generic modules system is available.
|
supporting Synapse's generic modules system is available. For more information, please
|
||||||
For more information, please see https://matrix-org.github.io/synapse/latest/modules.html
|
see https://matrix-org.github.io/synapse/latest/modules/index.html
|
||||||
---------------------------------------------------------------------------------------"""
|
---------------------------------------------------------------------------------------"""
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,6 @@ from typing import (
|
||||||
Awaitable,
|
Awaitable,
|
||||||
Callable,
|
Callable,
|
||||||
Collection,
|
Collection,
|
||||||
Dict,
|
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
Tuple,
|
Tuple,
|
||||||
|
@ -31,7 +30,7 @@ from typing import (
|
||||||
from synapse.rest.media.v1._base import FileInfo
|
from synapse.rest.media.v1._base import FileInfo
|
||||||
from synapse.rest.media.v1.media_storage import ReadableFileWrapper
|
from synapse.rest.media.v1.media_storage import ReadableFileWrapper
|
||||||
from synapse.spam_checker_api import RegistrationBehaviour
|
from synapse.spam_checker_api import RegistrationBehaviour
|
||||||
from synapse.types import RoomAlias
|
from synapse.types import RoomAlias, UserProfile
|
||||||
from synapse.util.async_helpers import maybe_awaitable
|
from synapse.util.async_helpers import maybe_awaitable
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -50,7 +49,7 @@ USER_MAY_SEND_3PID_INVITE_CALLBACK = Callable[[str, str, str, str], Awaitable[bo
|
||||||
USER_MAY_CREATE_ROOM_CALLBACK = Callable[[str], Awaitable[bool]]
|
USER_MAY_CREATE_ROOM_CALLBACK = Callable[[str], Awaitable[bool]]
|
||||||
USER_MAY_CREATE_ROOM_ALIAS_CALLBACK = Callable[[str, RoomAlias], Awaitable[bool]]
|
USER_MAY_CREATE_ROOM_ALIAS_CALLBACK = Callable[[str, RoomAlias], Awaitable[bool]]
|
||||||
USER_MAY_PUBLISH_ROOM_CALLBACK = Callable[[str, str], Awaitable[bool]]
|
USER_MAY_PUBLISH_ROOM_CALLBACK = Callable[[str, str], Awaitable[bool]]
|
||||||
CHECK_USERNAME_FOR_SPAM_CALLBACK = Callable[[Dict[str, str]], Awaitable[bool]]
|
CHECK_USERNAME_FOR_SPAM_CALLBACK = Callable[[UserProfile], Awaitable[bool]]
|
||||||
LEGACY_CHECK_REGISTRATION_FOR_SPAM_CALLBACK = Callable[
|
LEGACY_CHECK_REGISTRATION_FOR_SPAM_CALLBACK = Callable[
|
||||||
[
|
[
|
||||||
Optional[dict],
|
Optional[dict],
|
||||||
|
@ -383,7 +382,7 @@ class SpamChecker:
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool:
|
async def check_username_for_spam(self, user_profile: UserProfile) -> bool:
|
||||||
"""Checks if a user ID or display name are considered "spammy" by this server.
|
"""Checks if a user ID or display name are considered "spammy" by this server.
|
||||||
|
|
||||||
If the server considers a username spammy, then it will not be included in
|
If the server considers a username spammy, then it will not be included in
|
||||||
|
|
|
@ -38,8 +38,8 @@ from synapse.util.frozenutils import unfreeze
|
||||||
from . import EventBase
|
from . import EventBase
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from synapse.handlers.relations import BundledAggregations
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.storage.databases.main.relations import BundledAggregations
|
|
||||||
|
|
||||||
|
|
||||||
# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
|
# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
|
||||||
|
|
|
@ -22,7 +22,6 @@ from typing import (
|
||||||
Callable,
|
Callable,
|
||||||
Collection,
|
Collection,
|
||||||
Dict,
|
Dict,
|
||||||
Iterable,
|
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
Tuple,
|
Tuple,
|
||||||
|
@ -577,10 +576,10 @@ class FederationServer(FederationBase):
|
||||||
async def _on_context_state_request_compute(
|
async def _on_context_state_request_compute(
|
||||||
self, room_id: str, event_id: Optional[str]
|
self, room_id: str, event_id: Optional[str]
|
||||||
) -> Dict[str, list]:
|
) -> Dict[str, list]:
|
||||||
|
pdus: Collection[EventBase]
|
||||||
if event_id:
|
if event_id:
|
||||||
pdus: Iterable[EventBase] = await self.handler.get_state_for_pdu(
|
event_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
|
||||||
room_id, event_id
|
pdus = await self.store.get_events_as_list(event_ids)
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
pdus = (await self.state.get_current_state(room_id)).values()
|
pdus = (await self.state.get_current_state(room_id)).values()
|
||||||
|
|
||||||
|
@ -1093,7 +1092,7 @@ class FederationServer(FederationBase):
|
||||||
# has started processing).
|
# has started processing).
|
||||||
while True:
|
while True:
|
||||||
async with lock:
|
async with lock:
|
||||||
logger.info("handling received PDU: %s", event)
|
logger.info("handling received PDU in room %s: %s", room_id, event)
|
||||||
try:
|
try:
|
||||||
with nested_logging_context(event.event_id):
|
with nested_logging_context(event.event_id):
|
||||||
await self._federation_event_handler.on_receive_pdu(
|
await self._federation_event_handler.on_receive_pdu(
|
||||||
|
|
|
@ -26,6 +26,10 @@ class AccountHandler:
|
||||||
self._main_store = hs.get_datastores().main
|
self._main_store = hs.get_datastores().main
|
||||||
self._is_mine = hs.is_mine
|
self._is_mine = hs.is_mine
|
||||||
self._federation_client = hs.get_federation_client()
|
self._federation_client = hs.get_federation_client()
|
||||||
|
self._use_account_validity_in_account_status = (
|
||||||
|
hs.config.server.use_account_validity_in_account_status
|
||||||
|
)
|
||||||
|
self._account_validity_handler = hs.get_account_validity_handler()
|
||||||
|
|
||||||
async def get_account_statuses(
|
async def get_account_statuses(
|
||||||
self,
|
self,
|
||||||
|
@ -106,6 +110,13 @@ class AccountHandler:
|
||||||
"deactivated": userinfo.is_deactivated,
|
"deactivated": userinfo.is_deactivated,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if self._use_account_validity_in_account_status:
|
||||||
|
status[
|
||||||
|
"org.matrix.expired"
|
||||||
|
] = await self._account_validity_handler.is_user_expired(
|
||||||
|
user_id.to_string()
|
||||||
|
)
|
||||||
|
|
||||||
return status
|
return status
|
||||||
|
|
||||||
async def _get_remote_account_statuses(
|
async def _get_remote_account_statuses(
|
||||||
|
|
|
@ -950,54 +950,35 @@ class FederationHandler:
|
||||||
|
|
||||||
return event
|
return event
|
||||||
|
|
||||||
async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]:
|
|
||||||
"""Returns the state at the event. i.e. not including said event."""
|
|
||||||
|
|
||||||
event = await self.store.get_event(event_id, check_room_id=room_id)
|
|
||||||
|
|
||||||
state_groups = await self.state_store.get_state_groups(room_id, [event_id])
|
|
||||||
|
|
||||||
if state_groups:
|
|
||||||
_, state = list(state_groups.items()).pop()
|
|
||||||
results = {(e.type, e.state_key): e for e in state}
|
|
||||||
|
|
||||||
if event.is_state():
|
|
||||||
# Get previous state
|
|
||||||
if "replaces_state" in event.unsigned:
|
|
||||||
prev_id = event.unsigned["replaces_state"]
|
|
||||||
if prev_id != event.event_id:
|
|
||||||
prev_event = await self.store.get_event(prev_id)
|
|
||||||
results[(event.type, event.state_key)] = prev_event
|
|
||||||
else:
|
|
||||||
del results[(event.type, event.state_key)]
|
|
||||||
|
|
||||||
res = list(results.values())
|
|
||||||
return res
|
|
||||||
else:
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
|
async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
|
||||||
"""Returns the state at the event. i.e. not including said event."""
|
"""Returns the state at the event. i.e. not including said event."""
|
||||||
event = await self.store.get_event(event_id, check_room_id=room_id)
|
event = await self.store.get_event(event_id, check_room_id=room_id)
|
||||||
|
if event.internal_metadata.outlier:
|
||||||
|
raise NotFoundError("State not known at event %s" % (event_id,))
|
||||||
|
|
||||||
state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id])
|
state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id])
|
||||||
|
|
||||||
if state_groups:
|
# get_state_groups_ids should return exactly one result
|
||||||
_, state = list(state_groups.items()).pop()
|
assert len(state_groups) == 1
|
||||||
results = state
|
|
||||||
|
|
||||||
if event.is_state():
|
state_map = next(iter(state_groups.values()))
|
||||||
# Get previous state
|
|
||||||
|
state_key = event.get_state_key()
|
||||||
|
if state_key is not None:
|
||||||
|
# the event was not rejected (get_event raises a NotFoundError for rejected
|
||||||
|
# events) so the state at the event should include the event itself.
|
||||||
|
assert (
|
||||||
|
state_map.get((event.type, state_key)) == event.event_id
|
||||||
|
), "State at event did not include event itself"
|
||||||
|
|
||||||
|
# ... but we need the state *before* that event
|
||||||
if "replaces_state" in event.unsigned:
|
if "replaces_state" in event.unsigned:
|
||||||
prev_id = event.unsigned["replaces_state"]
|
prev_id = event.unsigned["replaces_state"]
|
||||||
if prev_id != event.event_id:
|
state_map[(event.type, state_key)] = prev_id
|
||||||
results[(event.type, event.state_key)] = prev_id
|
|
||||||
else:
|
else:
|
||||||
results.pop((event.type, event.state_key), None)
|
del state_map[(event.type, state_key)]
|
||||||
|
|
||||||
return list(results.values())
|
return list(state_map.values())
|
||||||
else:
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def on_backfill_request(
|
async def on_backfill_request(
|
||||||
self, origin: str, room_id: str, pdu_list: List[str], limit: int
|
self, origin: str, room_id: str, pdu_list: List[str], limit: int
|
||||||
|
|
|
@ -495,6 +495,7 @@ class EventCreationHandler:
|
||||||
allow_no_prev_events: bool = False,
|
allow_no_prev_events: bool = False,
|
||||||
prev_event_ids: Optional[List[str]] = None,
|
prev_event_ids: Optional[List[str]] = None,
|
||||||
auth_event_ids: Optional[List[str]] = None,
|
auth_event_ids: Optional[List[str]] = None,
|
||||||
|
state_event_ids: Optional[List[str]] = None,
|
||||||
require_consent: bool = True,
|
require_consent: bool = True,
|
||||||
outlier: bool = False,
|
outlier: bool = False,
|
||||||
historical: bool = False,
|
historical: bool = False,
|
||||||
|
@ -529,6 +530,15 @@ class EventCreationHandler:
|
||||||
|
|
||||||
If non-None, prev_event_ids must also be provided.
|
If non-None, prev_event_ids must also be provided.
|
||||||
|
|
||||||
|
state_event_ids:
|
||||||
|
The full state at a given event. This is used particularly by the MSC2716
|
||||||
|
/batch_send endpoint. One use case is with insertion events which float at
|
||||||
|
the beginning of a historical batch and don't have any `prev_events` to
|
||||||
|
derive from; we add all of these state events as the explicit state so the
|
||||||
|
rest of the historical batch can inherit the same state and state_group.
|
||||||
|
This should normally be left as None, which will cause the auth_event_ids
|
||||||
|
to be calculated based on the room state at the prev_events.
|
||||||
|
|
||||||
require_consent: Whether to check if the requester has
|
require_consent: Whether to check if the requester has
|
||||||
consented to the privacy policy.
|
consented to the privacy policy.
|
||||||
|
|
||||||
|
@ -614,6 +624,7 @@ class EventCreationHandler:
|
||||||
allow_no_prev_events=allow_no_prev_events,
|
allow_no_prev_events=allow_no_prev_events,
|
||||||
prev_event_ids=prev_event_ids,
|
prev_event_ids=prev_event_ids,
|
||||||
auth_event_ids=auth_event_ids,
|
auth_event_ids=auth_event_ids,
|
||||||
|
state_event_ids=state_event_ids,
|
||||||
depth=depth,
|
depth=depth,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -773,7 +784,7 @@ class EventCreationHandler:
|
||||||
event_dict: dict,
|
event_dict: dict,
|
||||||
allow_no_prev_events: bool = False,
|
allow_no_prev_events: bool = False,
|
||||||
prev_event_ids: Optional[List[str]] = None,
|
prev_event_ids: Optional[List[str]] = None,
|
||||||
auth_event_ids: Optional[List[str]] = None,
|
state_event_ids: Optional[List[str]] = None,
|
||||||
ratelimit: bool = True,
|
ratelimit: bool = True,
|
||||||
txn_id: Optional[str] = None,
|
txn_id: Optional[str] = None,
|
||||||
ignore_shadow_ban: bool = False,
|
ignore_shadow_ban: bool = False,
|
||||||
|
@ -797,12 +808,14 @@ class EventCreationHandler:
|
||||||
The event IDs to use as the prev events.
|
The event IDs to use as the prev events.
|
||||||
Should normally be left as None to automatically request them
|
Should normally be left as None to automatically request them
|
||||||
from the database.
|
from the database.
|
||||||
auth_event_ids:
|
state_event_ids:
|
||||||
The event ids to use as the auth_events for the new event.
|
The full state at a given event. This is used particularly by the MSC2716
|
||||||
Should normally be left as None, which will cause them to be calculated
|
/batch_send endpoint. One use case is with insertion events which float at
|
||||||
based on the room state at the prev_events.
|
the beginning of a historical batch and don't have any `prev_events` to
|
||||||
|
derive from; we add all of these state events as the explicit state so the
|
||||||
If non-None, prev_event_ids must also be provided.
|
rest of the historical batch can inherit the same state and state_group.
|
||||||
|
This should normally be left as None, which will cause the auth_event_ids
|
||||||
|
to be calculated based on the room state at the prev_events.
|
||||||
ratelimit: Whether to rate limit this send.
|
ratelimit: Whether to rate limit this send.
|
||||||
txn_id: The transaction ID.
|
txn_id: The transaction ID.
|
||||||
ignore_shadow_ban: True if shadow-banned users should be allowed to
|
ignore_shadow_ban: True if shadow-banned users should be allowed to
|
||||||
|
@ -858,8 +871,9 @@ class EventCreationHandler:
|
||||||
requester,
|
requester,
|
||||||
event_dict,
|
event_dict,
|
||||||
txn_id=txn_id,
|
txn_id=txn_id,
|
||||||
|
allow_no_prev_events=allow_no_prev_events,
|
||||||
prev_event_ids=prev_event_ids,
|
prev_event_ids=prev_event_ids,
|
||||||
auth_event_ids=auth_event_ids,
|
state_event_ids=state_event_ids,
|
||||||
outlier=outlier,
|
outlier=outlier,
|
||||||
historical=historical,
|
historical=historical,
|
||||||
depth=depth,
|
depth=depth,
|
||||||
|
@ -895,6 +909,7 @@ class EventCreationHandler:
|
||||||
allow_no_prev_events: bool = False,
|
allow_no_prev_events: bool = False,
|
||||||
prev_event_ids: Optional[List[str]] = None,
|
prev_event_ids: Optional[List[str]] = None,
|
||||||
auth_event_ids: Optional[List[str]] = None,
|
auth_event_ids: Optional[List[str]] = None,
|
||||||
|
state_event_ids: Optional[List[str]] = None,
|
||||||
depth: Optional[int] = None,
|
depth: Optional[int] = None,
|
||||||
) -> Tuple[EventBase, EventContext]:
|
) -> Tuple[EventBase, EventContext]:
|
||||||
"""Create a new event for a local client
|
"""Create a new event for a local client
|
||||||
|
@ -917,6 +932,15 @@ class EventCreationHandler:
|
||||||
Should normally be left as None, which will cause them to be calculated
|
Should normally be left as None, which will cause them to be calculated
|
||||||
based on the room state at the prev_events.
|
based on the room state at the prev_events.
|
||||||
|
|
||||||
|
state_event_ids:
|
||||||
|
The full state at a given event. This is used particularly by the MSC2716
|
||||||
|
/batch_send endpoint. One use case is with insertion events which float at
|
||||||
|
the beginning of a historical batch and don't have any `prev_events` to
|
||||||
|
derive from; we add all of these state events as the explicit state so the
|
||||||
|
rest of the historical batch can inherit the same state and state_group.
|
||||||
|
This should normally be left as None, which will cause the auth_event_ids
|
||||||
|
to be calculated based on the room state at the prev_events.
|
||||||
|
|
||||||
depth: Override the depth used to order the event in the DAG.
|
depth: Override the depth used to order the event in the DAG.
|
||||||
Should normally be set to None, which will cause the depth to be calculated
|
Should normally be set to None, which will cause the depth to be calculated
|
||||||
based on the prev_events.
|
based on the prev_events.
|
||||||
|
@ -924,31 +948,26 @@ class EventCreationHandler:
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of created event, context
|
Tuple of created event, context
|
||||||
"""
|
"""
|
||||||
# Strip down the auth_event_ids to only what we need to auth the event.
|
# Strip down the state_event_ids to only what we need to auth the event.
|
||||||
# For example, we don't need extra m.room.member that don't match event.sender
|
# For example, we don't need extra m.room.member that don't match event.sender
|
||||||
full_state_ids_at_event = None
|
if state_event_ids is not None:
|
||||||
if auth_event_ids is not None:
|
# Do a quick check to make sure that prev_event_ids is present to
|
||||||
# If auth events are provided, prev events must be also.
|
# make the type-checking around `builder.build` happy.
|
||||||
# prev_event_ids could be an empty array though.
|
# prev_event_ids could be an empty array though.
|
||||||
assert prev_event_ids is not None
|
assert prev_event_ids is not None
|
||||||
|
|
||||||
# Copy the full auth state before it stripped down
|
|
||||||
full_state_ids_at_event = auth_event_ids.copy()
|
|
||||||
|
|
||||||
temp_event = await builder.build(
|
temp_event = await builder.build(
|
||||||
prev_event_ids=prev_event_ids,
|
prev_event_ids=prev_event_ids,
|
||||||
auth_event_ids=auth_event_ids,
|
auth_event_ids=state_event_ids,
|
||||||
depth=depth,
|
depth=depth,
|
||||||
)
|
)
|
||||||
auth_events = await self.store.get_events_as_list(auth_event_ids)
|
state_events = await self.store.get_events_as_list(state_event_ids)
|
||||||
# Create a StateMap[str]
|
# Create a StateMap[str]
|
||||||
auth_event_state_map = {
|
state_map = {(e.type, e.state_key): e.event_id for e in state_events}
|
||||||
(e.type, e.state_key): e.event_id for e in auth_events
|
# Actually strip down and only use the necessary auth events
|
||||||
}
|
|
||||||
# Actually strip down and use the necessary auth events
|
|
||||||
auth_event_ids = self._event_auth_handler.compute_auth_events(
|
auth_event_ids = self._event_auth_handler.compute_auth_events(
|
||||||
event=temp_event,
|
event=temp_event,
|
||||||
current_state_ids=auth_event_state_map,
|
current_state_ids=state_map,
|
||||||
for_verification=False,
|
for_verification=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -991,12 +1010,16 @@ class EventCreationHandler:
|
||||||
context = EventContext.for_outlier()
|
context = EventContext.for_outlier()
|
||||||
elif (
|
elif (
|
||||||
event.type == EventTypes.MSC2716_INSERTION
|
event.type == EventTypes.MSC2716_INSERTION
|
||||||
and full_state_ids_at_event
|
and state_event_ids
|
||||||
and builder.internal_metadata.is_historical()
|
and builder.internal_metadata.is_historical()
|
||||||
):
|
):
|
||||||
|
# Add explicit state to the insertion event so it has state to derive
|
||||||
|
# from even though it's floating with no `prev_events`. The rest of
|
||||||
|
# the batch can derive from this state and state_group.
|
||||||
|
#
|
||||||
# TODO(faster_joins): figure out how this works, and make sure that the
|
# TODO(faster_joins): figure out how this works, and make sure that the
|
||||||
# old state is complete.
|
# old state is complete.
|
||||||
old_state = await self.store.get_events_as_list(full_state_ids_at_event)
|
old_state = await self.store.get_events_as_list(state_event_ids)
|
||||||
context = await self.state.compute_event_context(event, old_state=old_state)
|
context = await self.state.compute_event_context(event, old_state=old_state)
|
||||||
else:
|
else:
|
||||||
context = await self.state.compute_event_context(event)
|
context = await self.state.compute_event_context(event)
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Set
|
from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
|
@ -134,6 +134,7 @@ class PaginationHandler:
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self._server_name = hs.hostname
|
self._server_name = hs.hostname
|
||||||
self._room_shutdown_handler = hs.get_room_shutdown_handler()
|
self._room_shutdown_handler = hs.get_room_shutdown_handler()
|
||||||
|
self._relations_handler = hs.get_relations_handler()
|
||||||
|
|
||||||
self.pagination_lock = ReadWriteLock()
|
self.pagination_lock = ReadWriteLock()
|
||||||
# IDs of rooms in which there currently an active purge *or delete* operation.
|
# IDs of rooms in which there currently an active purge *or delete* operation.
|
||||||
|
@ -422,7 +423,7 @@ class PaginationHandler:
|
||||||
pagin_config: PaginationConfig,
|
pagin_config: PaginationConfig,
|
||||||
as_client_event: bool = True,
|
as_client_event: bool = True,
|
||||||
event_filter: Optional[Filter] = None,
|
event_filter: Optional[Filter] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> JsonDict:
|
||||||
"""Get messages in a room.
|
"""Get messages in a room.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -431,6 +432,7 @@ class PaginationHandler:
|
||||||
pagin_config: The pagination config rules to apply, if any.
|
pagin_config: The pagination config rules to apply, if any.
|
||||||
as_client_event: True to get events in client-server format.
|
as_client_event: True to get events in client-server format.
|
||||||
event_filter: Filter to apply to results or None
|
event_filter: Filter to apply to results or None
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Pagination API results
|
Pagination API results
|
||||||
"""
|
"""
|
||||||
|
@ -538,7 +540,9 @@ class PaginationHandler:
|
||||||
state_dict = await self.store.get_events(list(state_ids.values()))
|
state_dict = await self.store.get_events(list(state_ids.values()))
|
||||||
state = state_dict.values()
|
state = state_dict.values()
|
||||||
|
|
||||||
aggregations = await self.store.get_bundled_aggregations(events, user_id)
|
aggregations = await self._relations_handler.get_bundled_aggregations(
|
||||||
|
events, user_id
|
||||||
|
)
|
||||||
|
|
||||||
time_now = self.clock.time_msec()
|
time_now = self.clock.time_msec()
|
||||||
|
|
||||||
|
|
|
@ -336,12 +336,18 @@ class ProfileHandler:
|
||||||
"""Check that the size and content type of the avatar at the given MXC URI are
|
"""Check that the size and content type of the avatar at the given MXC URI are
|
||||||
within the configured limits.
|
within the configured limits.
|
||||||
|
|
||||||
|
If the given `mxc` is empty, no checks are performed. (Users are always able to
|
||||||
|
unset their avatar.)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
mxc: The MXC URI at which the avatar can be found.
|
mxc: The MXC URI at which the avatar can be found.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A boolean indicating whether the file can be allowed to be set as an avatar.
|
A boolean indicating whether the file can be allowed to be set as an avatar.
|
||||||
"""
|
"""
|
||||||
|
if mxc == "":
|
||||||
|
return True
|
||||||
|
|
||||||
if not self.max_avatar_size and not self.allowed_avatar_mimetypes:
|
if not self.max_avatar_size and not self.allowed_avatar_mimetypes:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
271
synapse/handlers/relations.py
Normal file
271
synapse/handlers/relations.py
Normal file
|
@ -0,0 +1,271 @@
|
||||||
|
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, Dict, Iterable, Optional, cast
|
||||||
|
|
||||||
|
import attr
|
||||||
|
from frozendict import frozendict
|
||||||
|
|
||||||
|
from synapse.api.constants import RelationTypes
|
||||||
|
from synapse.api.errors import SynapseError
|
||||||
|
from synapse.events import EventBase
|
||||||
|
from synapse.types import JsonDict, Requester, StreamToken
|
||||||
|
from synapse.visibility import filter_events_for_client
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.storage.databases.main import DataStore
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||||
|
class _ThreadAggregation:
|
||||||
|
# The latest event in the thread.
|
||||||
|
latest_event: EventBase
|
||||||
|
# The latest edit to the latest event in the thread.
|
||||||
|
latest_edit: Optional[EventBase]
|
||||||
|
# The total number of events in the thread.
|
||||||
|
count: int
|
||||||
|
# True if the current user has sent an event to the thread.
|
||||||
|
current_user_participated: bool
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(slots=True, auto_attribs=True)
|
||||||
|
class BundledAggregations:
|
||||||
|
"""
|
||||||
|
The bundled aggregations for an event.
|
||||||
|
|
||||||
|
Some values require additional processing during serialization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
annotations: Optional[JsonDict] = None
|
||||||
|
references: Optional[JsonDict] = None
|
||||||
|
replace: Optional[EventBase] = None
|
||||||
|
thread: Optional[_ThreadAggregation] = None
|
||||||
|
|
||||||
|
def __bool__(self) -> bool:
|
||||||
|
return bool(self.annotations or self.references or self.replace or self.thread)
|
||||||
|
|
||||||
|
|
||||||
|
class RelationsHandler:
|
||||||
|
def __init__(self, hs: "HomeServer"):
|
||||||
|
self._main_store = hs.get_datastores().main
|
||||||
|
self._storage = hs.get_storage()
|
||||||
|
self._auth = hs.get_auth()
|
||||||
|
self._clock = hs.get_clock()
|
||||||
|
self._event_handler = hs.get_event_handler()
|
||||||
|
self._event_serializer = hs.get_event_client_serializer()
|
||||||
|
|
||||||
|
async def get_relations(
|
||||||
|
self,
|
||||||
|
requester: Requester,
|
||||||
|
event_id: str,
|
||||||
|
room_id: str,
|
||||||
|
relation_type: Optional[str] = None,
|
||||||
|
event_type: Optional[str] = None,
|
||||||
|
aggregation_key: Optional[str] = None,
|
||||||
|
limit: int = 5,
|
||||||
|
direction: str = "b",
|
||||||
|
from_token: Optional[StreamToken] = None,
|
||||||
|
to_token: Optional[StreamToken] = None,
|
||||||
|
) -> JsonDict:
|
||||||
|
"""Get related events of a event, ordered by topological ordering.
|
||||||
|
|
||||||
|
TODO Accept a PaginationConfig instead of individual pagination parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
requester: The user requesting the relations.
|
||||||
|
event_id: Fetch events that relate to this event ID.
|
||||||
|
room_id: The room the event belongs to.
|
||||||
|
relation_type: Only fetch events with this relation type, if given.
|
||||||
|
event_type: Only fetch events with this event type, if given.
|
||||||
|
aggregation_key: Only fetch events with this aggregation key, if given.
|
||||||
|
limit: Only fetch the most recent `limit` events.
|
||||||
|
direction: Whether to fetch the most recent first (`"b"`) or the
|
||||||
|
oldest first (`"f"`).
|
||||||
|
from_token: Fetch rows from the given token, or from the start if None.
|
||||||
|
to_token: Fetch rows up to the given token, or up to the end if None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The pagination chunk.
|
||||||
|
"""
|
||||||
|
|
||||||
|
user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
# TODO Properly handle a user leaving a room.
|
||||||
|
(_, member_event_id) = await self._auth.check_user_in_room_or_world_readable(
|
||||||
|
room_id, user_id, allow_departed_users=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# This gets the original event and checks that a) the event exists and
|
||||||
|
# b) the user is allowed to view it.
|
||||||
|
event = await self._event_handler.get_event(requester.user, room_id, event_id)
|
||||||
|
if event is None:
|
||||||
|
raise SynapseError(404, "Unknown parent event.")
|
||||||
|
|
||||||
|
pagination_chunk = await self._main_store.get_relations_for_event(
|
||||||
|
event_id=event_id,
|
||||||
|
event=event,
|
||||||
|
room_id=room_id,
|
||||||
|
relation_type=relation_type,
|
||||||
|
event_type=event_type,
|
||||||
|
aggregation_key=aggregation_key,
|
||||||
|
limit=limit,
|
||||||
|
direction=direction,
|
||||||
|
from_token=from_token,
|
||||||
|
to_token=to_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
events = await self._main_store.get_events_as_list(
|
||||||
|
[c["event_id"] for c in pagination_chunk.chunk]
|
||||||
|
)
|
||||||
|
|
||||||
|
events = await filter_events_for_client(
|
||||||
|
self._storage, user_id, events, is_peeking=(member_event_id is None)
|
||||||
|
)
|
||||||
|
|
||||||
|
now = self._clock.time_msec()
|
||||||
|
# Do not bundle aggregations when retrieving the original event because
|
||||||
|
# we want the content before relations are applied to it.
|
||||||
|
original_event = self._event_serializer.serialize_event(
|
||||||
|
event, now, bundle_aggregations=None
|
||||||
|
)
|
||||||
|
# The relations returned for the requested event do include their
|
||||||
|
# bundled aggregations.
|
||||||
|
aggregations = await self.get_bundled_aggregations(
|
||||||
|
events, requester.user.to_string()
|
||||||
|
)
|
||||||
|
serialized_events = self._event_serializer.serialize_events(
|
||||||
|
events, now, bundle_aggregations=aggregations
|
||||||
|
)
|
||||||
|
|
||||||
|
return_value = await pagination_chunk.to_dict(self._main_store)
|
||||||
|
return_value["chunk"] = serialized_events
|
||||||
|
return_value["original_event"] = original_event
|
||||||
|
|
||||||
|
return return_value
|
||||||
|
|
||||||
|
async def _get_bundled_aggregation_for_event(
|
||||||
|
self, event: EventBase, user_id: str
|
||||||
|
) -> Optional[BundledAggregations]:
|
||||||
|
"""Generate bundled aggregations for an event.
|
||||||
|
|
||||||
|
Note that this does not use a cache, but depends on cached methods.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event: The event to calculate bundled aggregations for.
|
||||||
|
user_id: The user requesting the bundled aggregations.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The bundled aggregations for an event, if bundled aggregations are
|
||||||
|
enabled and the event can have bundled aggregations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Do not bundle aggregations for an event which represents an edit or an
|
||||||
|
# annotation. It does not make sense for them to have related events.
|
||||||
|
relates_to = event.content.get("m.relates_to")
|
||||||
|
if isinstance(relates_to, (dict, frozendict)):
|
||||||
|
relation_type = relates_to.get("rel_type")
|
||||||
|
if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
|
||||||
|
return None
|
||||||
|
|
||||||
|
event_id = event.event_id
|
||||||
|
room_id = event.room_id
|
||||||
|
|
||||||
|
# The bundled aggregations to include, a mapping of relation type to a
|
||||||
|
# type-specific value. Some types include the direct return type here
|
||||||
|
# while others need more processing during serialization.
|
||||||
|
aggregations = BundledAggregations()
|
||||||
|
|
||||||
|
annotations = await self._main_store.get_aggregation_groups_for_event(
|
||||||
|
event_id, room_id
|
||||||
|
)
|
||||||
|
if annotations.chunk:
|
||||||
|
aggregations.annotations = await annotations.to_dict(
|
||||||
|
cast("DataStore", self)
|
||||||
|
)
|
||||||
|
|
||||||
|
references = await self._main_store.get_relations_for_event(
|
||||||
|
event_id, event, room_id, RelationTypes.REFERENCE, direction="f"
|
||||||
|
)
|
||||||
|
if references.chunk:
|
||||||
|
aggregations.references = await references.to_dict(cast("DataStore", self))
|
||||||
|
|
||||||
|
# Store the bundled aggregations in the event metadata for later use.
|
||||||
|
return aggregations
|
||||||
|
|
||||||
|
async def get_bundled_aggregations(
|
||||||
|
self, events: Iterable[EventBase], user_id: str
|
||||||
|
) -> Dict[str, BundledAggregations]:
|
||||||
|
"""Generate bundled aggregations for events.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
events: The iterable of events to calculate bundled aggregations for.
|
||||||
|
user_id: The user requesting the bundled aggregations.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A map of event ID to the bundled aggregation for the event. Not all
|
||||||
|
events may have bundled aggregations in the results.
|
||||||
|
"""
|
||||||
|
# De-duplicate events by ID to handle the same event requested multiple times.
|
||||||
|
#
|
||||||
|
# State events do not get bundled aggregations.
|
||||||
|
events_by_id = {
|
||||||
|
event.event_id: event for event in events if not event.is_state()
|
||||||
|
}
|
||||||
|
|
||||||
|
# event ID -> bundled aggregation in non-serialized form.
|
||||||
|
results: Dict[str, BundledAggregations] = {}
|
||||||
|
|
||||||
|
# Fetch other relations per event.
|
||||||
|
for event in events_by_id.values():
|
||||||
|
event_result = await self._get_bundled_aggregation_for_event(event, user_id)
|
||||||
|
if event_result:
|
||||||
|
results[event.event_id] = event_result
|
||||||
|
|
||||||
|
# Fetch any edits (but not for redacted events).
|
||||||
|
edits = await self._main_store.get_applicable_edits(
|
||||||
|
[
|
||||||
|
event_id
|
||||||
|
for event_id, event in events_by_id.items()
|
||||||
|
if not event.internal_metadata.is_redacted()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
for event_id, edit in edits.items():
|
||||||
|
results.setdefault(event_id, BundledAggregations()).replace = edit
|
||||||
|
|
||||||
|
# Fetch thread summaries.
|
||||||
|
summaries = await self._main_store.get_thread_summaries(events_by_id.keys())
|
||||||
|
# Only fetch participated for a limited selection based on what had
|
||||||
|
# summaries.
|
||||||
|
participated = await self._main_store.get_threads_participated(
|
||||||
|
[event_id for event_id, summary in summaries.items() if summary], user_id
|
||||||
|
)
|
||||||
|
for event_id, summary in summaries.items():
|
||||||
|
if summary:
|
||||||
|
thread_count, latest_thread_event, edit = summary
|
||||||
|
results.setdefault(
|
||||||
|
event_id, BundledAggregations()
|
||||||
|
).thread = _ThreadAggregation(
|
||||||
|
latest_event=latest_thread_event,
|
||||||
|
latest_edit=edit,
|
||||||
|
count=thread_count,
|
||||||
|
# If there's a thread summary it must also exist in the
|
||||||
|
# participated dictionary.
|
||||||
|
current_user_participated=participated[event_id],
|
||||||
|
)
|
||||||
|
|
||||||
|
return results
|
|
@ -60,8 +60,8 @@ from synapse.events import EventBase
|
||||||
from synapse.events.utils import copy_power_levels_contents
|
from synapse.events.utils import copy_power_levels_contents
|
||||||
from synapse.federation.federation_client import InvalidResponseError
|
from synapse.federation.federation_client import InvalidResponseError
|
||||||
from synapse.handlers.federation import get_domains_from_state
|
from synapse.handlers.federation import get_domains_from_state
|
||||||
|
from synapse.handlers.relations import BundledAggregations
|
||||||
from synapse.rest.admin._base import assert_user_is_admin
|
from synapse.rest.admin._base import assert_user_is_admin
|
||||||
from synapse.storage.databases.main.relations import BundledAggregations
|
|
||||||
from synapse.storage.state import StateFilter
|
from synapse.storage.state import StateFilter
|
||||||
from synapse.streams import EventSource
|
from synapse.streams import EventSource
|
||||||
from synapse.types import (
|
from synapse.types import (
|
||||||
|
@ -1128,6 +1128,7 @@ class RoomContextHandler:
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.storage = hs.get_storage()
|
self.storage = hs.get_storage()
|
||||||
self.state_store = self.storage.state
|
self.state_store = self.storage.state
|
||||||
|
self._relations_handler = hs.get_relations_handler()
|
||||||
|
|
||||||
async def get_event_context(
|
async def get_event_context(
|
||||||
self,
|
self,
|
||||||
|
@ -1200,7 +1201,7 @@ class RoomContextHandler:
|
||||||
event = filtered[0]
|
event = filtered[0]
|
||||||
|
|
||||||
# Fetch the aggregations.
|
# Fetch the aggregations.
|
||||||
aggregations = await self.store.get_bundled_aggregations(
|
aggregations = await self._relations_handler.get_bundled_aggregations(
|
||||||
itertools.chain(events_before, (event,), events_after),
|
itertools.chain(events_before, (event,), events_after),
|
||||||
user.to_string(),
|
user.to_string(),
|
||||||
)
|
)
|
||||||
|
|
|
@ -123,12 +123,11 @@ class RoomBatchHandler:
|
||||||
|
|
||||||
return create_requester(user_id, app_service=app_service)
|
return create_requester(user_id, app_service=app_service)
|
||||||
|
|
||||||
async def get_most_recent_auth_event_ids_from_event_id_list(
|
async def get_most_recent_full_state_ids_from_event_id_list(
|
||||||
self, event_ids: List[str]
|
self, event_ids: List[str]
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""Find the most recent auth event ids (derived from state events) that
|
"""Find the most recent event_id and grab the full state at that event.
|
||||||
allowed that message to be sent. We will use this as a base
|
We will use this as a base to auth our historical messages against.
|
||||||
to auth our historical messages against.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
event_ids: List of event ID's to look at
|
event_ids: List of event ID's to look at
|
||||||
|
@ -138,38 +137,37 @@ class RoomBatchHandler:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
(
|
(
|
||||||
most_recent_prev_event_id,
|
most_recent_event_id,
|
||||||
_,
|
_,
|
||||||
) = await self.store.get_max_depth_of(event_ids)
|
) = await self.store.get_max_depth_of(event_ids)
|
||||||
# mapping from (type, state_key) -> state_event_id
|
# mapping from (type, state_key) -> state_event_id
|
||||||
prev_state_map = await self.state_store.get_state_ids_for_event(
|
prev_state_map = await self.state_store.get_state_ids_for_event(
|
||||||
most_recent_prev_event_id
|
most_recent_event_id
|
||||||
)
|
)
|
||||||
# List of state event ID's
|
# List of state event ID's
|
||||||
prev_state_ids = list(prev_state_map.values())
|
full_state_ids = list(prev_state_map.values())
|
||||||
auth_event_ids = prev_state_ids
|
|
||||||
|
|
||||||
return auth_event_ids
|
return full_state_ids
|
||||||
|
|
||||||
async def persist_state_events_at_start(
|
async def persist_state_events_at_start(
|
||||||
self,
|
self,
|
||||||
state_events_at_start: List[JsonDict],
|
state_events_at_start: List[JsonDict],
|
||||||
room_id: str,
|
room_id: str,
|
||||||
initial_auth_event_ids: List[str],
|
initial_state_event_ids: List[str],
|
||||||
app_service_requester: Requester,
|
app_service_requester: Requester,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""Takes all `state_events_at_start` event dictionaries and creates/persists
|
"""Takes all `state_events_at_start` event dictionaries and creates/persists
|
||||||
them as floating state events which don't resolve into the current room state.
|
them in a floating state event chain which don't resolve into the current room
|
||||||
They are floating because they reference a fake prev_event which doesn't connect
|
state. They are floating because they reference no prev_events and are marked
|
||||||
to the normal DAG at all.
|
as outliers which disconnects them from the normal DAG.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
state_events_at_start:
|
state_events_at_start:
|
||||||
room_id: Room where you want the events persisted in.
|
room_id: Room where you want the events persisted in.
|
||||||
initial_auth_event_ids: These will be the auth_events for the first
|
initial_state_event_ids:
|
||||||
state event created. Each event created afterwards will be
|
The base set of state for the historical batch which the floating
|
||||||
added to the list of auth events for the next state event
|
state chain will derive from. This should probably be the state
|
||||||
created.
|
from the `prev_event` defined by `/batch_send?prev_event_id=$abc`.
|
||||||
app_service_requester: The requester of an application service.
|
app_service_requester: The requester of an application service.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -178,7 +176,7 @@ class RoomBatchHandler:
|
||||||
assert app_service_requester.app_service
|
assert app_service_requester.app_service
|
||||||
|
|
||||||
state_event_ids_at_start = []
|
state_event_ids_at_start = []
|
||||||
auth_event_ids = initial_auth_event_ids.copy()
|
state_event_ids = initial_state_event_ids.copy()
|
||||||
|
|
||||||
# Make the state events float off on their own by specifying no
|
# Make the state events float off on their own by specifying no
|
||||||
# prev_events for the first one in the chain so we don't have a bunch of
|
# prev_events for the first one in the chain so we don't have a bunch of
|
||||||
|
@ -191,9 +189,7 @@ class RoomBatchHandler:
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"RoomBatchSendEventRestServlet inserting state_event=%s, auth_event_ids=%s",
|
"RoomBatchSendEventRestServlet inserting state_event=%s", state_event
|
||||||
state_event,
|
|
||||||
auth_event_ids,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
event_dict = {
|
event_dict = {
|
||||||
|
@ -219,16 +215,26 @@ class RoomBatchHandler:
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
action=membership,
|
action=membership,
|
||||||
content=event_dict["content"],
|
content=event_dict["content"],
|
||||||
|
# Mark as an outlier to disconnect it from the normal DAG
|
||||||
|
# and not show up between batches of history.
|
||||||
outlier=True,
|
outlier=True,
|
||||||
historical=True,
|
historical=True,
|
||||||
# Only the first event in the chain should be floating.
|
# Only the first event in the state chain should be floating.
|
||||||
# The rest should hang off each other in a chain.
|
# The rest should hang off each other in a chain.
|
||||||
allow_no_prev_events=index == 0,
|
allow_no_prev_events=index == 0,
|
||||||
prev_event_ids=prev_event_ids_for_state_chain,
|
prev_event_ids=prev_event_ids_for_state_chain,
|
||||||
|
# Since each state event is marked as an outlier, the
|
||||||
|
# `EventContext.for_outlier()` won't have any `state_ids`
|
||||||
|
# set and therefore can't derive any state even though the
|
||||||
|
# prev_events are set. Also since the first event in the
|
||||||
|
# state chain is floating with no `prev_events`, it can't
|
||||||
|
# derive state from anywhere automatically. So we need to
|
||||||
|
# set some state explicitly.
|
||||||
|
#
|
||||||
# Make sure to use a copy of this list because we modify it
|
# Make sure to use a copy of this list because we modify it
|
||||||
# later in the loop here. Otherwise it will be the same
|
# later in the loop here. Otherwise it will be the same
|
||||||
# reference and also update in the event when we append later.
|
# reference and also update in the event when we append later.
|
||||||
auth_event_ids=auth_event_ids.copy(),
|
state_event_ids=state_event_ids.copy(),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# TODO: Add some complement tests that adds state that is not member joins
|
# TODO: Add some complement tests that adds state that is not member joins
|
||||||
|
@ -242,21 +248,31 @@ class RoomBatchHandler:
|
||||||
state_event["sender"], app_service_requester.app_service
|
state_event["sender"], app_service_requester.app_service
|
||||||
),
|
),
|
||||||
event_dict,
|
event_dict,
|
||||||
|
# Mark as an outlier to disconnect it from the normal DAG
|
||||||
|
# and not show up between batches of history.
|
||||||
outlier=True,
|
outlier=True,
|
||||||
historical=True,
|
historical=True,
|
||||||
# Only the first event in the chain should be floating.
|
# Only the first event in the state chain should be floating.
|
||||||
# The rest should hang off each other in a chain.
|
# The rest should hang off each other in a chain.
|
||||||
allow_no_prev_events=index == 0,
|
allow_no_prev_events=index == 0,
|
||||||
prev_event_ids=prev_event_ids_for_state_chain,
|
prev_event_ids=prev_event_ids_for_state_chain,
|
||||||
|
# Since each state event is marked as an outlier, the
|
||||||
|
# `EventContext.for_outlier()` won't have any `state_ids`
|
||||||
|
# set and therefore can't derive any state even though the
|
||||||
|
# prev_events are set. Also since the first event in the
|
||||||
|
# state chain is floating with no `prev_events`, it can't
|
||||||
|
# derive state from anywhere automatically. So we need to
|
||||||
|
# set some state explicitly.
|
||||||
|
#
|
||||||
# Make sure to use a copy of this list because we modify it
|
# Make sure to use a copy of this list because we modify it
|
||||||
# later in the loop here. Otherwise it will be the same
|
# later in the loop here. Otherwise it will be the same
|
||||||
# reference and also update in the event when we append later.
|
# reference and also update in the event when we append later.
|
||||||
auth_event_ids=auth_event_ids.copy(),
|
state_event_ids=state_event_ids.copy(),
|
||||||
)
|
)
|
||||||
event_id = event.event_id
|
event_id = event.event_id
|
||||||
|
|
||||||
state_event_ids_at_start.append(event_id)
|
state_event_ids_at_start.append(event_id)
|
||||||
auth_event_ids.append(event_id)
|
state_event_ids.append(event_id)
|
||||||
# Connect all the state in a floating chain
|
# Connect all the state in a floating chain
|
||||||
prev_event_ids_for_state_chain = [event_id]
|
prev_event_ids_for_state_chain = [event_id]
|
||||||
|
|
||||||
|
@ -267,7 +283,7 @@ class RoomBatchHandler:
|
||||||
events_to_create: List[JsonDict],
|
events_to_create: List[JsonDict],
|
||||||
room_id: str,
|
room_id: str,
|
||||||
inherited_depth: int,
|
inherited_depth: int,
|
||||||
auth_event_ids: List[str],
|
initial_state_event_ids: List[str],
|
||||||
app_service_requester: Requester,
|
app_service_requester: Requester,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""Create and persists all events provided sequentially. Handles the
|
"""Create and persists all events provided sequentially. Handles the
|
||||||
|
@ -283,8 +299,10 @@ class RoomBatchHandler:
|
||||||
room_id: Room where you want the events persisted in.
|
room_id: Room where you want the events persisted in.
|
||||||
inherited_depth: The depth to create the events at (you will
|
inherited_depth: The depth to create the events at (you will
|
||||||
probably by calling inherit_depth_from_prev_ids(...)).
|
probably by calling inherit_depth_from_prev_ids(...)).
|
||||||
auth_event_ids: Define which events allow you to create the given
|
initial_state_event_ids:
|
||||||
event in the room.
|
This is used to set explicit state for the insertion event at
|
||||||
|
the start of the historical batch since it's floating with no
|
||||||
|
prev_events to derive state from automatically.
|
||||||
app_service_requester: The requester of an application service.
|
app_service_requester: The requester of an application service.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -292,6 +310,11 @@ class RoomBatchHandler:
|
||||||
"""
|
"""
|
||||||
assert app_service_requester.app_service
|
assert app_service_requester.app_service
|
||||||
|
|
||||||
|
# We expect the first event in a historical batch to be an insertion event
|
||||||
|
assert events_to_create[0]["type"] == EventTypes.MSC2716_INSERTION
|
||||||
|
# We expect the last event in a historical batch to be an batch event
|
||||||
|
assert events_to_create[-1]["type"] == EventTypes.MSC2716_BATCH
|
||||||
|
|
||||||
# Make the historical event chain float off on its own by specifying no
|
# Make the historical event chain float off on its own by specifying no
|
||||||
# prev_events for the first event in the chain which causes the HS to
|
# prev_events for the first event in the chain which causes the HS to
|
||||||
# ask for the state at the start of the batch later.
|
# ask for the state at the start of the batch later.
|
||||||
|
@ -323,11 +346,16 @@ class RoomBatchHandler:
|
||||||
ev["sender"], app_service_requester.app_service
|
ev["sender"], app_service_requester.app_service
|
||||||
),
|
),
|
||||||
event_dict,
|
event_dict,
|
||||||
# Only the first event in the chain should be floating.
|
# Only the first event (which is the insertion event) in the
|
||||||
# The rest should hang off each other in a chain.
|
# chain should be floating. The rest should hang off each other
|
||||||
|
# in a chain.
|
||||||
allow_no_prev_events=index == 0,
|
allow_no_prev_events=index == 0,
|
||||||
prev_event_ids=event_dict.get("prev_events"),
|
prev_event_ids=event_dict.get("prev_events"),
|
||||||
auth_event_ids=auth_event_ids,
|
# Since the first event (which is the insertion event) in the
|
||||||
|
# chain is floating with no `prev_events`, it can't derive state
|
||||||
|
# from anywhere automatically. So we need to set some state
|
||||||
|
# explicitly.
|
||||||
|
state_event_ids=initial_state_event_ids if index == 0 else None,
|
||||||
historical=True,
|
historical=True,
|
||||||
depth=inherited_depth,
|
depth=inherited_depth,
|
||||||
)
|
)
|
||||||
|
@ -345,10 +373,9 @@ class RoomBatchHandler:
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"RoomBatchSendEventRestServlet inserting event=%s, prev_event_ids=%s, auth_event_ids=%s",
|
"RoomBatchSendEventRestServlet inserting event=%s, prev_event_ids=%s",
|
||||||
event,
|
event,
|
||||||
prev_event_ids,
|
prev_event_ids,
|
||||||
auth_event_ids,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
events_to_persist.append((event, context))
|
events_to_persist.append((event, context))
|
||||||
|
@ -378,12 +405,12 @@ class RoomBatchHandler:
|
||||||
room_id: str,
|
room_id: str,
|
||||||
batch_id_to_connect_to: str,
|
batch_id_to_connect_to: str,
|
||||||
inherited_depth: int,
|
inherited_depth: int,
|
||||||
auth_event_ids: List[str],
|
initial_state_event_ids: List[str],
|
||||||
app_service_requester: Requester,
|
app_service_requester: Requester,
|
||||||
) -> Tuple[List[str], str]:
|
) -> Tuple[List[str], str]:
|
||||||
"""
|
"""
|
||||||
Handles creating and persisting all of the historical events as well
|
Handles creating and persisting all of the historical events as well as
|
||||||
as insertion and batch meta events to make the batch navigable in the DAG.
|
insertion and batch meta events to make the batch navigable in the DAG.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
events_to_create: List of historical events to create in JSON
|
events_to_create: List of historical events to create in JSON
|
||||||
|
@ -393,8 +420,13 @@ class RoomBatchHandler:
|
||||||
want this batch to connect to.
|
want this batch to connect to.
|
||||||
inherited_depth: The depth to create the events at (you will
|
inherited_depth: The depth to create the events at (you will
|
||||||
probably by calling inherit_depth_from_prev_ids(...)).
|
probably by calling inherit_depth_from_prev_ids(...)).
|
||||||
auth_event_ids: Define which events allow you to create the given
|
initial_state_event_ids:
|
||||||
event in the room.
|
This is used to set explicit state for the insertion event at
|
||||||
|
the start of the historical batch since it's floating with no
|
||||||
|
prev_events to derive state from automatically. This should
|
||||||
|
probably be the state from the `prev_event` defined by
|
||||||
|
`/batch_send?prev_event_id=$abc` plus the outcome of
|
||||||
|
`persist_state_events_at_start`
|
||||||
app_service_requester: The requester of an application service.
|
app_service_requester: The requester of an application service.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -440,7 +472,7 @@ class RoomBatchHandler:
|
||||||
events_to_create=events_to_create,
|
events_to_create=events_to_create,
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
inherited_depth=inherited_depth,
|
inherited_depth=inherited_depth,
|
||||||
auth_event_ids=auth_event_ids,
|
initial_state_event_ids=initial_state_event_ids,
|
||||||
app_service_requester=app_service_requester,
|
app_service_requester=app_service_requester,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -271,7 +271,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||||
membership: str,
|
membership: str,
|
||||||
allow_no_prev_events: bool = False,
|
allow_no_prev_events: bool = False,
|
||||||
prev_event_ids: Optional[List[str]] = None,
|
prev_event_ids: Optional[List[str]] = None,
|
||||||
auth_event_ids: Optional[List[str]] = None,
|
state_event_ids: Optional[List[str]] = None,
|
||||||
txn_id: Optional[str] = None,
|
txn_id: Optional[str] = None,
|
||||||
ratelimit: bool = True,
|
ratelimit: bool = True,
|
||||||
content: Optional[dict] = None,
|
content: Optional[dict] = None,
|
||||||
|
@ -294,10 +294,14 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||||
events should have a prev_event and we should only use this in special
|
events should have a prev_event and we should only use this in special
|
||||||
cases like MSC2716.
|
cases like MSC2716.
|
||||||
prev_event_ids: The event IDs to use as the prev events
|
prev_event_ids: The event IDs to use as the prev events
|
||||||
auth_event_ids:
|
state_event_ids:
|
||||||
The event ids to use as the auth_events for the new event.
|
The full state at a given event. This is used particularly by the MSC2716
|
||||||
Should normally be left as None, which will cause them to be calculated
|
/batch_send endpoint. One use case is the historical `state_events_at_start`;
|
||||||
based on the room state at the prev_events.
|
since each is marked as an `outlier`, the `EventContext.for_outlier()` won't
|
||||||
|
have any `state_ids` set and therefore can't derive any state even though the
|
||||||
|
prev_events are set so we need to set them ourself via this argument.
|
||||||
|
This should normally be left as None, which will cause the auth_event_ids
|
||||||
|
to be calculated based on the room state at the prev_events.
|
||||||
|
|
||||||
txn_id:
|
txn_id:
|
||||||
ratelimit:
|
ratelimit:
|
||||||
|
@ -352,7 +356,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||||
txn_id=txn_id,
|
txn_id=txn_id,
|
||||||
allow_no_prev_events=allow_no_prev_events,
|
allow_no_prev_events=allow_no_prev_events,
|
||||||
prev_event_ids=prev_event_ids,
|
prev_event_ids=prev_event_ids,
|
||||||
auth_event_ids=auth_event_ids,
|
state_event_ids=state_event_ids,
|
||||||
require_consent=require_consent,
|
require_consent=require_consent,
|
||||||
outlier=outlier,
|
outlier=outlier,
|
||||||
historical=historical,
|
historical=historical,
|
||||||
|
@ -455,7 +459,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||||
historical: bool = False,
|
historical: bool = False,
|
||||||
allow_no_prev_events: bool = False,
|
allow_no_prev_events: bool = False,
|
||||||
prev_event_ids: Optional[List[str]] = None,
|
prev_event_ids: Optional[List[str]] = None,
|
||||||
auth_event_ids: Optional[List[str]] = None,
|
state_event_ids: Optional[List[str]] = None,
|
||||||
) -> Tuple[str, int]:
|
) -> Tuple[str, int]:
|
||||||
"""Update a user's membership in a room.
|
"""Update a user's membership in a room.
|
||||||
|
|
||||||
|
@ -483,10 +487,14 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||||
events should have a prev_event and we should only use this in special
|
events should have a prev_event and we should only use this in special
|
||||||
cases like MSC2716.
|
cases like MSC2716.
|
||||||
prev_event_ids: The event IDs to use as the prev events
|
prev_event_ids: The event IDs to use as the prev events
|
||||||
auth_event_ids:
|
state_event_ids:
|
||||||
The event ids to use as the auth_events for the new event.
|
The full state at a given event. This is used particularly by the MSC2716
|
||||||
Should normally be left as None, which will cause them to be calculated
|
/batch_send endpoint. One use case is the historical `state_events_at_start`;
|
||||||
based on the room state at the prev_events.
|
since each is marked as an `outlier`, the `EventContext.for_outlier()` won't
|
||||||
|
have any `state_ids` set and therefore can't derive any state even though the
|
||||||
|
prev_events are set so we need to set them ourself via this argument.
|
||||||
|
This should normally be left as None, which will cause the auth_event_ids
|
||||||
|
to be calculated based on the room state at the prev_events.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of the new event ID and stream ID.
|
A tuple of the new event ID and stream ID.
|
||||||
|
@ -525,7 +533,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||||
historical=historical,
|
historical=historical,
|
||||||
allow_no_prev_events=allow_no_prev_events,
|
allow_no_prev_events=allow_no_prev_events,
|
||||||
prev_event_ids=prev_event_ids,
|
prev_event_ids=prev_event_ids,
|
||||||
auth_event_ids=auth_event_ids,
|
state_event_ids=state_event_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
@ -547,7 +555,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||||
historical: bool = False,
|
historical: bool = False,
|
||||||
allow_no_prev_events: bool = False,
|
allow_no_prev_events: bool = False,
|
||||||
prev_event_ids: Optional[List[str]] = None,
|
prev_event_ids: Optional[List[str]] = None,
|
||||||
auth_event_ids: Optional[List[str]] = None,
|
state_event_ids: Optional[List[str]] = None,
|
||||||
) -> Tuple[str, int]:
|
) -> Tuple[str, int]:
|
||||||
"""Helper for update_membership.
|
"""Helper for update_membership.
|
||||||
|
|
||||||
|
@ -577,10 +585,14 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||||
events should have a prev_event and we should only use this in special
|
events should have a prev_event and we should only use this in special
|
||||||
cases like MSC2716.
|
cases like MSC2716.
|
||||||
prev_event_ids: The event IDs to use as the prev events
|
prev_event_ids: The event IDs to use as the prev events
|
||||||
auth_event_ids:
|
state_event_ids:
|
||||||
The event ids to use as the auth_events for the new event.
|
The full state at a given event. This is used particularly by the MSC2716
|
||||||
Should normally be left as None, which will cause them to be calculated
|
/batch_send endpoint. One use case is the historical `state_events_at_start`;
|
||||||
based on the room state at the prev_events.
|
since each is marked as an `outlier`, the `EventContext.for_outlier()` won't
|
||||||
|
have any `state_ids` set and therefore can't derive any state even though the
|
||||||
|
prev_events are set so we need to set them ourself via this argument.
|
||||||
|
This should normally be left as None, which will cause the auth_event_ids
|
||||||
|
to be calculated based on the room state at the prev_events.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of the new event ID and stream ID.
|
A tuple of the new event ID and stream ID.
|
||||||
|
@ -707,7 +719,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||||
ratelimit=ratelimit,
|
ratelimit=ratelimit,
|
||||||
allow_no_prev_events=allow_no_prev_events,
|
allow_no_prev_events=allow_no_prev_events,
|
||||||
prev_event_ids=prev_event_ids,
|
prev_event_ids=prev_event_ids,
|
||||||
auth_event_ids=auth_event_ids,
|
state_event_ids=state_event_ids,
|
||||||
content=content,
|
content=content,
|
||||||
require_consent=require_consent,
|
require_consent=require_consent,
|
||||||
outlier=outlier,
|
outlier=outlier,
|
||||||
|
@ -931,7 +943,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||||
txn_id=txn_id,
|
txn_id=txn_id,
|
||||||
ratelimit=ratelimit,
|
ratelimit=ratelimit,
|
||||||
prev_event_ids=latest_event_ids,
|
prev_event_ids=latest_event_ids,
|
||||||
auth_event_ids=auth_event_ids,
|
state_event_ids=state_event_ids,
|
||||||
content=content,
|
content=content,
|
||||||
require_consent=require_consent,
|
require_consent=require_consent,
|
||||||
outlier=outlier,
|
outlier=outlier,
|
||||||
|
|
|
@ -54,6 +54,7 @@ class SearchHandler:
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self._event_serializer = hs.get_event_client_serializer()
|
self._event_serializer = hs.get_event_client_serializer()
|
||||||
|
self._relations_handler = hs.get_relations_handler()
|
||||||
self.storage = hs.get_storage()
|
self.storage = hs.get_storage()
|
||||||
self.state_store = self.storage.state
|
self.state_store = self.storage.state
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
@ -354,7 +355,7 @@ class SearchHandler:
|
||||||
|
|
||||||
aggregations = None
|
aggregations = None
|
||||||
if self._msc3666_enabled:
|
if self._msc3666_enabled:
|
||||||
aggregations = await self.store.get_bundled_aggregations(
|
aggregations = await self._relations_handler.get_bundled_aggregations(
|
||||||
# Generate an iterable of EventBase for all the events that will be
|
# Generate an iterable of EventBase for all the events that will be
|
||||||
# returned, including contextual events.
|
# returned, including contextual events.
|
||||||
itertools.chain(
|
itertools.chain(
|
||||||
|
|
|
@ -28,16 +28,16 @@ from typing import (
|
||||||
import attr
|
import attr
|
||||||
from prometheus_client import Counter
|
from prometheus_client import Counter
|
||||||
|
|
||||||
from synapse.api.constants import AccountDataTypes, EventTypes, Membership, ReceiptTypes
|
from synapse.api.constants import EventTypes, Membership, ReceiptTypes
|
||||||
from synapse.api.filtering import FilterCollection
|
from synapse.api.filtering import FilterCollection
|
||||||
from synapse.api.presence import UserPresenceState
|
from synapse.api.presence import UserPresenceState
|
||||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
|
from synapse.handlers.relations import BundledAggregations
|
||||||
from synapse.logging.context import current_context
|
from synapse.logging.context import current_context
|
||||||
from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span
|
from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span
|
||||||
from synapse.push.clientformat import format_push_rules_for_user
|
from synapse.push.clientformat import format_push_rules_for_user
|
||||||
from synapse.storage.databases.main.event_push_actions import NotifCounts
|
from synapse.storage.databases.main.event_push_actions import NotifCounts
|
||||||
from synapse.storage.databases.main.relations import BundledAggregations
|
|
||||||
from synapse.storage.roommember import MemberSummary
|
from synapse.storage.roommember import MemberSummary
|
||||||
from synapse.storage.state import StateFilter
|
from synapse.storage.state import StateFilter
|
||||||
from synapse.types import (
|
from synapse.types import (
|
||||||
|
@ -269,6 +269,7 @@ class SyncHandler:
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.notifier = hs.get_notifier()
|
self.notifier = hs.get_notifier()
|
||||||
self.presence_handler = hs.get_presence_handler()
|
self.presence_handler = hs.get_presence_handler()
|
||||||
|
self._relations_handler = hs.get_relations_handler()
|
||||||
self.event_sources = hs.get_event_sources()
|
self.event_sources = hs.get_event_sources()
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.state = hs.get_state_handler()
|
self.state = hs.get_state_handler()
|
||||||
|
@ -638,9 +639,11 @@ class SyncHandler:
|
||||||
# as clients will have all the necessary information.
|
# as clients will have all the necessary information.
|
||||||
bundled_aggregations = None
|
bundled_aggregations = None
|
||||||
if limited or newly_joined_room:
|
if limited or newly_joined_room:
|
||||||
bundled_aggregations = await self.store.get_bundled_aggregations(
|
bundled_aggregations = (
|
||||||
|
await self._relations_handler.get_bundled_aggregations(
|
||||||
recents, sync_config.user.to_string()
|
recents, sync_config.user.to_string()
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return TimelineBatch(
|
return TimelineBatch(
|
||||||
events=recents,
|
events=recents,
|
||||||
|
@ -1600,7 +1603,7 @@ class SyncHandler:
|
||||||
return set(), set(), set(), set()
|
return set(), set(), set(), set()
|
||||||
|
|
||||||
# 3. Work out which rooms need reporting in the sync response.
|
# 3. Work out which rooms need reporting in the sync response.
|
||||||
ignored_users = await self._get_ignored_users(user_id)
|
ignored_users = await self.store.ignored_users(user_id)
|
||||||
if since_token:
|
if since_token:
|
||||||
room_changes = await self._get_rooms_changed(
|
room_changes = await self._get_rooms_changed(
|
||||||
sync_result_builder, ignored_users
|
sync_result_builder, ignored_users
|
||||||
|
@ -1626,7 +1629,6 @@ class SyncHandler:
|
||||||
logger.debug("Generating room entry for %s", room_entry.room_id)
|
logger.debug("Generating room entry for %s", room_entry.room_id)
|
||||||
await self._generate_room_entry(
|
await self._generate_room_entry(
|
||||||
sync_result_builder,
|
sync_result_builder,
|
||||||
ignored_users,
|
|
||||||
room_entry,
|
room_entry,
|
||||||
ephemeral=ephemeral_by_room.get(room_entry.room_id, []),
|
ephemeral=ephemeral_by_room.get(room_entry.room_id, []),
|
||||||
tags=tags_by_room.get(room_entry.room_id),
|
tags=tags_by_room.get(room_entry.room_id),
|
||||||
|
@ -1656,29 +1658,6 @@ class SyncHandler:
|
||||||
newly_left_users,
|
newly_left_users,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _get_ignored_users(self, user_id: str) -> FrozenSet[str]:
|
|
||||||
"""Retrieve the users ignored by the given user from their global account_data.
|
|
||||||
|
|
||||||
Returns an empty set if
|
|
||||||
- there is no global account_data entry for ignored_users
|
|
||||||
- there is such an entry, but it's not a JSON object.
|
|
||||||
"""
|
|
||||||
# TODO: Can we `SELECT ignored_user_id FROM ignored_users WHERE ignorer_user_id=?;` instead?
|
|
||||||
ignored_account_data = (
|
|
||||||
await self.store.get_global_account_data_by_type_for_user(
|
|
||||||
user_id=user_id, data_type=AccountDataTypes.IGNORED_USER_LIST
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# If there is ignored users account data and it matches the proper type,
|
|
||||||
# then use it.
|
|
||||||
ignored_users: FrozenSet[str] = frozenset()
|
|
||||||
if ignored_account_data:
|
|
||||||
ignored_users_data = ignored_account_data.get("ignored_users", {})
|
|
||||||
if isinstance(ignored_users_data, dict):
|
|
||||||
ignored_users = frozenset(ignored_users_data.keys())
|
|
||||||
return ignored_users
|
|
||||||
|
|
||||||
async def _have_rooms_changed(
|
async def _have_rooms_changed(
|
||||||
self, sync_result_builder: "SyncResultBuilder"
|
self, sync_result_builder: "SyncResultBuilder"
|
||||||
) -> bool:
|
) -> bool:
|
||||||
|
@ -2021,7 +2000,6 @@ class SyncHandler:
|
||||||
async def _generate_room_entry(
|
async def _generate_room_entry(
|
||||||
self,
|
self,
|
||||||
sync_result_builder: "SyncResultBuilder",
|
sync_result_builder: "SyncResultBuilder",
|
||||||
ignored_users: FrozenSet[str],
|
|
||||||
room_builder: "RoomSyncResultBuilder",
|
room_builder: "RoomSyncResultBuilder",
|
||||||
ephemeral: List[JsonDict],
|
ephemeral: List[JsonDict],
|
||||||
tags: Optional[Dict[str, Dict[str, Any]]],
|
tags: Optional[Dict[str, Dict[str, Any]]],
|
||||||
|
@ -2050,7 +2028,6 @@ class SyncHandler:
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sync_result_builder
|
sync_result_builder
|
||||||
ignored_users: Set of users ignored by user.
|
|
||||||
room_builder
|
room_builder
|
||||||
ephemeral: List of new ephemeral events for room
|
ephemeral: List of new ephemeral events for room
|
||||||
tags: List of *all* tags for room, or None if there has been
|
tags: List of *all* tags for room, or None if there has been
|
||||||
|
|
|
@ -19,8 +19,8 @@ import synapse.metrics
|
||||||
from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules, Membership
|
from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules, Membership
|
||||||
from synapse.handlers.state_deltas import MatchChange, StateDeltasHandler
|
from synapse.handlers.state_deltas import MatchChange, StateDeltasHandler
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
|
from synapse.storage.databases.main.user_directory import SearchResult
|
||||||
from synapse.storage.roommember import ProfileInfo
|
from synapse.storage.roommember import ProfileInfo
|
||||||
from synapse.types import JsonDict
|
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -78,7 +78,7 @@ class UserDirectoryHandler(StateDeltasHandler):
|
||||||
|
|
||||||
async def search_users(
|
async def search_users(
|
||||||
self, user_id: str, search_term: str, limit: int
|
self, user_id: str, search_term: str, limit: int
|
||||||
) -> JsonDict:
|
) -> SearchResult:
|
||||||
"""Searches for users in directory
|
"""Searches for users in directory
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
|
@ -111,6 +111,7 @@ from synapse.types import (
|
||||||
StateMap,
|
StateMap,
|
||||||
UserID,
|
UserID,
|
||||||
UserInfo,
|
UserInfo,
|
||||||
|
UserProfile,
|
||||||
create_requester,
|
create_requester,
|
||||||
)
|
)
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
@ -150,6 +151,7 @@ __all__ = [
|
||||||
"EventBase",
|
"EventBase",
|
||||||
"StateMap",
|
"StateMap",
|
||||||
"ProfileInfo",
|
"ProfileInfo",
|
||||||
|
"UserProfile",
|
||||||
]
|
]
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -609,15 +611,18 @@ class ModuleApi:
|
||||||
localpart: str,
|
localpart: str,
|
||||||
displayname: Optional[str] = None,
|
displayname: Optional[str] = None,
|
||||||
emails: Optional[List[str]] = None,
|
emails: Optional[List[str]] = None,
|
||||||
|
admin: bool = False,
|
||||||
) -> "defer.Deferred[str]":
|
) -> "defer.Deferred[str]":
|
||||||
"""Registers a new user with given localpart and optional displayname, emails.
|
"""Registers a new user with given localpart and optional displayname, emails.
|
||||||
|
|
||||||
Added in Synapse v1.2.0.
|
Added in Synapse v1.2.0.
|
||||||
|
Changed in Synapse v1.56.0: add 'admin' argument to register the user as admin.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
localpart: The localpart of the new user.
|
localpart: The localpart of the new user.
|
||||||
displayname: The displayname of the new user.
|
displayname: The displayname of the new user.
|
||||||
emails: Emails to bind to the new user.
|
emails: Emails to bind to the new user.
|
||||||
|
admin: True if the user should be registered as a server admin.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
SynapseError if there is an error performing the registration. Check the
|
SynapseError if there is an error performing the registration. Check the
|
||||||
|
@ -631,6 +636,7 @@ class ModuleApi:
|
||||||
localpart=localpart,
|
localpart=localpart,
|
||||||
default_display_name=displayname,
|
default_display_name=displayname,
|
||||||
bind_emails=emails or [],
|
bind_emails=emails or [],
|
||||||
|
admin=admin,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -665,7 +671,8 @@ class ModuleApi:
|
||||||
def record_user_external_id(
|
def record_user_external_id(
|
||||||
self, auth_provider_id: str, remote_user_id: str, registered_user_id: str
|
self, auth_provider_id: str, remote_user_id: str, registered_user_id: str
|
||||||
) -> defer.Deferred:
|
) -> defer.Deferred:
|
||||||
"""Record a mapping from an external user id to a mxid
|
"""Record a mapping between an external user id from a single sign-on provider
|
||||||
|
and a mxid.
|
||||||
|
|
||||||
Added in Synapse v1.9.0.
|
Added in Synapse v1.9.0.
|
||||||
|
|
||||||
|
@ -1280,6 +1287,30 @@ class ModuleApi:
|
||||||
"""
|
"""
|
||||||
await self._registration_handler.check_username(username)
|
await self._registration_handler.check_username(username)
|
||||||
|
|
||||||
|
async def store_remote_3pid_association(
|
||||||
|
self, user_id: str, medium: str, address: str, id_server: str
|
||||||
|
) -> None:
|
||||||
|
"""Stores an existing association between a user ID and a third-party identifier.
|
||||||
|
|
||||||
|
The association must already exist on the remote identity server.
|
||||||
|
|
||||||
|
Added in Synapse v1.56.0.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID that's been associated with the 3PID.
|
||||||
|
medium: The medium of the 3PID (current supported values are "msisdn" and
|
||||||
|
"email").
|
||||||
|
address: The address of the 3PID.
|
||||||
|
id_server: The identity server the 3PID association has been registered on.
|
||||||
|
This should only be the domain (or IP address, optionally with the port
|
||||||
|
number) for the identity server. This will be used to reach out to the
|
||||||
|
identity server using HTTPS (unless specified otherwise by Synapse's
|
||||||
|
configuration) when attempting to unbind the third-party identifier.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
await self._store.add_user_bound_threepid(user_id, medium, address, id_server)
|
||||||
|
|
||||||
|
|
||||||
class PublicRoomListManager:
|
class PublicRoomListManager:
|
||||||
"""Contains methods for adding to, removing from and querying whether a room
|
"""Contains methods for adding to, removing from and querying whether a room
|
||||||
|
|
|
@ -24,6 +24,7 @@ from synapse.event_auth import get_user_power_level
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.events.snapshot import EventContext
|
from synapse.events.snapshot import EventContext
|
||||||
from synapse.state import POWER_KEY
|
from synapse.state import POWER_KEY
|
||||||
|
from synapse.storage.databases.main.roommember import EventIdMembership
|
||||||
from synapse.util.async_helpers import Linearizer
|
from synapse.util.async_helpers import Linearizer
|
||||||
from synapse.util.caches import CacheMetric, register_cache
|
from synapse.util.caches import CacheMetric, register_cache
|
||||||
from synapse.util.caches.descriptors import lru_cache
|
from synapse.util.caches.descriptors import lru_cache
|
||||||
|
@ -213,7 +214,7 @@ class BulkPushRuleEvaluator:
|
||||||
if not event.is_state():
|
if not event.is_state():
|
||||||
ignorers = await self.store.ignored_by(event.sender)
|
ignorers = await self.store.ignored_by(event.sender)
|
||||||
else:
|
else:
|
||||||
ignorers = set()
|
ignorers = frozenset()
|
||||||
|
|
||||||
for uid, rules in rules_by_user.items():
|
for uid, rules in rules_by_user.items():
|
||||||
if event.sender == uid:
|
if event.sender == uid:
|
||||||
|
@ -292,7 +293,7 @@ def _condition_checker(
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
MemberMap = Dict[str, Tuple[str, str]]
|
MemberMap = Dict[str, Optional[EventIdMembership]]
|
||||||
Rule = Dict[str, dict]
|
Rule = Dict[str, dict]
|
||||||
RulesByUser = Dict[str, List[Rule]]
|
RulesByUser = Dict[str, List[Rule]]
|
||||||
StateGroup = Union[object, int]
|
StateGroup = Union[object, int]
|
||||||
|
@ -306,7 +307,7 @@ class RulesForRoomData:
|
||||||
*only* include data, and not references to e.g. the data stores.
|
*only* include data, and not references to e.g. the data stores.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# event_id -> (user_id, state)
|
# event_id -> EventIdMembership
|
||||||
member_map: MemberMap = attr.Factory(dict)
|
member_map: MemberMap = attr.Factory(dict)
|
||||||
# user_id -> rules
|
# user_id -> rules
|
||||||
rules_by_user: RulesByUser = attr.Factory(dict)
|
rules_by_user: RulesByUser = attr.Factory(dict)
|
||||||
|
@ -447,11 +448,10 @@ class RulesForRoom:
|
||||||
|
|
||||||
res = self.data.member_map.get(event_id, None)
|
res = self.data.member_map.get(event_id, None)
|
||||||
if res:
|
if res:
|
||||||
user_id, state = res
|
if res.membership == Membership.JOIN:
|
||||||
if state == Membership.JOIN:
|
rules = self.data.rules_by_user.get(res.user_id, None)
|
||||||
rules = self.data.rules_by_user.get(user_id, None)
|
|
||||||
if rules:
|
if rules:
|
||||||
ret_rules_by_user[user_id] = rules
|
ret_rules_by_user[res.user_id] = rules
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# If a user has left a room we remove their push rule. If they
|
# If a user has left a room we remove their push rule. If they
|
||||||
|
@ -502,24 +502,26 @@ class RulesForRoom:
|
||||||
"""
|
"""
|
||||||
sequence = self.data.sequence
|
sequence = self.data.sequence
|
||||||
|
|
||||||
rows = await self.store.get_membership_from_event_ids(member_event_ids.values())
|
members = await self.store.get_membership_from_event_ids(
|
||||||
|
member_event_ids.values()
|
||||||
|
)
|
||||||
|
|
||||||
members = {row["event_id"]: (row["user_id"], row["membership"]) for row in rows}
|
# If the event is a join event then it will be in current state events
|
||||||
|
|
||||||
# If the event is a join event then it will be in current state evnts
|
|
||||||
# map but not in the DB, so we have to explicitly insert it.
|
# map but not in the DB, so we have to explicitly insert it.
|
||||||
if event.type == EventTypes.Member:
|
if event.type == EventTypes.Member:
|
||||||
for event_id in member_event_ids.values():
|
for event_id in member_event_ids.values():
|
||||||
if event_id == event.event_id:
|
if event_id == event.event_id:
|
||||||
members[event_id] = (event.state_key, event.membership)
|
members[event_id] = EventIdMembership(
|
||||||
|
user_id=event.state_key, membership=event.membership
|
||||||
|
)
|
||||||
|
|
||||||
if logger.isEnabledFor(logging.DEBUG):
|
if logger.isEnabledFor(logging.DEBUG):
|
||||||
logger.debug("Found members %r: %r", self.room_id, members.values())
|
logger.debug("Found members %r: %r", self.room_id, members.values())
|
||||||
|
|
||||||
joined_user_ids = {
|
joined_user_ids = {
|
||||||
user_id
|
entry.user_id
|
||||||
for user_id, membership in members.values()
|
for entry in members.values()
|
||||||
if membership == Membership.JOIN
|
if entry and entry.membership == Membership.JOIN
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.debug("Joined: %r", joined_user_ids)
|
logger.debug("Joined: %r", joined_user_ids)
|
||||||
|
|
|
@ -18,6 +18,7 @@ from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, TypeVar
|
||||||
|
|
||||||
import bleach
|
import bleach
|
||||||
import jinja2
|
import jinja2
|
||||||
|
from markupsafe import Markup
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, Membership, RoomTypes
|
from synapse.api.constants import EventTypes, Membership, RoomTypes
|
||||||
from synapse.api.errors import StoreError
|
from synapse.api.errors import StoreError
|
||||||
|
@ -867,7 +868,7 @@ class Mailer:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def safe_markup(raw_html: str) -> jinja2.Markup:
|
def safe_markup(raw_html: str) -> Markup:
|
||||||
"""
|
"""
|
||||||
Sanitise a raw HTML string to a set of allowed tags and attributes, and linkify any bare URLs.
|
Sanitise a raw HTML string to a set of allowed tags and attributes, and linkify any bare URLs.
|
||||||
|
|
||||||
|
@ -877,7 +878,7 @@ def safe_markup(raw_html: str) -> jinja2.Markup:
|
||||||
Returns:
|
Returns:
|
||||||
A Markup object ready to safely use in a Jinja template.
|
A Markup object ready to safely use in a Jinja template.
|
||||||
"""
|
"""
|
||||||
return jinja2.Markup(
|
return Markup(
|
||||||
bleach.linkify(
|
bleach.linkify(
|
||||||
bleach.clean(
|
bleach.clean(
|
||||||
raw_html,
|
raw_html,
|
||||||
|
@ -891,7 +892,7 @@ def safe_markup(raw_html: str) -> jinja2.Markup:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def safe_text(raw_text: str) -> jinja2.Markup:
|
def safe_text(raw_text: str) -> Markup:
|
||||||
"""
|
"""
|
||||||
Sanitise text (escape any HTML tags), and then linkify any bare URLs.
|
Sanitise text (escape any HTML tags), and then linkify any bare URLs.
|
||||||
|
|
||||||
|
@ -901,7 +902,7 @@ def safe_text(raw_text: str) -> jinja2.Markup:
|
||||||
Returns:
|
Returns:
|
||||||
A Markup object ready to safely use in a Jinja template.
|
A Markup object ready to safely use in a Jinja template.
|
||||||
"""
|
"""
|
||||||
return jinja2.Markup(
|
return Markup(
|
||||||
bleach.linkify(bleach.clean(raw_text, tags=[], attributes=[], strip=False))
|
bleach.linkify(bleach.clean(raw_text, tags=[], attributes=[], strip=False))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -74,7 +74,10 @@ REQUIREMENTS = [
|
||||||
# Note: 21.1.0 broke `/sync`, see #9936
|
# Note: 21.1.0 broke `/sync`, see #9936
|
||||||
"attrs>=19.2.0,!=21.1.0",
|
"attrs>=19.2.0,!=21.1.0",
|
||||||
"netaddr>=0.7.18",
|
"netaddr>=0.7.18",
|
||||||
"Jinja2>=2.9",
|
# Jinja 2.x is incompatible with MarkupSafe>=2.1. To ensure that admins do not
|
||||||
|
# end up with a broken installation, with recent MarkupSafe but old Jinja, we
|
||||||
|
# add a lower bound to the Jinja2 dependency.
|
||||||
|
"Jinja2>=3.0",
|
||||||
"bleach>=1.4.3",
|
"bleach>=1.4.3",
|
||||||
# We use `ParamSpec`, which was added in `typing-extensions` 3.10.0.0.
|
# We use `ParamSpec`, which was added in `typing-extensions` 3.10.0.0.
|
||||||
"typing-extensions>=3.10.0",
|
"typing-extensions>=3.10.0",
|
||||||
|
|
|
@ -32,6 +32,7 @@ from synapse.rest.client import (
|
||||||
knock,
|
knock,
|
||||||
login as v1_login,
|
login as v1_login,
|
||||||
logout,
|
logout,
|
||||||
|
mutual_rooms,
|
||||||
notifications,
|
notifications,
|
||||||
openid,
|
openid,
|
||||||
password_policy,
|
password_policy,
|
||||||
|
@ -49,7 +50,6 @@ from synapse.rest.client import (
|
||||||
room_keys,
|
room_keys,
|
||||||
room_upgrade_rest_servlet,
|
room_upgrade_rest_servlet,
|
||||||
sendtodevice,
|
sendtodevice,
|
||||||
shared_rooms,
|
|
||||||
sync,
|
sync,
|
||||||
tags,
|
tags,
|
||||||
thirdparty,
|
thirdparty,
|
||||||
|
@ -132,4 +132,4 @@ class ClientRestResource(JsonResource):
|
||||||
admin.register_servlets_for_client_rest_resource(hs, client_resource)
|
admin.register_servlets_for_client_rest_resource(hs, client_resource)
|
||||||
|
|
||||||
# unstable
|
# unstable
|
||||||
shared_rooms.register_servlets(hs, client_resource)
|
mutual_rooms.register_servlets(hs, client_resource)
|
||||||
|
|
|
@ -28,13 +28,13 @@ if TYPE_CHECKING:
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class UserSharedRoomsServlet(RestServlet):
|
class UserMutualRoomsServlet(RestServlet):
|
||||||
"""
|
"""
|
||||||
GET /uk.half-shot.msc2666/user/shared_rooms/{user_id} HTTP/1.1
|
GET /uk.half-shot.msc2666/user/mutual_rooms/{user_id} HTTP/1.1
|
||||||
"""
|
"""
|
||||||
|
|
||||||
PATTERNS = client_patterns(
|
PATTERNS = client_patterns(
|
||||||
"/uk.half-shot.msc2666/user/shared_rooms/(?P<user_id>[^/]*)",
|
"/uk.half-shot.msc2666/user/mutual_rooms/(?P<user_id>[^/]*)",
|
||||||
releases=(), # This is an unstable feature
|
releases=(), # This is an unstable feature
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -42,17 +42,19 @@ class UserSharedRoomsServlet(RestServlet):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.user_directory_active = hs.config.server.update_user_directory
|
self.user_directory_search_enabled = (
|
||||||
|
hs.config.userdirectory.user_directory_search_enabled
|
||||||
|
)
|
||||||
|
|
||||||
async def on_GET(
|
async def on_GET(
|
||||||
self, request: SynapseRequest, user_id: str
|
self, request: SynapseRequest, user_id: str
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
|
|
||||||
if not self.user_directory_active:
|
if not self.user_directory_search_enabled:
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
code=400,
|
code=400,
|
||||||
msg="The user directory is disabled on this server. Cannot determine shared rooms.",
|
msg="User directory searching is disabled. Cannot determine shared rooms.",
|
||||||
errcode=Codes.FORBIDDEN,
|
errcode=Codes.UNKNOWN,
|
||||||
)
|
)
|
||||||
|
|
||||||
UserID.from_string(user_id)
|
UserID.from_string(user_id)
|
||||||
|
@ -64,7 +66,8 @@ class UserSharedRoomsServlet(RestServlet):
|
||||||
msg="You cannot request a list of shared rooms with yourself",
|
msg="You cannot request a list of shared rooms with yourself",
|
||||||
errcode=Codes.FORBIDDEN,
|
errcode=Codes.FORBIDDEN,
|
||||||
)
|
)
|
||||||
rooms = await self.store.get_shared_rooms_for_users(
|
|
||||||
|
rooms = await self.store.get_mutual_rooms_for_users(
|
||||||
requester.user.to_string(), user_id
|
requester.user.to_string(), user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -72,4 +75,4 @@ class UserSharedRoomsServlet(RestServlet):
|
||||||
|
|
||||||
|
|
||||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||||
UserSharedRoomsServlet(hs).register(http_server)
|
UserMutualRoomsServlet(hs).register(http_server)
|
|
@ -51,9 +51,7 @@ class RelationPaginationServlet(RestServlet):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.clock = hs.get_clock()
|
self._relations_handler = hs.get_relations_handler()
|
||||||
self._event_serializer = hs.get_event_client_serializer()
|
|
||||||
self.event_handler = hs.get_event_handler()
|
|
||||||
|
|
||||||
async def on_GET(
|
async def on_GET(
|
||||||
self,
|
self,
|
||||||
|
@ -65,16 +63,6 @@ class RelationPaginationServlet(RestServlet):
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
|
|
||||||
await self.auth.check_user_in_room_or_world_readable(
|
|
||||||
room_id, requester.user.to_string(), allow_departed_users=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# This gets the original event and checks that a) the event exists and
|
|
||||||
# b) the user is allowed to view it.
|
|
||||||
event = await self.event_handler.get_event(requester.user, room_id, parent_id)
|
|
||||||
if event is None:
|
|
||||||
raise SynapseError(404, "Unknown parent event.")
|
|
||||||
|
|
||||||
limit = parse_integer(request, "limit", default=5)
|
limit = parse_integer(request, "limit", default=5)
|
||||||
direction = parse_string(
|
direction = parse_string(
|
||||||
request, "org.matrix.msc3715.dir", default="b", allowed_values=["f", "b"]
|
request, "org.matrix.msc3715.dir", default="b", allowed_values=["f", "b"]
|
||||||
|
@ -90,9 +78,9 @@ class RelationPaginationServlet(RestServlet):
|
||||||
if to_token_str:
|
if to_token_str:
|
||||||
to_token = await StreamToken.from_string(self.store, to_token_str)
|
to_token = await StreamToken.from_string(self.store, to_token_str)
|
||||||
|
|
||||||
pagination_chunk = await self.store.get_relations_for_event(
|
result = await self._relations_handler.get_relations(
|
||||||
|
requester=requester,
|
||||||
event_id=parent_id,
|
event_id=parent_id,
|
||||||
event=event,
|
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
relation_type=relation_type,
|
relation_type=relation_type,
|
||||||
event_type=event_type,
|
event_type=event_type,
|
||||||
|
@ -102,30 +90,7 @@ class RelationPaginationServlet(RestServlet):
|
||||||
to_token=to_token,
|
to_token=to_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
events = await self.store.get_events_as_list(
|
return 200, result
|
||||||
[c["event_id"] for c in pagination_chunk.chunk]
|
|
||||||
)
|
|
||||||
|
|
||||||
now = self.clock.time_msec()
|
|
||||||
# Do not bundle aggregations when retrieving the original event because
|
|
||||||
# we want the content before relations are applied to it.
|
|
||||||
original_event = self._event_serializer.serialize_event(
|
|
||||||
event, now, bundle_aggregations=None
|
|
||||||
)
|
|
||||||
# The relations returned for the requested event do include their
|
|
||||||
# bundled aggregations.
|
|
||||||
aggregations = await self.store.get_bundled_aggregations(
|
|
||||||
events, requester.user.to_string()
|
|
||||||
)
|
|
||||||
serialized_events = self._event_serializer.serialize_events(
|
|
||||||
events, now, bundle_aggregations=aggregations
|
|
||||||
)
|
|
||||||
|
|
||||||
return_value = await pagination_chunk.to_dict(self.store)
|
|
||||||
return_value["chunk"] = serialized_events
|
|
||||||
return_value["original_event"] = original_event
|
|
||||||
|
|
||||||
return 200, return_value
|
|
||||||
|
|
||||||
|
|
||||||
class RelationAggregationPaginationServlet(RestServlet):
|
class RelationAggregationPaginationServlet(RestServlet):
|
||||||
|
@ -245,9 +210,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.clock = hs.get_clock()
|
self._relations_handler = hs.get_relations_handler()
|
||||||
self._event_serializer = hs.get_event_client_serializer()
|
|
||||||
self.event_handler = hs.get_event_handler()
|
|
||||||
|
|
||||||
async def on_GET(
|
async def on_GET(
|
||||||
self,
|
self,
|
||||||
|
@ -260,18 +223,6 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
|
|
||||||
await self.auth.check_user_in_room_or_world_readable(
|
|
||||||
room_id,
|
|
||||||
requester.user.to_string(),
|
|
||||||
allow_departed_users=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# This checks that a) the event exists and b) the user is allowed to
|
|
||||||
# view it.
|
|
||||||
event = await self.event_handler.get_event(requester.user, room_id, parent_id)
|
|
||||||
if event is None:
|
|
||||||
raise SynapseError(404, "Unknown parent event.")
|
|
||||||
|
|
||||||
if relation_type != RelationTypes.ANNOTATION:
|
if relation_type != RelationTypes.ANNOTATION:
|
||||||
raise SynapseError(400, "Relation type must be 'annotation'")
|
raise SynapseError(400, "Relation type must be 'annotation'")
|
||||||
|
|
||||||
|
@ -286,9 +237,9 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
|
||||||
if to_token_str:
|
if to_token_str:
|
||||||
to_token = await StreamToken.from_string(self.store, to_token_str)
|
to_token = await StreamToken.from_string(self.store, to_token_str)
|
||||||
|
|
||||||
result = await self.store.get_relations_for_event(
|
result = await self._relations_handler.get_relations(
|
||||||
|
requester=requester,
|
||||||
event_id=parent_id,
|
event_id=parent_id,
|
||||||
event=event,
|
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
relation_type=relation_type,
|
relation_type=relation_type,
|
||||||
event_type=event_type,
|
event_type=event_type,
|
||||||
|
@ -298,17 +249,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
|
||||||
to_token=to_token,
|
to_token=to_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
events = await self.store.get_events_as_list(
|
return 200, result
|
||||||
[c["event_id"] for c in result.chunk]
|
|
||||||
)
|
|
||||||
|
|
||||||
now = self.clock.time_msec()
|
|
||||||
serialized_events = self._event_serializer.serialize_events(events, now)
|
|
||||||
|
|
||||||
return_value = await result.to_dict(self.store)
|
|
||||||
return_value["chunk"] = serialized_events
|
|
||||||
|
|
||||||
return 200, return_value
|
|
||||||
|
|
||||||
|
|
||||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||||
|
|
|
@ -649,6 +649,7 @@ class RoomEventServlet(RestServlet):
|
||||||
self._store = hs.get_datastores().main
|
self._store = hs.get_datastores().main
|
||||||
self.event_handler = hs.get_event_handler()
|
self.event_handler = hs.get_event_handler()
|
||||||
self._event_serializer = hs.get_event_client_serializer()
|
self._event_serializer = hs.get_event_client_serializer()
|
||||||
|
self._relations_handler = hs.get_relations_handler()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
|
||||||
async def on_GET(
|
async def on_GET(
|
||||||
|
@ -667,7 +668,7 @@ class RoomEventServlet(RestServlet):
|
||||||
|
|
||||||
if event:
|
if event:
|
||||||
# Ensure there are bundled aggregations available.
|
# Ensure there are bundled aggregations available.
|
||||||
aggregations = await self._store.get_bundled_aggregations(
|
aggregations = await self._relations_handler.get_bundled_aggregations(
|
||||||
[event], requester.user.to_string()
|
[event], requester.user.to_string()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -124,14 +124,14 @@ class RoomBatchSendEventRestServlet(RestServlet):
|
||||||
)
|
)
|
||||||
|
|
||||||
# For the event we are inserting next to (`prev_event_ids_from_query`),
|
# For the event we are inserting next to (`prev_event_ids_from_query`),
|
||||||
# find the most recent auth events (derived from state events) that
|
# find the most recent state events that allowed that message to be
|
||||||
# allowed that message to be sent. We will use that as a base
|
# sent. We will use that as a base to auth our historical messages
|
||||||
# to auth our historical messages against.
|
# against.
|
||||||
auth_event_ids = await self.room_batch_handler.get_most_recent_auth_event_ids_from_event_id_list(
|
state_event_ids = await self.room_batch_handler.get_most_recent_full_state_ids_from_event_id_list(
|
||||||
prev_event_ids_from_query
|
prev_event_ids_from_query
|
||||||
)
|
)
|
||||||
|
|
||||||
if not auth_event_ids:
|
if not state_event_ids:
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
HTTPStatus.BAD_REQUEST,
|
HTTPStatus.BAD_REQUEST,
|
||||||
"No auth events found for given prev_event query parameter. The prev_event=%s probably does not exist."
|
"No auth events found for given prev_event query parameter. The prev_event=%s probably does not exist."
|
||||||
|
@ -148,13 +148,13 @@ class RoomBatchSendEventRestServlet(RestServlet):
|
||||||
await self.room_batch_handler.persist_state_events_at_start(
|
await self.room_batch_handler.persist_state_events_at_start(
|
||||||
state_events_at_start=body["state_events_at_start"],
|
state_events_at_start=body["state_events_at_start"],
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
initial_auth_event_ids=auth_event_ids,
|
initial_state_event_ids=state_event_ids,
|
||||||
app_service_requester=requester,
|
app_service_requester=requester,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# Update our ongoing auth event ID list with all of the new state we
|
# Update our ongoing auth event ID list with all of the new state we
|
||||||
# just created
|
# just created
|
||||||
auth_event_ids.extend(state_event_ids_at_start)
|
state_event_ids.extend(state_event_ids_at_start)
|
||||||
|
|
||||||
inherited_depth = await self.room_batch_handler.inherit_depth_from_prev_ids(
|
inherited_depth = await self.room_batch_handler.inherit_depth_from_prev_ids(
|
||||||
prev_event_ids_from_query
|
prev_event_ids_from_query
|
||||||
|
@ -196,7 +196,12 @@ class RoomBatchSendEventRestServlet(RestServlet):
|
||||||
),
|
),
|
||||||
base_insertion_event_dict,
|
base_insertion_event_dict,
|
||||||
prev_event_ids=base_insertion_event_dict.get("prev_events"),
|
prev_event_ids=base_insertion_event_dict.get("prev_events"),
|
||||||
auth_event_ids=auth_event_ids,
|
# Also set the explicit state here because we want to resolve
|
||||||
|
# any `state_events_at_start` here too. It's not strictly
|
||||||
|
# necessary to accomplish anything but if someone asks for the
|
||||||
|
# state at this point, we probably want to show them the
|
||||||
|
# historical state that was part of this batch.
|
||||||
|
state_event_ids=state_event_ids,
|
||||||
historical=True,
|
historical=True,
|
||||||
depth=inherited_depth,
|
depth=inherited_depth,
|
||||||
)
|
)
|
||||||
|
@ -212,7 +217,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
batch_id_to_connect_to=batch_id_to_connect_to,
|
batch_id_to_connect_to=batch_id_to_connect_to,
|
||||||
inherited_depth=inherited_depth,
|
inherited_depth=inherited_depth,
|
||||||
auth_event_ids=auth_event_ids,
|
initial_state_event_ids=state_event_ids,
|
||||||
app_service_requester=requester,
|
app_service_requester=requester,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,7 @@ from synapse.api.errors import SynapseError
|
||||||
from synapse.http.server import HttpServer
|
from synapse.http.server import HttpServer
|
||||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonMapping
|
||||||
|
|
||||||
from ._base import client_patterns
|
from ._base import client_patterns
|
||||||
|
|
||||||
|
@ -38,7 +38,7 @@ class UserDirectorySearchRestServlet(RestServlet):
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.user_directory_handler = hs.get_user_directory_handler()
|
self.user_directory_handler = hs.get_user_directory_handler()
|
||||||
|
|
||||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonMapping]:
|
||||||
"""Searches for users in directory
|
"""Searches for users in directory
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
|
@ -16,7 +16,6 @@ import itertools
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import TYPE_CHECKING, Dict, Generator, Iterable, Optional, Set, Union
|
from typing import TYPE_CHECKING, Dict, Generator, Iterable, Optional, Set, Union
|
||||||
from urllib import parse as urlparse
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from lxml import etree
|
from lxml import etree
|
||||||
|
@ -144,9 +143,7 @@ def decode_body(
|
||||||
return etree.fromstring(body, parser)
|
return etree.fromstring(body, parser)
|
||||||
|
|
||||||
|
|
||||||
def parse_html_to_open_graph(
|
def parse_html_to_open_graph(tree: "etree.Element") -> Dict[str, Optional[str]]:
|
||||||
tree: "etree.Element", media_uri: str
|
|
||||||
) -> Dict[str, Optional[str]]:
|
|
||||||
"""
|
"""
|
||||||
Parse the HTML document into an Open Graph response.
|
Parse the HTML document into an Open Graph response.
|
||||||
|
|
||||||
|
@ -155,7 +152,6 @@ def parse_html_to_open_graph(
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tree: The parsed HTML document.
|
tree: The parsed HTML document.
|
||||||
media_url: The URI used to download the body.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The Open Graph response as a dictionary.
|
The Open Graph response as a dictionary.
|
||||||
|
@ -209,7 +205,7 @@ def parse_html_to_open_graph(
|
||||||
"//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"
|
"//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"
|
||||||
)
|
)
|
||||||
if meta_image:
|
if meta_image:
|
||||||
og["og:image"] = rebase_url(meta_image[0], media_uri)
|
og["og:image"] = meta_image[0]
|
||||||
else:
|
else:
|
||||||
# TODO: consider inlined CSS styles as well as width & height attribs
|
# TODO: consider inlined CSS styles as well as width & height attribs
|
||||||
images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
|
images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
|
||||||
|
@ -320,37 +316,6 @@ def _iterate_over_text(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def rebase_url(url: str, base: str) -> str:
|
|
||||||
"""
|
|
||||||
Resolves a potentially relative `url` against an absolute `base` URL.
|
|
||||||
|
|
||||||
For example:
|
|
||||||
|
|
||||||
>>> rebase_url("subpage", "https://example.com/foo/")
|
|
||||||
'https://example.com/foo/subpage'
|
|
||||||
>>> rebase_url("sibling", "https://example.com/foo")
|
|
||||||
'https://example.com/sibling'
|
|
||||||
>>> rebase_url("/bar", "https://example.com/foo/")
|
|
||||||
'https://example.com/bar'
|
|
||||||
>>> rebase_url("https://alice.com/a/", "https://example.com/foo/")
|
|
||||||
'https://alice.com/a'
|
|
||||||
"""
|
|
||||||
base_parts = urlparse.urlparse(base)
|
|
||||||
# Convert the parsed URL to a list for (potential) modification.
|
|
||||||
url_parts = list(urlparse.urlparse(url))
|
|
||||||
# Add a scheme, if one does not exist.
|
|
||||||
if not url_parts[0]:
|
|
||||||
url_parts[0] = base_parts.scheme or "http"
|
|
||||||
# Fix up the hostname, if this is not a data URL.
|
|
||||||
if url_parts[0] != "data" and not url_parts[1]:
|
|
||||||
url_parts[1] = base_parts.netloc
|
|
||||||
# If the path does not start with a /, nest it under the base path's last
|
|
||||||
# directory.
|
|
||||||
if not url_parts[2].startswith("/"):
|
|
||||||
url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts.path) + url_parts[2]
|
|
||||||
return urlparse.urlunparse(url_parts)
|
|
||||||
|
|
||||||
|
|
||||||
def summarize_paragraphs(
|
def summarize_paragraphs(
|
||||||
text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500
|
text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
|
|
|
@ -22,7 +22,7 @@ import shutil
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
from typing import TYPE_CHECKING, BinaryIO, Iterable, Optional, Tuple
|
from typing import TYPE_CHECKING, BinaryIO, Iterable, Optional, Tuple
|
||||||
from urllib import parse as urlparse
|
from urllib.parse import urljoin, urlparse, urlsplit
|
||||||
from urllib.request import urlopen
|
from urllib.request import urlopen
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
@ -44,11 +44,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.rest.media.v1._base import get_filename_from_headers
|
from synapse.rest.media.v1._base import get_filename_from_headers
|
||||||
from synapse.rest.media.v1.media_storage import MediaStorage
|
from synapse.rest.media.v1.media_storage import MediaStorage
|
||||||
from synapse.rest.media.v1.oembed import OEmbedProvider
|
from synapse.rest.media.v1.oembed import OEmbedProvider
|
||||||
from synapse.rest.media.v1.preview_html import (
|
from synapse.rest.media.v1.preview_html import decode_body, parse_html_to_open_graph
|
||||||
decode_body,
|
|
||||||
parse_html_to_open_graph,
|
|
||||||
rebase_url,
|
|
||||||
)
|
|
||||||
from synapse.types import JsonDict, UserID
|
from synapse.types import JsonDict, UserID
|
||||||
from synapse.util import json_encoder
|
from synapse.util import json_encoder
|
||||||
from synapse.util.async_helpers import ObservableDeferred
|
from synapse.util.async_helpers import ObservableDeferred
|
||||||
|
@ -187,7 +183,7 @@ class PreviewUrlResource(DirectServeJsonResource):
|
||||||
ts = self.clock.time_msec()
|
ts = self.clock.time_msec()
|
||||||
|
|
||||||
# XXX: we could move this into _do_preview if we wanted.
|
# XXX: we could move this into _do_preview if we wanted.
|
||||||
url_tuple = urlparse.urlsplit(url)
|
url_tuple = urlsplit(url)
|
||||||
for entry in self.url_preview_url_blacklist:
|
for entry in self.url_preview_url_blacklist:
|
||||||
match = True
|
match = True
|
||||||
for attrib in entry:
|
for attrib in entry:
|
||||||
|
@ -322,7 +318,7 @@ class PreviewUrlResource(DirectServeJsonResource):
|
||||||
|
|
||||||
# Parse Open Graph information from the HTML in case the oEmbed
|
# Parse Open Graph information from the HTML in case the oEmbed
|
||||||
# response failed or is incomplete.
|
# response failed or is incomplete.
|
||||||
og_from_html = parse_html_to_open_graph(tree, media_info.uri)
|
og_from_html = parse_html_to_open_graph(tree)
|
||||||
|
|
||||||
# Compile the Open Graph response by using the scraped
|
# Compile the Open Graph response by using the scraped
|
||||||
# information from the HTML and overlaying any information
|
# information from the HTML and overlaying any information
|
||||||
|
@ -588,12 +584,17 @@ class PreviewUrlResource(DirectServeJsonResource):
|
||||||
if "og:image" not in og or not og["og:image"]:
|
if "og:image" not in og or not og["og:image"]:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# The image URL from the HTML might be relative to the previewed page,
|
||||||
|
# convert it to an URL which can be requested directly.
|
||||||
|
image_url = og["og:image"]
|
||||||
|
url_parts = urlparse(image_url)
|
||||||
|
if url_parts.scheme != "data":
|
||||||
|
image_url = urljoin(media_info.uri, image_url)
|
||||||
|
|
||||||
# FIXME: it might be cleaner to use the same flow as the main /preview_url
|
# FIXME: it might be cleaner to use the same flow as the main /preview_url
|
||||||
# request itself and benefit from the same caching etc. But for now we
|
# request itself and benefit from the same caching etc. But for now we
|
||||||
# just rely on the caching on the master request to speed things up.
|
# just rely on the caching on the master request to speed things up.
|
||||||
image_info = await self._handle_url(
|
image_info = await self._handle_url(image_url, user, allow_data_urls=True)
|
||||||
rebase_url(og["og:image"], media_info.uri), user, allow_data_urls=True
|
|
||||||
)
|
|
||||||
|
|
||||||
if _is_media(image_info.media_type):
|
if _is_media(image_info.media_type):
|
||||||
# TODO: make sure we don't choke on white-on-transparent images
|
# TODO: make sure we don't choke on white-on-transparent images
|
||||||
|
|
|
@ -94,6 +94,7 @@ from synapse.handlers.profile import ProfileHandler
|
||||||
from synapse.handlers.read_marker import ReadMarkerHandler
|
from synapse.handlers.read_marker import ReadMarkerHandler
|
||||||
from synapse.handlers.receipts import ReceiptsHandler
|
from synapse.handlers.receipts import ReceiptsHandler
|
||||||
from synapse.handlers.register import RegistrationHandler
|
from synapse.handlers.register import RegistrationHandler
|
||||||
|
from synapse.handlers.relations import RelationsHandler
|
||||||
from synapse.handlers.room import (
|
from synapse.handlers.room import (
|
||||||
RoomContextHandler,
|
RoomContextHandler,
|
||||||
RoomCreationHandler,
|
RoomCreationHandler,
|
||||||
|
@ -719,6 +720,10 @@ class HomeServer(metaclass=abc.ABCMeta):
|
||||||
def get_pagination_handler(self) -> PaginationHandler:
|
def get_pagination_handler(self) -> PaginationHandler:
|
||||||
return PaginationHandler(self)
|
return PaginationHandler(self)
|
||||||
|
|
||||||
|
@cache_in_self
|
||||||
|
def get_relations_handler(self) -> RelationsHandler:
|
||||||
|
return RelationsHandler(self)
|
||||||
|
|
||||||
@cache_in_self
|
@cache_in_self
|
||||||
def get_room_context_handler(self) -> RoomContextHandler:
|
def get_room_context_handler(self) -> RoomContextHandler:
|
||||||
return RoomContextHandler(self)
|
return RoomContextHandler(self)
|
||||||
|
|
|
@ -41,6 +41,7 @@ from prometheus_client import Histogram
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
|
|
||||||
from twisted.enterprise import adbapi
|
from twisted.enterprise import adbapi
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.errors import StoreError
|
from synapse.api.errors import StoreError
|
||||||
from synapse.config.database import DatabaseConnectionConfig
|
from synapse.config.database import DatabaseConnectionConfig
|
||||||
|
@ -55,6 +56,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.storage.background_updates import BackgroundUpdater
|
from synapse.storage.background_updates import BackgroundUpdater
|
||||||
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
|
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
|
||||||
from synapse.storage.types import Connection, Cursor
|
from synapse.storage.types import Connection, Cursor
|
||||||
|
from synapse.util.async_helpers import delay_cancellation
|
||||||
from synapse.util.iterutils import batch_iter
|
from synapse.util.iterutils import batch_iter
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -286,13 +288,17 @@ class LoggingTransaction:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if isinstance(self.database_engine, PostgresEngine):
|
if isinstance(self.database_engine, PostgresEngine):
|
||||||
from psycopg2.extras import execute_batch # type: ignore
|
from psycopg2.extras import execute_batch
|
||||||
|
|
||||||
self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
|
self._do_execute(
|
||||||
|
lambda the_sql: execute_batch(self.txn, the_sql, args), sql
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.executemany(sql, args)
|
self.executemany(sql, args)
|
||||||
|
|
||||||
def execute_values(self, sql: str, *args: Any, fetch: bool = True) -> List[Tuple]:
|
def execute_values(
|
||||||
|
self, sql: str, values: Iterable[Iterable[Any]], fetch: bool = True
|
||||||
|
) -> List[Tuple]:
|
||||||
"""Corresponds to psycopg2.extras.execute_values. Only available when
|
"""Corresponds to psycopg2.extras.execute_values. Only available when
|
||||||
using postgres.
|
using postgres.
|
||||||
|
|
||||||
|
@ -300,10 +306,11 @@ class LoggingTransaction:
|
||||||
rows (e.g. INSERTs).
|
rows (e.g. INSERTs).
|
||||||
"""
|
"""
|
||||||
assert isinstance(self.database_engine, PostgresEngine)
|
assert isinstance(self.database_engine, PostgresEngine)
|
||||||
from psycopg2.extras import execute_values # type: ignore
|
from psycopg2.extras import execute_values
|
||||||
|
|
||||||
return self._do_execute(
|
return self._do_execute(
|
||||||
lambda *x: execute_values(self.txn, *x, fetch=fetch), sql, *args
|
lambda the_sql: execute_values(self.txn, the_sql, values, fetch=fetch),
|
||||||
|
sql,
|
||||||
)
|
)
|
||||||
|
|
||||||
def execute(self, sql: str, *args: Any) -> None:
|
def execute(self, sql: str, *args: Any) -> None:
|
||||||
|
@ -732,6 +739,8 @@ class DatabasePool:
|
||||||
Returns:
|
Returns:
|
||||||
The result of func
|
The result of func
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
async def _runInteraction() -> R:
|
||||||
after_callbacks: List[_CallbackListEntry] = []
|
after_callbacks: List[_CallbackListEntry] = []
|
||||||
exception_callbacks: List[_CallbackListEntry] = []
|
exception_callbacks: List[_CallbackListEntry] = []
|
||||||
|
|
||||||
|
@ -754,12 +763,21 @@ class DatabasePool:
|
||||||
|
|
||||||
for after_callback, after_args, after_kwargs in after_callbacks:
|
for after_callback, after_args, after_kwargs in after_callbacks:
|
||||||
after_callback(*after_args, **after_kwargs)
|
after_callback(*after_args, **after_kwargs)
|
||||||
|
|
||||||
|
return cast(R, result)
|
||||||
except Exception:
|
except Exception:
|
||||||
for after_callback, after_args, after_kwargs in exception_callbacks:
|
for after_callback, after_args, after_kwargs in exception_callbacks:
|
||||||
after_callback(*after_args, **after_kwargs)
|
after_callback(*after_args, **after_kwargs)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
return cast(R, result)
|
# To handle cancellation, we ensure that `after_callback`s and
|
||||||
|
# `exception_callback`s are always run, since the transaction will complete
|
||||||
|
# on another thread regardless of cancellation.
|
||||||
|
#
|
||||||
|
# We also wait until everything above is done before releasing the
|
||||||
|
# `CancelledError`, so that logging contexts won't get used after they have been
|
||||||
|
# finished.
|
||||||
|
return await delay_cancellation(defer.ensureDeferred(_runInteraction()))
|
||||||
|
|
||||||
async def runWithConnection(
|
async def runWithConnection(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -14,7 +14,17 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, cast
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Dict,
|
||||||
|
FrozenSet,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
|
||||||
from synapse.api.constants import AccountDataTypes
|
from synapse.api.constants import AccountDataTypes
|
||||||
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
|
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
|
||||||
|
@ -365,7 +375,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
|
||||||
)
|
)
|
||||||
|
|
||||||
@cached(max_entries=5000, iterable=True)
|
@cached(max_entries=5000, iterable=True)
|
||||||
async def ignored_by(self, user_id: str) -> Set[str]:
|
async def ignored_by(self, user_id: str) -> FrozenSet[str]:
|
||||||
"""
|
"""
|
||||||
Get users which ignore the given user.
|
Get users which ignore the given user.
|
||||||
|
|
||||||
|
@ -375,7 +385,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
|
||||||
Return:
|
Return:
|
||||||
The user IDs which ignore the given user.
|
The user IDs which ignore the given user.
|
||||||
"""
|
"""
|
||||||
return set(
|
return frozenset(
|
||||||
await self.db_pool.simple_select_onecol(
|
await self.db_pool.simple_select_onecol(
|
||||||
table="ignored_users",
|
table="ignored_users",
|
||||||
keyvalues={"ignored_user_id": user_id},
|
keyvalues={"ignored_user_id": user_id},
|
||||||
|
@ -384,6 +394,26 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@cached(max_entries=5000, iterable=True)
|
||||||
|
async def ignored_users(self, user_id: str) -> FrozenSet[str]:
|
||||||
|
"""
|
||||||
|
Get users which the given user ignores.
|
||||||
|
|
||||||
|
Params:
|
||||||
|
user_id: The user ID which is making the request.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
The user IDs which are ignored by the given user.
|
||||||
|
"""
|
||||||
|
return frozenset(
|
||||||
|
await self.db_pool.simple_select_onecol(
|
||||||
|
table="ignored_users",
|
||||||
|
keyvalues={"ignorer_user_id": user_id},
|
||||||
|
retcol="ignored_user_id",
|
||||||
|
desc="ignored_users",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def process_replication_rows(
|
def process_replication_rows(
|
||||||
self,
|
self,
|
||||||
stream_name: str,
|
stream_name: str,
|
||||||
|
@ -529,6 +559,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
|
||||||
else:
|
else:
|
||||||
currently_ignored_users = set()
|
currently_ignored_users = set()
|
||||||
|
|
||||||
|
# If the data has not changed, nothing to do.
|
||||||
|
if previously_ignored_users == currently_ignored_users:
|
||||||
|
return
|
||||||
|
|
||||||
# Delete entries which are no longer ignored.
|
# Delete entries which are no longer ignored.
|
||||||
self.db_pool.simple_delete_many_txn(
|
self.db_pool.simple_delete_many_txn(
|
||||||
txn,
|
txn,
|
||||||
|
@ -551,6 +585,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
|
||||||
# Invalidate the cache for any ignored users which were added or removed.
|
# Invalidate the cache for any ignored users which were added or removed.
|
||||||
for ignored_user_id in previously_ignored_users ^ currently_ignored_users:
|
for ignored_user_id in previously_ignored_users ^ currently_ignored_users:
|
||||||
self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,))
|
self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,))
|
||||||
|
self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,))
|
||||||
|
|
||||||
async def purge_account_data_for_user(self, user_id: str) -> None:
|
async def purge_account_data_for_user(self, user_id: str) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -23,6 +23,7 @@ from synapse.replication.tcp.streams.events import (
|
||||||
EventsStream,
|
EventsStream,
|
||||||
EventsStreamCurrentStateRow,
|
EventsStreamCurrentStateRow,
|
||||||
EventsStreamEventRow,
|
EventsStreamEventRow,
|
||||||
|
EventsStreamRow,
|
||||||
)
|
)
|
||||||
from synapse.storage._base import SQLBaseStore
|
from synapse.storage._base import SQLBaseStore
|
||||||
from synapse.storage.database import (
|
from synapse.storage.database import (
|
||||||
|
@ -31,6 +32,7 @@ from synapse.storage.database import (
|
||||||
LoggingTransaction,
|
LoggingTransaction,
|
||||||
)
|
)
|
||||||
from synapse.storage.engines import PostgresEngine
|
from synapse.storage.engines import PostgresEngine
|
||||||
|
from synapse.util.caches.descriptors import _CachedFunction
|
||||||
from synapse.util.iterutils import batch_iter
|
from synapse.util.iterutils import batch_iter
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -82,7 +84,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||||
if last_id == current_id:
|
if last_id == current_id:
|
||||||
return [], current_id, False
|
return [], current_id, False
|
||||||
|
|
||||||
def get_all_updated_caches_txn(txn):
|
def get_all_updated_caches_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
|
||||||
# We purposefully don't bound by the current token, as we want to
|
# We purposefully don't bound by the current token, as we want to
|
||||||
# send across cache invalidations as quickly as possible. Cache
|
# send across cache invalidations as quickly as possible. Cache
|
||||||
# invalidations are idempotent, so duplicates are fine.
|
# invalidations are idempotent, so duplicates are fine.
|
||||||
|
@ -107,7 +111,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||||
"get_all_updated_caches", get_all_updated_caches_txn
|
"get_all_updated_caches", get_all_updated_caches_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
def process_replication_rows(
|
||||||
|
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
|
||||||
|
) -> None:
|
||||||
if stream_name == EventsStream.NAME:
|
if stream_name == EventsStream.NAME:
|
||||||
for row in rows:
|
for row in rows:
|
||||||
self._process_event_stream_row(token, row)
|
self._process_event_stream_row(token, row)
|
||||||
|
@ -142,10 +148,11 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
super().process_replication_rows(stream_name, instance_name, token, rows)
|
super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||||
|
|
||||||
def _process_event_stream_row(self, token, row):
|
def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None:
|
||||||
data = row.data
|
data = row.data
|
||||||
|
|
||||||
if row.type == EventsStreamEventRow.TypeId:
|
if row.type == EventsStreamEventRow.TypeId:
|
||||||
|
assert isinstance(data, EventsStreamEventRow)
|
||||||
self._invalidate_caches_for_event(
|
self._invalidate_caches_for_event(
|
||||||
token,
|
token,
|
||||||
data.event_id,
|
data.event_id,
|
||||||
|
@ -157,9 +164,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||||
backfilled=False,
|
backfilled=False,
|
||||||
)
|
)
|
||||||
elif row.type == EventsStreamCurrentStateRow.TypeId:
|
elif row.type == EventsStreamCurrentStateRow.TypeId:
|
||||||
self._curr_state_delta_stream_cache.entity_has_changed(
|
assert isinstance(data, EventsStreamCurrentStateRow)
|
||||||
row.data.room_id, token
|
self._curr_state_delta_stream_cache.entity_has_changed(data.room_id, token)
|
||||||
)
|
|
||||||
|
|
||||||
if data.type == EventTypes.Member:
|
if data.type == EventTypes.Member:
|
||||||
self.get_rooms_for_user_with_stream_ordering.invalidate(
|
self.get_rooms_for_user_with_stream_ordering.invalidate(
|
||||||
|
@ -170,15 +176,15 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
def _invalidate_caches_for_event(
|
def _invalidate_caches_for_event(
|
||||||
self,
|
self,
|
||||||
stream_ordering,
|
stream_ordering: int,
|
||||||
event_id,
|
event_id: str,
|
||||||
room_id,
|
room_id: str,
|
||||||
etype,
|
etype: str,
|
||||||
state_key,
|
state_key: Optional[str],
|
||||||
redacts,
|
redacts: Optional[str],
|
||||||
relates_to,
|
relates_to: Optional[str],
|
||||||
backfilled,
|
backfilled: bool,
|
||||||
):
|
) -> None:
|
||||||
self._invalidate_get_event_cache(event_id)
|
self._invalidate_get_event_cache(event_id)
|
||||||
self.have_seen_event.invalidate((room_id, event_id))
|
self.have_seen_event.invalidate((room_id, event_id))
|
||||||
|
|
||||||
|
@ -186,6 +192,10 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
self.get_unread_event_push_actions_by_room_for_user.invalidate((room_id,))
|
self.get_unread_event_push_actions_by_room_for_user.invalidate((room_id,))
|
||||||
|
|
||||||
|
# The `_get_membership_from_event_id` is immutable, except for the
|
||||||
|
# case where we look up an event *before* persisting it.
|
||||||
|
self._get_membership_from_event_id.invalidate((event_id,))
|
||||||
|
|
||||||
if not backfilled:
|
if not backfilled:
|
||||||
self._events_stream_cache.entity_has_changed(room_id, stream_ordering)
|
self._events_stream_cache.entity_has_changed(room_id, stream_ordering)
|
||||||
|
|
||||||
|
@ -207,7 +217,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||||
self.get_thread_summary.invalidate((relates_to,))
|
self.get_thread_summary.invalidate((relates_to,))
|
||||||
self.get_thread_participated.invalidate((relates_to,))
|
self.get_thread_participated.invalidate((relates_to,))
|
||||||
|
|
||||||
async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
|
async def invalidate_cache_and_stream(
|
||||||
|
self, cache_name: str, keys: Tuple[Any, ...]
|
||||||
|
) -> None:
|
||||||
"""Invalidates the cache and adds it to the cache stream so slaves
|
"""Invalidates the cache and adds it to the cache stream so slaves
|
||||||
will know to invalidate their caches.
|
will know to invalidate their caches.
|
||||||
|
|
||||||
|
@ -227,7 +239,12 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||||
keys,
|
keys,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _invalidate_cache_and_stream(self, txn, cache_func, keys):
|
def _invalidate_cache_and_stream(
|
||||||
|
self,
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
cache_func: _CachedFunction,
|
||||||
|
keys: Tuple[Any, ...],
|
||||||
|
) -> None:
|
||||||
"""Invalidates the cache and adds it to the cache stream so slaves
|
"""Invalidates the cache and adds it to the cache stream so slaves
|
||||||
will know to invalidate their caches.
|
will know to invalidate their caches.
|
||||||
|
|
||||||
|
@ -238,7 +255,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||||
txn.call_after(cache_func.invalidate, keys)
|
txn.call_after(cache_func.invalidate, keys)
|
||||||
self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
|
self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
|
||||||
|
|
||||||
def _invalidate_all_cache_and_stream(self, txn, cache_func):
|
def _invalidate_all_cache_and_stream(
|
||||||
|
self, txn: LoggingTransaction, cache_func: _CachedFunction
|
||||||
|
) -> None:
|
||||||
"""Invalidates the entire cache and adds it to the cache stream so slaves
|
"""Invalidates the entire cache and adds it to the cache stream so slaves
|
||||||
will know to invalidate their caches.
|
will know to invalidate their caches.
|
||||||
"""
|
"""
|
||||||
|
@ -279,8 +298,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _send_invalidation_to_replication(
|
def _send_invalidation_to_replication(
|
||||||
self, txn, cache_name: str, keys: Optional[Iterable[Any]]
|
self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[Any]]
|
||||||
):
|
) -> None:
|
||||||
"""Notifies replication that given cache has been invalidated.
|
"""Notifies replication that given cache has been invalidated.
|
||||||
|
|
||||||
Note that this does *not* invalidate the cache locally.
|
Note that this does *not* invalidate the cache locally.
|
||||||
|
@ -315,7 +334,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||||
"instance_name": self._instance_name,
|
"instance_name": self._instance_name,
|
||||||
"cache_func": cache_name,
|
"cache_func": cache_name,
|
||||||
"keys": keys,
|
"keys": keys,
|
||||||
"invalidation_ts": self.clock.time_msec(),
|
"invalidation_ts": self._clock.time_msec(),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1073,9 +1073,15 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
/* Get the depth and stream_ordering of the prev_event_id from the events table */
|
/* Get the depth and stream_ordering of the prev_event_id from the events table */
|
||||||
INNER JOIN events
|
INNER JOIN events
|
||||||
ON prev_event_id = events.event_id
|
ON prev_event_id = events.event_id
|
||||||
|
|
||||||
|
/* exclude outliers from the results (we don't have the state, so cannot
|
||||||
|
* verify if the requesting server can see them).
|
||||||
|
*/
|
||||||
|
WHERE NOT events.outlier
|
||||||
|
|
||||||
/* Look for an edge which matches the given event_id */
|
/* Look for an edge which matches the given event_id */
|
||||||
WHERE event_edges.event_id = ?
|
AND event_edges.event_id = ? AND NOT event_edges.is_state
|
||||||
AND event_edges.is_state = ?
|
|
||||||
/* Because we can have many events at the same depth,
|
/* Because we can have many events at the same depth,
|
||||||
* we want to also tie-break and sort on stream_ordering */
|
* we want to also tie-break and sort on stream_ordering */
|
||||||
ORDER BY depth DESC, stream_ordering DESC
|
ORDER BY depth DESC, stream_ordering DESC
|
||||||
|
@ -1084,7 +1090,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
|
|
||||||
txn.execute(
|
txn.execute(
|
||||||
connected_prev_event_query,
|
connected_prev_event_query,
|
||||||
(event_id, False, limit),
|
(event_id, limit),
|
||||||
)
|
)
|
||||||
return [
|
return [
|
||||||
BackfillQueueNavigationItem(
|
BackfillQueueNavigationItem(
|
||||||
|
|
|
@ -1745,6 +1745,13 @@ class PersistEventsStore:
|
||||||
(event.state_key,),
|
(event.state_key,),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# The `_get_membership_from_event_id` is immutable, except for the
|
||||||
|
# case where we look up an event *before* persisting it.
|
||||||
|
txn.call_after(
|
||||||
|
self.store._get_membership_from_event_id.invalidate,
|
||||||
|
(event.event_id,),
|
||||||
|
)
|
||||||
|
|
||||||
# We update the local_current_membership table only if the event is
|
# We update the local_current_membership table only if the event is
|
||||||
# "current", i.e., its something that has just happened.
|
# "current", i.e., its something that has just happened.
|
||||||
#
|
#
|
||||||
|
|
|
@ -13,13 +13,17 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast
|
||||||
|
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||||
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
|
from synapse.storage.database import (
|
||||||
|
DatabasePool,
|
||||||
|
LoggingDatabaseConnection,
|
||||||
|
LoggingTransaction,
|
||||||
|
)
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
from synapse.util import json_encoder
|
from synapse.util import json_encoder
|
||||||
|
|
||||||
|
@ -75,7 +79,7 @@ class GroupServerWorkerStore(SQLBaseStore):
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
# TODO: Pagination
|
# TODO: Pagination
|
||||||
|
|
||||||
keyvalues = {"group_id": group_id}
|
keyvalues: JsonDict = {"group_id": group_id}
|
||||||
if not include_private:
|
if not include_private:
|
||||||
keyvalues["is_public"] = True
|
keyvalues["is_public"] = True
|
||||||
|
|
||||||
|
@ -117,7 +121,7 @@ class GroupServerWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
# TODO: Pagination
|
# TODO: Pagination
|
||||||
|
|
||||||
def _get_rooms_in_group_txn(txn):
|
def _get_rooms_in_group_txn(txn: LoggingTransaction) -> List[_RoomInGroup]:
|
||||||
sql = """
|
sql = """
|
||||||
SELECT room_id, is_public FROM group_rooms
|
SELECT room_id, is_public FROM group_rooms
|
||||||
WHERE group_id = ?
|
WHERE group_id = ?
|
||||||
|
@ -176,8 +180,10 @@ class GroupServerWorkerStore(SQLBaseStore):
|
||||||
* "order": int, the sort order of rooms in this category
|
* "order": int, the sort order of rooms in this category
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _get_rooms_for_summary_txn(txn):
|
def _get_rooms_for_summary_txn(
|
||||||
keyvalues = {"group_id": group_id}
|
txn: LoggingTransaction,
|
||||||
|
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
||||||
|
keyvalues: JsonDict = {"group_id": group_id}
|
||||||
if not include_private:
|
if not include_private:
|
||||||
keyvalues["is_public"] = True
|
keyvalues["is_public"] = True
|
||||||
|
|
||||||
|
@ -241,7 +247,7 @@ class GroupServerWorkerStore(SQLBaseStore):
|
||||||
"get_rooms_for_summary", _get_rooms_for_summary_txn
|
"get_rooms_for_summary", _get_rooms_for_summary_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_group_categories(self, group_id):
|
async def get_group_categories(self, group_id: str) -> JsonDict:
|
||||||
rows = await self.db_pool.simple_select_list(
|
rows = await self.db_pool.simple_select_list(
|
||||||
table="group_room_categories",
|
table="group_room_categories",
|
||||||
keyvalues={"group_id": group_id},
|
keyvalues={"group_id": group_id},
|
||||||
|
@ -257,7 +263,7 @@ class GroupServerWorkerStore(SQLBaseStore):
|
||||||
for row in rows
|
for row in rows
|
||||||
}
|
}
|
||||||
|
|
||||||
async def get_group_category(self, group_id, category_id):
|
async def get_group_category(self, group_id: str, category_id: str) -> JsonDict:
|
||||||
category = await self.db_pool.simple_select_one(
|
category = await self.db_pool.simple_select_one(
|
||||||
table="group_room_categories",
|
table="group_room_categories",
|
||||||
keyvalues={"group_id": group_id, "category_id": category_id},
|
keyvalues={"group_id": group_id, "category_id": category_id},
|
||||||
|
@ -269,7 +275,7 @@ class GroupServerWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
return category
|
return category
|
||||||
|
|
||||||
async def get_group_roles(self, group_id):
|
async def get_group_roles(self, group_id: str) -> JsonDict:
|
||||||
rows = await self.db_pool.simple_select_list(
|
rows = await self.db_pool.simple_select_list(
|
||||||
table="group_roles",
|
table="group_roles",
|
||||||
keyvalues={"group_id": group_id},
|
keyvalues={"group_id": group_id},
|
||||||
|
@ -285,7 +291,7 @@ class GroupServerWorkerStore(SQLBaseStore):
|
||||||
for row in rows
|
for row in rows
|
||||||
}
|
}
|
||||||
|
|
||||||
async def get_group_role(self, group_id, role_id):
|
async def get_group_role(self, group_id: str, role_id: str) -> JsonDict:
|
||||||
role = await self.db_pool.simple_select_one(
|
role = await self.db_pool.simple_select_one(
|
||||||
table="group_roles",
|
table="group_roles",
|
||||||
keyvalues={"group_id": group_id, "role_id": role_id},
|
keyvalues={"group_id": group_id, "role_id": role_id},
|
||||||
|
@ -311,15 +317,19 @@ class GroupServerWorkerStore(SQLBaseStore):
|
||||||
desc="get_local_groups_for_room",
|
desc="get_local_groups_for_room",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_users_for_summary_by_role(self, group_id, include_private=False):
|
async def get_users_for_summary_by_role(
|
||||||
|
self, group_id: str, include_private: bool = False
|
||||||
|
) -> Tuple[List[JsonDict], JsonDict]:
|
||||||
"""Get the users and roles that should be included in a summary request
|
"""Get the users and roles that should be included in a summary request
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
([users], [roles])
|
([users], [roles])
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _get_users_for_summary_txn(txn):
|
def _get_users_for_summary_txn(
|
||||||
keyvalues = {"group_id": group_id}
|
txn: LoggingTransaction,
|
||||||
|
) -> Tuple[List[JsonDict], JsonDict]:
|
||||||
|
keyvalues: JsonDict = {"group_id": group_id}
|
||||||
if not include_private:
|
if not include_private:
|
||||||
keyvalues["is_public"] = True
|
keyvalues["is_public"] = True
|
||||||
|
|
||||||
|
@ -406,7 +416,9 @@ class GroupServerWorkerStore(SQLBaseStore):
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_users_membership_info_in_group(self, group_id, user_id):
|
async def get_users_membership_info_in_group(
|
||||||
|
self, group_id: str, user_id: str
|
||||||
|
) -> JsonDict:
|
||||||
"""Get a dict describing the membership of a user in a group.
|
"""Get a dict describing the membership of a user in a group.
|
||||||
|
|
||||||
Example if joined:
|
Example if joined:
|
||||||
|
@ -421,7 +433,7 @@ class GroupServerWorkerStore(SQLBaseStore):
|
||||||
An empty dict if the user is not join/invite/etc
|
An empty dict if the user is not join/invite/etc
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _get_users_membership_in_group_txn(txn):
|
def _get_users_membership_in_group_txn(txn: LoggingTransaction) -> JsonDict:
|
||||||
row = self.db_pool.simple_select_one_txn(
|
row = self.db_pool.simple_select_one_txn(
|
||||||
txn,
|
txn,
|
||||||
table="group_users",
|
table="group_users",
|
||||||
|
@ -463,10 +475,14 @@ class GroupServerWorkerStore(SQLBaseStore):
|
||||||
desc="get_publicised_groups_for_user",
|
desc="get_publicised_groups_for_user",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_attestations_need_renewals(self, valid_until_ms):
|
async def get_attestations_need_renewals(
|
||||||
|
self, valid_until_ms: int
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
"""Get all attestations that need to be renewed until givent time"""
|
"""Get all attestations that need to be renewed until givent time"""
|
||||||
|
|
||||||
def _get_attestations_need_renewals_txn(txn):
|
def _get_attestations_need_renewals_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
sql = """
|
sql = """
|
||||||
SELECT group_id, user_id FROM group_attestations_renewals
|
SELECT group_id, user_id FROM group_attestations_renewals
|
||||||
WHERE valid_until_ms <= ?
|
WHERE valid_until_ms <= ?
|
||||||
|
@ -478,7 +494,9 @@ class GroupServerWorkerStore(SQLBaseStore):
|
||||||
"get_attestations_need_renewals", _get_attestations_need_renewals_txn
|
"get_attestations_need_renewals", _get_attestations_need_renewals_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_remote_attestation(self, group_id, user_id):
|
async def get_remote_attestation(
|
||||||
|
self, group_id: str, user_id: str
|
||||||
|
) -> Optional[JsonDict]:
|
||||||
"""Get the attestation that proves the remote agrees that the user is
|
"""Get the attestation that proves the remote agrees that the user is
|
||||||
in the group.
|
in the group.
|
||||||
"""
|
"""
|
||||||
|
@ -504,8 +522,8 @@ class GroupServerWorkerStore(SQLBaseStore):
|
||||||
desc="get_joined_groups",
|
desc="get_joined_groups",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_all_groups_for_user(self, user_id, now_token):
|
async def get_all_groups_for_user(self, user_id, now_token) -> List[JsonDict]:
|
||||||
def _get_all_groups_for_user_txn(txn):
|
def _get_all_groups_for_user_txn(txn: LoggingTransaction) -> List[JsonDict]:
|
||||||
sql = """
|
sql = """
|
||||||
SELECT group_id, type, membership, u.content
|
SELECT group_id, type, membership, u.content
|
||||||
FROM local_group_updates AS u
|
FROM local_group_updates AS u
|
||||||
|
@ -528,15 +546,16 @@ class GroupServerWorkerStore(SQLBaseStore):
|
||||||
"get_all_groups_for_user", _get_all_groups_for_user_txn
|
"get_all_groups_for_user", _get_all_groups_for_user_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_groups_changes_for_user(self, user_id, from_token, to_token):
|
async def get_groups_changes_for_user(
|
||||||
from_token = int(from_token)
|
self, user_id: str, from_token: int, to_token: int
|
||||||
has_changed = self._group_updates_stream_cache.has_entity_changed(
|
) -> List[JsonDict]:
|
||||||
|
has_changed = self._group_updates_stream_cache.has_entity_changed( # type: ignore[attr-defined]
|
||||||
user_id, from_token
|
user_id, from_token
|
||||||
)
|
)
|
||||||
if not has_changed:
|
if not has_changed:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def _get_groups_changes_for_user_txn(txn):
|
def _get_groups_changes_for_user_txn(txn: LoggingTransaction) -> List[JsonDict]:
|
||||||
sql = """
|
sql = """
|
||||||
SELECT group_id, membership, type, u.content
|
SELECT group_id, membership, type, u.content
|
||||||
FROM local_group_updates AS u
|
FROM local_group_updates AS u
|
||||||
|
@ -583,12 +602,14 @@ class GroupServerWorkerStore(SQLBaseStore):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
last_id = int(last_id)
|
last_id = int(last_id)
|
||||||
has_changed = self._group_updates_stream_cache.has_any_entity_changed(last_id)
|
has_changed = self._group_updates_stream_cache.has_any_entity_changed(last_id) # type: ignore[attr-defined]
|
||||||
|
|
||||||
if not has_changed:
|
if not has_changed:
|
||||||
return [], current_id, False
|
return [], current_id, False
|
||||||
|
|
||||||
def _get_all_groups_changes_txn(txn):
|
def _get_all_groups_changes_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
|
||||||
sql = """
|
sql = """
|
||||||
SELECT stream_id, group_id, user_id, type, content
|
SELECT stream_id, group_id, user_id, type, content
|
||||||
FROM local_group_updates
|
FROM local_group_updates
|
||||||
|
@ -596,10 +617,13 @@ class GroupServerWorkerStore(SQLBaseStore):
|
||||||
LIMIT ?
|
LIMIT ?
|
||||||
"""
|
"""
|
||||||
txn.execute(sql, (last_id, current_id, limit))
|
txn.execute(sql, (last_id, current_id, limit))
|
||||||
updates = [
|
updates = cast(
|
||||||
|
List[Tuple[int, tuple]],
|
||||||
|
[
|
||||||
(stream_id, (group_id, user_id, gtype, db_to_json(content_json)))
|
(stream_id, (group_id, user_id, gtype, db_to_json(content_json)))
|
||||||
for stream_id, group_id, user_id, gtype, content_json in txn
|
for stream_id, group_id, user_id, gtype, content_json in txn
|
||||||
]
|
],
|
||||||
|
)
|
||||||
|
|
||||||
limited = False
|
limited = False
|
||||||
upto_token = current_id
|
upto_token = current_id
|
||||||
|
@ -633,8 +657,8 @@ class GroupServerStore(GroupServerWorkerStore):
|
||||||
self,
|
self,
|
||||||
group_id: str,
|
group_id: str,
|
||||||
room_id: str,
|
room_id: str,
|
||||||
category_id: str,
|
category_id: Optional[str],
|
||||||
order: int,
|
order: Optional[int],
|
||||||
is_public: Optional[bool],
|
is_public: Optional[bool],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Add (or update) room's entry in summary.
|
"""Add (or update) room's entry in summary.
|
||||||
|
@ -661,11 +685,11 @@ class GroupServerStore(GroupServerWorkerStore):
|
||||||
|
|
||||||
def _add_room_to_summary_txn(
|
def _add_room_to_summary_txn(
|
||||||
self,
|
self,
|
||||||
txn,
|
txn: LoggingTransaction,
|
||||||
group_id: str,
|
group_id: str,
|
||||||
room_id: str,
|
room_id: str,
|
||||||
category_id: str,
|
category_id: Optional[str],
|
||||||
order: int,
|
order: Optional[int],
|
||||||
is_public: Optional[bool],
|
is_public: Optional[bool],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Add (or update) room's entry in summary.
|
"""Add (or update) room's entry in summary.
|
||||||
|
@ -750,7 +774,7 @@ class GroupServerStore(GroupServerWorkerStore):
|
||||||
WHERE group_id = ? AND category_id = ?
|
WHERE group_id = ? AND category_id = ?
|
||||||
"""
|
"""
|
||||||
txn.execute(sql, (group_id, category_id))
|
txn.execute(sql, (group_id, category_id))
|
||||||
(order,) = txn.fetchone()
|
(order,) = cast(Tuple[int], txn.fetchone())
|
||||||
|
|
||||||
if existing:
|
if existing:
|
||||||
to_update = {}
|
to_update = {}
|
||||||
|
@ -766,7 +790,7 @@ class GroupServerStore(GroupServerWorkerStore):
|
||||||
"category_id": category_id,
|
"category_id": category_id,
|
||||||
"room_id": room_id,
|
"room_id": room_id,
|
||||||
},
|
},
|
||||||
values=to_update,
|
updatevalues=to_update,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if is_public is None:
|
if is_public is None:
|
||||||
|
@ -785,7 +809,7 @@ class GroupServerStore(GroupServerWorkerStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def remove_room_from_summary(
|
async def remove_room_from_summary(
|
||||||
self, group_id: str, room_id: str, category_id: str
|
self, group_id: str, room_id: str, category_id: Optional[str]
|
||||||
) -> int:
|
) -> int:
|
||||||
if category_id is None:
|
if category_id is None:
|
||||||
category_id = _DEFAULT_CATEGORY_ID
|
category_id = _DEFAULT_CATEGORY_ID
|
||||||
|
@ -808,8 +832,8 @@ class GroupServerStore(GroupServerWorkerStore):
|
||||||
is_public: Optional[bool],
|
is_public: Optional[bool],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Add/update room category for group"""
|
"""Add/update room category for group"""
|
||||||
insertion_values = {}
|
insertion_values: JsonDict = {}
|
||||||
update_values = {"category_id": category_id} # This cannot be empty
|
update_values: JsonDict = {"category_id": category_id} # This cannot be empty
|
||||||
|
|
||||||
if profile is None:
|
if profile is None:
|
||||||
insertion_values["profile"] = "{}"
|
insertion_values["profile"] = "{}"
|
||||||
|
@ -844,8 +868,8 @@ class GroupServerStore(GroupServerWorkerStore):
|
||||||
is_public: Optional[bool],
|
is_public: Optional[bool],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Add/remove user role"""
|
"""Add/remove user role"""
|
||||||
insertion_values = {}
|
insertion_values: JsonDict = {}
|
||||||
update_values = {"role_id": role_id} # This cannot be empty
|
update_values: JsonDict = {"role_id": role_id} # This cannot be empty
|
||||||
|
|
||||||
if profile is None:
|
if profile is None:
|
||||||
insertion_values["profile"] = "{}"
|
insertion_values["profile"] = "{}"
|
||||||
|
@ -876,8 +900,8 @@ class GroupServerStore(GroupServerWorkerStore):
|
||||||
self,
|
self,
|
||||||
group_id: str,
|
group_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
role_id: str,
|
role_id: Optional[str],
|
||||||
order: int,
|
order: Optional[int],
|
||||||
is_public: Optional[bool],
|
is_public: Optional[bool],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Add (or update) user's entry in summary.
|
"""Add (or update) user's entry in summary.
|
||||||
|
@ -904,13 +928,13 @@ class GroupServerStore(GroupServerWorkerStore):
|
||||||
|
|
||||||
def _add_user_to_summary_txn(
|
def _add_user_to_summary_txn(
|
||||||
self,
|
self,
|
||||||
txn,
|
txn: LoggingTransaction,
|
||||||
group_id: str,
|
group_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
role_id: str,
|
role_id: Optional[str],
|
||||||
order: int,
|
order: Optional[int],
|
||||||
is_public: Optional[bool],
|
is_public: Optional[bool],
|
||||||
):
|
) -> None:
|
||||||
"""Add (or update) user's entry in summary.
|
"""Add (or update) user's entry in summary.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -989,7 +1013,7 @@ class GroupServerStore(GroupServerWorkerStore):
|
||||||
WHERE group_id = ? AND role_id = ?
|
WHERE group_id = ? AND role_id = ?
|
||||||
"""
|
"""
|
||||||
txn.execute(sql, (group_id, role_id))
|
txn.execute(sql, (group_id, role_id))
|
||||||
(order,) = txn.fetchone()
|
(order,) = cast(Tuple[int], txn.fetchone())
|
||||||
|
|
||||||
if existing:
|
if existing:
|
||||||
to_update = {}
|
to_update = {}
|
||||||
|
@ -1005,7 +1029,7 @@ class GroupServerStore(GroupServerWorkerStore):
|
||||||
"role_id": role_id,
|
"role_id": role_id,
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
},
|
},
|
||||||
values=to_update,
|
updatevalues=to_update,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if is_public is None:
|
if is_public is None:
|
||||||
|
@ -1024,7 +1048,7 @@ class GroupServerStore(GroupServerWorkerStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def remove_user_from_summary(
|
async def remove_user_from_summary(
|
||||||
self, group_id: str, user_id: str, role_id: str
|
self, group_id: str, user_id: str, role_id: Optional[str]
|
||||||
) -> int:
|
) -> int:
|
||||||
if role_id is None:
|
if role_id is None:
|
||||||
role_id = _DEFAULT_ROLE_ID
|
role_id = _DEFAULT_ROLE_ID
|
||||||
|
@ -1065,7 +1089,7 @@ class GroupServerStore(GroupServerWorkerStore):
|
||||||
Optional if the user and group are on the same server
|
Optional if the user and group are on the same server
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _add_user_to_group_txn(txn):
|
def _add_user_to_group_txn(txn: LoggingTransaction) -> None:
|
||||||
self.db_pool.simple_insert_txn(
|
self.db_pool.simple_insert_txn(
|
||||||
txn,
|
txn,
|
||||||
table="group_users",
|
table="group_users",
|
||||||
|
@ -1108,7 +1132,7 @@ class GroupServerStore(GroupServerWorkerStore):
|
||||||
await self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn)
|
await self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn)
|
||||||
|
|
||||||
async def remove_user_from_group(self, group_id: str, user_id: str) -> None:
|
async def remove_user_from_group(self, group_id: str, user_id: str) -> None:
|
||||||
def _remove_user_from_group_txn(txn):
|
def _remove_user_from_group_txn(txn: LoggingTransaction) -> None:
|
||||||
self.db_pool.simple_delete_txn(
|
self.db_pool.simple_delete_txn(
|
||||||
txn,
|
txn,
|
||||||
table="group_users",
|
table="group_users",
|
||||||
|
@ -1159,7 +1183,7 @@ class GroupServerStore(GroupServerWorkerStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def remove_room_from_group(self, group_id: str, room_id: str) -> None:
|
async def remove_room_from_group(self, group_id: str, room_id: str) -> None:
|
||||||
def _remove_room_from_group_txn(txn):
|
def _remove_room_from_group_txn(txn: LoggingTransaction) -> None:
|
||||||
self.db_pool.simple_delete_txn(
|
self.db_pool.simple_delete_txn(
|
||||||
txn,
|
txn,
|
||||||
table="group_rooms",
|
table="group_rooms",
|
||||||
|
@ -1216,7 +1240,9 @@ class GroupServerStore(GroupServerWorkerStore):
|
||||||
|
|
||||||
content = content or {}
|
content = content or {}
|
||||||
|
|
||||||
def _register_user_group_membership_txn(txn, next_id):
|
def _register_user_group_membership_txn(
|
||||||
|
txn: LoggingTransaction, next_id: int
|
||||||
|
) -> int:
|
||||||
# TODO: Upsert?
|
# TODO: Upsert?
|
||||||
self.db_pool.simple_delete_txn(
|
self.db_pool.simple_delete_txn(
|
||||||
txn,
|
txn,
|
||||||
|
@ -1249,7 +1275,7 @@ class GroupServerStore(GroupServerWorkerStore):
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self._group_updates_stream_cache.entity_has_changed(user_id, next_id)
|
self._group_updates_stream_cache.entity_has_changed(user_id, next_id) # type: ignore[attr-defined]
|
||||||
|
|
||||||
# TODO: Insert profile to ensure it comes down stream if its a join.
|
# TODO: Insert profile to ensure it comes down stream if its a join.
|
||||||
|
|
||||||
|
@ -1289,7 +1315,7 @@ class GroupServerStore(GroupServerWorkerStore):
|
||||||
|
|
||||||
return next_id
|
return next_id
|
||||||
|
|
||||||
async with self._group_updates_id_gen.get_next() as next_id:
|
async with self._group_updates_id_gen.get_next() as next_id: # type: ignore[attr-defined]
|
||||||
res = await self.db_pool.runInteraction(
|
res = await self.db_pool.runInteraction(
|
||||||
"register_user_group_membership",
|
"register_user_group_membership",
|
||||||
_register_user_group_membership_txn,
|
_register_user_group_membership_txn,
|
||||||
|
@ -1298,7 +1324,13 @@ class GroupServerStore(GroupServerWorkerStore):
|
||||||
return res
|
return res
|
||||||
|
|
||||||
async def create_group(
|
async def create_group(
|
||||||
self, group_id, user_id, name, avatar_url, short_description, long_description
|
self,
|
||||||
|
group_id: str,
|
||||||
|
user_id: str,
|
||||||
|
name: str,
|
||||||
|
avatar_url: str,
|
||||||
|
short_description: str,
|
||||||
|
long_description: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
await self.db_pool.simple_insert(
|
await self.db_pool.simple_insert(
|
||||||
table="groups",
|
table="groups",
|
||||||
|
@ -1313,7 +1345,7 @@ class GroupServerStore(GroupServerWorkerStore):
|
||||||
desc="create_group",
|
desc="create_group",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def update_group_profile(self, group_id, profile):
|
async def update_group_profile(self, group_id: str, profile: JsonDict) -> None:
|
||||||
await self.db_pool.simple_update_one(
|
await self.db_pool.simple_update_one(
|
||||||
table="groups",
|
table="groups",
|
||||||
keyvalues={"group_id": group_id},
|
keyvalues={"group_id": group_id},
|
||||||
|
@ -1361,8 +1393,8 @@ class GroupServerStore(GroupServerWorkerStore):
|
||||||
desc="remove_attestation_renewal",
|
desc="remove_attestation_renewal",
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_group_stream_token(self):
|
def get_group_stream_token(self) -> int:
|
||||||
return self._group_updates_id_gen.get_current_token()
|
return self._group_updates_id_gen.get_current_token() # type: ignore[attr-defined]
|
||||||
|
|
||||||
async def delete_group(self, group_id: str) -> None:
|
async def delete_group(self, group_id: str) -> None:
|
||||||
"""Deletes a group fully from the database.
|
"""Deletes a group fully from the database.
|
||||||
|
@ -1371,7 +1403,7 @@ class GroupServerStore(GroupServerWorkerStore):
|
||||||
group_id: The group ID to delete.
|
group_id: The group ID to delete.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _delete_group_txn(txn):
|
def _delete_group_txn(txn: LoggingTransaction) -> None:
|
||||||
tables = [
|
tables = [
|
||||||
"groups",
|
"groups",
|
||||||
"group_users",
|
"group_users",
|
||||||
|
|
|
@ -156,7 +156,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||||
hs: "HomeServer",
|
hs: "HomeServer",
|
||||||
):
|
):
|
||||||
super().__init__(database, db_conn, hs)
|
super().__init__(database, db_conn, hs)
|
||||||
self.server_name = hs.hostname
|
self.server_name: str = hs.hostname
|
||||||
|
|
||||||
async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]:
|
async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]:
|
||||||
"""Get the metadata for a local piece of media
|
"""Get the metadata for a local piece of media
|
||||||
|
|
|
@ -12,15 +12,17 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
|
||||||
|
|
||||||
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||||
from synapse.storage._base import SQLBaseStore
|
from synapse.storage._base import SQLBaseStore
|
||||||
from synapse.storage.database import (
|
from synapse.storage.database import (
|
||||||
DatabasePool,
|
DatabasePool,
|
||||||
LoggingDatabaseConnection,
|
LoggingDatabaseConnection,
|
||||||
|
LoggingTransaction,
|
||||||
make_in_list_sql_clause,
|
make_in_list_sql_clause,
|
||||||
)
|
)
|
||||||
|
from synapse.storage.databases.main.registration import RegistrationWorkerStore
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import cached
|
||||||
from synapse.util.threepids import canonicalise_email
|
from synapse.util.threepids import canonicalise_email
|
||||||
|
|
||||||
|
@ -56,7 +58,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
|
||||||
Number of current monthly active users
|
Number of current monthly active users
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _count_users(txn):
|
def _count_users(txn: LoggingTransaction) -> int:
|
||||||
# Exclude app service users
|
# Exclude app service users
|
||||||
sql = """
|
sql = """
|
||||||
SELECT COUNT(*)
|
SELECT COUNT(*)
|
||||||
|
@ -66,7 +68,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
|
||||||
WHERE (users.appservice_id IS NULL OR users.appservice_id = '');
|
WHERE (users.appservice_id IS NULL OR users.appservice_id = '');
|
||||||
"""
|
"""
|
||||||
txn.execute(sql)
|
txn.execute(sql)
|
||||||
(count,) = txn.fetchone()
|
(count,) = cast(Tuple[int], txn.fetchone())
|
||||||
return count
|
return count
|
||||||
|
|
||||||
return await self.db_pool.runInteraction("count_users", _count_users)
|
return await self.db_pool.runInteraction("count_users", _count_users)
|
||||||
|
@ -84,7 +86,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _count_users_by_service(txn):
|
def _count_users_by_service(txn: LoggingTransaction) -> Dict[str, int]:
|
||||||
sql = """
|
sql = """
|
||||||
SELECT COALESCE(appservice_id, 'native'), COUNT(*)
|
SELECT COALESCE(appservice_id, 'native'), COUNT(*)
|
||||||
FROM monthly_active_users
|
FROM monthly_active_users
|
||||||
|
@ -93,7 +95,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
txn.execute(sql)
|
txn.execute(sql)
|
||||||
result = txn.fetchall()
|
result = cast(List[Tuple[str, int]], txn.fetchall())
|
||||||
return dict(result)
|
return dict(result)
|
||||||
|
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
|
@ -141,12 +143,12 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
@wrap_as_background_process("reap_monthly_active_users")
|
@wrap_as_background_process("reap_monthly_active_users")
|
||||||
async def reap_monthly_active_users(self):
|
async def reap_monthly_active_users(self) -> None:
|
||||||
"""Cleans out monthly active user table to ensure that no stale
|
"""Cleans out monthly active user table to ensure that no stale
|
||||||
entries exist.
|
entries exist.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _reap_users(txn, reserved_users):
|
def _reap_users(txn: LoggingTransaction, reserved_users: List[str]) -> None:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
reserved_users (tuple): reserved users to preserve
|
reserved_users (tuple): reserved users to preserve
|
||||||
|
@ -210,10 +212,10 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
|
||||||
# is racy.
|
# is racy.
|
||||||
# Have resolved to invalidate the whole cache for now and do
|
# Have resolved to invalidate the whole cache for now and do
|
||||||
# something about it if and when the perf becomes significant
|
# something about it if and when the perf becomes significant
|
||||||
self._invalidate_all_cache_and_stream(
|
self._invalidate_all_cache_and_stream( # type: ignore[attr-defined]
|
||||||
txn, self.user_last_seen_monthly_active
|
txn, self.user_last_seen_monthly_active
|
||||||
)
|
)
|
||||||
self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ())
|
self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ()) # type: ignore[attr-defined]
|
||||||
|
|
||||||
reserved_users = await self.get_registered_reserved_users()
|
reserved_users = await self.get_registered_reserved_users()
|
||||||
await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
|
@ -221,7 +223,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
|
class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerStore):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
database: DatabasePool,
|
database: DatabasePool,
|
||||||
|
@ -242,13 +244,15 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
|
||||||
hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value],
|
hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value],
|
||||||
)
|
)
|
||||||
|
|
||||||
def _initialise_reserved_users(self, txn, threepids):
|
def _initialise_reserved_users(
|
||||||
|
self, txn: LoggingTransaction, threepids: List[dict]
|
||||||
|
) -> None:
|
||||||
"""Ensures that reserved threepids are accounted for in the MAU table, should
|
"""Ensures that reserved threepids are accounted for in the MAU table, should
|
||||||
be called on start up.
|
be called on start up.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
txn (cursor):
|
txn:
|
||||||
threepids (list[dict]): List of threepid dicts to reserve
|
threepids: List of threepid dicts to reserve
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# XXX what is this function trying to achieve? It upserts into
|
# XXX what is this function trying to achieve? It upserts into
|
||||||
|
@ -299,7 +303,9 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
|
||||||
"upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id
|
"upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
def upsert_monthly_active_user_txn(self, txn, user_id):
|
def upsert_monthly_active_user_txn(
|
||||||
|
self, txn: LoggingTransaction, user_id: str
|
||||||
|
) -> None:
|
||||||
"""Updates or inserts monthly active user member
|
"""Updates or inserts monthly active user member
|
||||||
|
|
||||||
We consciously do not call is_support_txn from this method because it
|
We consciously do not call is_support_txn from this method because it
|
||||||
|
@ -336,7 +342,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
|
||||||
txn, self.user_last_seen_monthly_active, (user_id,)
|
txn, self.user_last_seen_monthly_active, (user_id,)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def populate_monthly_active_users(self, user_id):
|
async def populate_monthly_active_users(self, user_id: str) -> None:
|
||||||
"""Checks on the state of monthly active user limits and optionally
|
"""Checks on the state of monthly active user limits and optionally
|
||||||
add the user to the monthly active tables
|
add the user to the monthly active tables
|
||||||
|
|
||||||
|
@ -345,7 +351,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
|
||||||
"""
|
"""
|
||||||
if self._limit_usage_by_mau or self._mau_stats_only:
|
if self._limit_usage_by_mau or self._mau_stats_only:
|
||||||
# Trial users and guests should not be included as part of MAU group
|
# Trial users and guests should not be included as part of MAU group
|
||||||
is_guest = await self.is_guest(user_id)
|
is_guest = await self.is_guest(user_id) # type: ignore[attr-defined]
|
||||||
if is_guest:
|
if is_guest:
|
||||||
return
|
return
|
||||||
is_trial = await self.is_trial_user(user_id)
|
is_trial = await self.is_trial_user(user_id)
|
||||||
|
|
|
@ -24,10 +24,9 @@ from typing import (
|
||||||
Optional,
|
Optional,
|
||||||
Set,
|
Set,
|
||||||
Tuple,
|
Tuple,
|
||||||
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.api.constants import ReceiptTypes
|
from synapse.api.constants import ReceiptTypes
|
||||||
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
|
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
|
||||||
from synapse.replication.tcp.streams import ReceiptsStream
|
from synapse.replication.tcp.streams import ReceiptsStream
|
||||||
|
@ -38,7 +37,11 @@ from synapse.storage.database import (
|
||||||
LoggingTransaction,
|
LoggingTransaction,
|
||||||
)
|
)
|
||||||
from synapse.storage.engines import PostgresEngine
|
from synapse.storage.engines import PostgresEngine
|
||||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
|
from synapse.storage.util.id_generators import (
|
||||||
|
AbstractStreamIdTracker,
|
||||||
|
MultiWriterIdGenerator,
|
||||||
|
StreamIdGenerator,
|
||||||
|
)
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
from synapse.util import json_encoder
|
from synapse.util import json_encoder
|
||||||
from synapse.util.caches.descriptors import cached, cachedList
|
from synapse.util.caches.descriptors import cached, cachedList
|
||||||
|
@ -58,6 +61,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||||
hs: "HomeServer",
|
hs: "HomeServer",
|
||||||
):
|
):
|
||||||
self._instance_name = hs.get_instance_name()
|
self._instance_name = hs.get_instance_name()
|
||||||
|
self._receipts_id_gen: AbstractStreamIdTracker
|
||||||
|
|
||||||
if isinstance(database.engine, PostgresEngine):
|
if isinstance(database.engine, PostgresEngine):
|
||||||
self._can_write_to_receipts = (
|
self._can_write_to_receipts = (
|
||||||
|
@ -161,7 +165,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||||
" AND user_id = ?"
|
" AND user_id = ?"
|
||||||
)
|
)
|
||||||
txn.execute(sql, (user_id,))
|
txn.execute(sql, (user_id,))
|
||||||
return txn.fetchall()
|
return cast(List[Tuple[str, str, int, int]], txn.fetchall())
|
||||||
|
|
||||||
rows = await self.db_pool.runInteraction(
|
rows = await self.db_pool.runInteraction(
|
||||||
"get_receipts_for_user_with_orderings", f
|
"get_receipts_for_user_with_orderings", f
|
||||||
|
@ -257,7 +261,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||||
if not rows:
|
if not rows:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
content = {}
|
content: JsonDict = {}
|
||||||
for row in rows:
|
for row in rows:
|
||||||
content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[
|
content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[
|
||||||
row["user_id"]
|
row["user_id"]
|
||||||
|
@ -305,7 +309,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||||
"_get_linearized_receipts_for_rooms", f
|
"_get_linearized_receipts_for_rooms", f
|
||||||
)
|
)
|
||||||
|
|
||||||
results = {}
|
results: JsonDict = {}
|
||||||
for row in txn_results:
|
for row in txn_results:
|
||||||
# We want a single event per room, since we want to batch the
|
# We want a single event per room, since we want to batch the
|
||||||
# receipts by room, event and type.
|
# receipts by room, event and type.
|
||||||
|
@ -370,7 +374,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||||
"get_linearized_receipts_for_all_rooms", f
|
"get_linearized_receipts_for_all_rooms", f
|
||||||
)
|
)
|
||||||
|
|
||||||
results = {}
|
results: JsonDict = {}
|
||||||
for row in txn_results:
|
for row in txn_results:
|
||||||
# We want a single event per room, since we want to batch the
|
# We want a single event per room, since we want to batch the
|
||||||
# receipts by room, event and type.
|
# receipts by room, event and type.
|
||||||
|
@ -399,7 +403,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if last_id == current_id:
|
if last_id == current_id:
|
||||||
return defer.succeed([])
|
return []
|
||||||
|
|
||||||
def _get_users_sent_receipts_between_txn(txn: LoggingTransaction) -> List[str]:
|
def _get_users_sent_receipts_between_txn(txn: LoggingTransaction) -> List[str]:
|
||||||
sql = """
|
sql = """
|
||||||
|
@ -453,7 +457,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||||
"""
|
"""
|
||||||
txn.execute(sql, (last_id, current_id, limit))
|
txn.execute(sql, (last_id, current_id, limit))
|
||||||
|
|
||||||
updates = [(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn]
|
updates = cast(
|
||||||
|
List[Tuple[int, list]],
|
||||||
|
[(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn],
|
||||||
|
)
|
||||||
|
|
||||||
limited = False
|
limited = False
|
||||||
upper_bound = current_id
|
upper_bound = current_id
|
||||||
|
@ -496,7 +503,13 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||||
self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
|
self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
|
||||||
self.get_receipts_for_room.invalidate((room_id, receipt_type))
|
self.get_receipts_for_room.invalidate((room_id, receipt_type))
|
||||||
|
|
||||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
def process_replication_rows(
|
||||||
|
self,
|
||||||
|
stream_name: str,
|
||||||
|
instance_name: str,
|
||||||
|
token: int,
|
||||||
|
rows: Iterable[Any],
|
||||||
|
) -> None:
|
||||||
if stream_name == ReceiptsStream.NAME:
|
if stream_name == ReceiptsStream.NAME:
|
||||||
self._receipts_id_gen.advance(instance_name, token)
|
self._receipts_id_gen.advance(instance_name, token)
|
||||||
for row in rows:
|
for row in rows:
|
||||||
|
@ -584,7 +597,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
if receipt_type == ReceiptTypes.READ and stream_ordering is not None:
|
if receipt_type == ReceiptTypes.READ and stream_ordering is not None:
|
||||||
self._remove_old_push_actions_before_txn(
|
self._remove_old_push_actions_before_txn( # type: ignore[attr-defined]
|
||||||
txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering
|
txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -637,7 +650,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||||
"insert_receipt_conv", graph_to_linear
|
"insert_receipt_conv", graph_to_linear
|
||||||
)
|
)
|
||||||
|
|
||||||
async with self._receipts_id_gen.get_next() as stream_id:
|
async with self._receipts_id_gen.get_next() as stream_id: # type: ignore[attr-defined]
|
||||||
event_ts = await self.db_pool.runInteraction(
|
event_ts = await self.db_pool.runInteraction(
|
||||||
"insert_linearized_receipt",
|
"insert_linearized_receipt",
|
||||||
self.insert_linearized_receipt_txn,
|
self.insert_linearized_receipt_txn,
|
||||||
|
|
|
@ -22,6 +22,7 @@ import attr
|
||||||
|
|
||||||
from synapse.api.constants import UserTypes
|
from synapse.api.constants import UserTypes
|
||||||
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
|
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
|
||||||
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||||
from synapse.storage.database import (
|
from synapse.storage.database import (
|
||||||
DatabasePool,
|
DatabasePool,
|
||||||
|
@ -123,7 +124,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
):
|
):
|
||||||
super().__init__(database, db_conn, hs)
|
super().__init__(database, db_conn, hs)
|
||||||
|
|
||||||
self.config = hs.config
|
self.config: HomeServerConfig = hs.config
|
||||||
|
|
||||||
# Note: we don't check this sequence for consistency as we'd have to
|
# Note: we don't check this sequence for consistency as we'd have to
|
||||||
# call `find_max_generated_user_id_localpart` each time, which is
|
# call `find_max_generated_user_id_localpart` each time, which is
|
||||||
|
|
|
@ -27,7 +27,6 @@ from typing import (
|
||||||
)
|
)
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
from frozendict import frozendict
|
|
||||||
|
|
||||||
from synapse.api.constants import RelationTypes
|
from synapse.api.constants import RelationTypes
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
|
@ -41,45 +40,15 @@ from synapse.storage.database import (
|
||||||
from synapse.storage.databases.main.stream import generate_pagination_where_clause
|
from synapse.storage.databases.main.stream import generate_pagination_where_clause
|
||||||
from synapse.storage.engines import PostgresEngine
|
from synapse.storage.engines import PostgresEngine
|
||||||
from synapse.storage.relations import AggregationPaginationToken, PaginationChunk
|
from synapse.storage.relations import AggregationPaginationToken, PaginationChunk
|
||||||
from synapse.types import JsonDict, RoomStreamToken, StreamToken
|
from synapse.types import RoomStreamToken, StreamToken
|
||||||
from synapse.util.caches.descriptors import cached, cachedList
|
from synapse.util.caches.descriptors import cached, cachedList
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.storage.databases.main import DataStore
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
|
||||||
class _ThreadAggregation:
|
|
||||||
# The latest event in the thread.
|
|
||||||
latest_event: EventBase
|
|
||||||
# The latest edit to the latest event in the thread.
|
|
||||||
latest_edit: Optional[EventBase]
|
|
||||||
# The total number of events in the thread.
|
|
||||||
count: int
|
|
||||||
# True if the current user has sent an event to the thread.
|
|
||||||
current_user_participated: bool
|
|
||||||
|
|
||||||
|
|
||||||
@attr.s(slots=True, auto_attribs=True)
|
|
||||||
class BundledAggregations:
|
|
||||||
"""
|
|
||||||
The bundled aggregations for an event.
|
|
||||||
|
|
||||||
Some values require additional processing during serialization.
|
|
||||||
"""
|
|
||||||
|
|
||||||
annotations: Optional[JsonDict] = None
|
|
||||||
references: Optional[JsonDict] = None
|
|
||||||
replace: Optional[EventBase] = None
|
|
||||||
thread: Optional[_ThreadAggregation] = None
|
|
||||||
|
|
||||||
def __bool__(self) -> bool:
|
|
||||||
return bool(self.annotations or self.references or self.replace or self.thread)
|
|
||||||
|
|
||||||
|
|
||||||
class RelationsWorkerStore(SQLBaseStore):
|
class RelationsWorkerStore(SQLBaseStore):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -384,7 +353,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@cachedList(cached_method_name="get_applicable_edit", list_name="event_ids")
|
@cachedList(cached_method_name="get_applicable_edit", list_name="event_ids")
|
||||||
async def _get_applicable_edits(
|
async def get_applicable_edits(
|
||||||
self, event_ids: Collection[str]
|
self, event_ids: Collection[str]
|
||||||
) -> Dict[str, Optional[EventBase]]:
|
) -> Dict[str, Optional[EventBase]]:
|
||||||
"""Get the most recent edit (if any) that has happened for the given
|
"""Get the most recent edit (if any) that has happened for the given
|
||||||
|
@ -473,7 +442,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@cachedList(cached_method_name="get_thread_summary", list_name="event_ids")
|
@cachedList(cached_method_name="get_thread_summary", list_name="event_ids")
|
||||||
async def _get_thread_summaries(
|
async def get_thread_summaries(
|
||||||
self, event_ids: Collection[str]
|
self, event_ids: Collection[str]
|
||||||
) -> Dict[str, Optional[Tuple[int, EventBase, Optional[EventBase]]]]:
|
) -> Dict[str, Optional[Tuple[int, EventBase, Optional[EventBase]]]]:
|
||||||
"""Get the number of threaded replies, the latest reply (if any), and the latest edit for that reply for the given event.
|
"""Get the number of threaded replies, the latest reply (if any), and the latest edit for that reply for the given event.
|
||||||
|
@ -587,7 +556,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||||
latest_events = await self.get_events(latest_event_ids.values()) # type: ignore[attr-defined]
|
latest_events = await self.get_events(latest_event_ids.values()) # type: ignore[attr-defined]
|
||||||
|
|
||||||
# Check to see if any of those events are edited.
|
# Check to see if any of those events are edited.
|
||||||
latest_edits = await self._get_applicable_edits(latest_event_ids.values())
|
latest_edits = await self.get_applicable_edits(latest_event_ids.values())
|
||||||
|
|
||||||
# Map to the event IDs to the thread summary.
|
# Map to the event IDs to the thread summary.
|
||||||
#
|
#
|
||||||
|
@ -610,7 +579,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@cachedList(cached_method_name="get_thread_participated", list_name="event_ids")
|
@cachedList(cached_method_name="get_thread_participated", list_name="event_ids")
|
||||||
async def _get_threads_participated(
|
async def get_threads_participated(
|
||||||
self, event_ids: Collection[str], user_id: str
|
self, event_ids: Collection[str], user_id: str
|
||||||
) -> Dict[str, bool]:
|
) -> Dict[str, bool]:
|
||||||
"""Get whether the requesting user participated in the given threads.
|
"""Get whether the requesting user participated in the given threads.
|
||||||
|
@ -766,114 +735,6 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||||
"get_if_user_has_annotated_event", _get_if_user_has_annotated_event
|
"get_if_user_has_annotated_event", _get_if_user_has_annotated_event
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _get_bundled_aggregation_for_event(
|
|
||||||
self, event: EventBase, user_id: str
|
|
||||||
) -> Optional[BundledAggregations]:
|
|
||||||
"""Generate bundled aggregations for an event.
|
|
||||||
|
|
||||||
Note that this does not use a cache, but depends on cached methods.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
event: The event to calculate bundled aggregations for.
|
|
||||||
user_id: The user requesting the bundled aggregations.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The bundled aggregations for an event, if bundled aggregations are
|
|
||||||
enabled and the event can have bundled aggregations.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Do not bundle aggregations for an event which represents an edit or an
|
|
||||||
# annotation. It does not make sense for them to have related events.
|
|
||||||
relates_to = event.content.get("m.relates_to")
|
|
||||||
if isinstance(relates_to, (dict, frozendict)):
|
|
||||||
relation_type = relates_to.get("rel_type")
|
|
||||||
if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
|
|
||||||
return None
|
|
||||||
|
|
||||||
event_id = event.event_id
|
|
||||||
room_id = event.room_id
|
|
||||||
|
|
||||||
# The bundled aggregations to include, a mapping of relation type to a
|
|
||||||
# type-specific value. Some types include the direct return type here
|
|
||||||
# while others need more processing during serialization.
|
|
||||||
aggregations = BundledAggregations()
|
|
||||||
|
|
||||||
annotations = await self.get_aggregation_groups_for_event(event_id, room_id)
|
|
||||||
if annotations.chunk:
|
|
||||||
aggregations.annotations = await annotations.to_dict(
|
|
||||||
cast("DataStore", self)
|
|
||||||
)
|
|
||||||
|
|
||||||
references = await self.get_relations_for_event(
|
|
||||||
event_id, event, room_id, RelationTypes.REFERENCE, direction="f"
|
|
||||||
)
|
|
||||||
if references.chunk:
|
|
||||||
aggregations.references = await references.to_dict(cast("DataStore", self))
|
|
||||||
|
|
||||||
# Store the bundled aggregations in the event metadata for later use.
|
|
||||||
return aggregations
|
|
||||||
|
|
||||||
async def get_bundled_aggregations(
|
|
||||||
self, events: Iterable[EventBase], user_id: str
|
|
||||||
) -> Dict[str, BundledAggregations]:
|
|
||||||
"""Generate bundled aggregations for events.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
events: The iterable of events to calculate bundled aggregations for.
|
|
||||||
user_id: The user requesting the bundled aggregations.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A map of event ID to the bundled aggregation for the event. Not all
|
|
||||||
events may have bundled aggregations in the results.
|
|
||||||
"""
|
|
||||||
# De-duplicate events by ID to handle the same event requested multiple times.
|
|
||||||
#
|
|
||||||
# State events do not get bundled aggregations.
|
|
||||||
events_by_id = {
|
|
||||||
event.event_id: event for event in events if not event.is_state()
|
|
||||||
}
|
|
||||||
|
|
||||||
# event ID -> bundled aggregation in non-serialized form.
|
|
||||||
results: Dict[str, BundledAggregations] = {}
|
|
||||||
|
|
||||||
# Fetch other relations per event.
|
|
||||||
for event in events_by_id.values():
|
|
||||||
event_result = await self._get_bundled_aggregation_for_event(event, user_id)
|
|
||||||
if event_result:
|
|
||||||
results[event.event_id] = event_result
|
|
||||||
|
|
||||||
# Fetch any edits (but not for redacted events).
|
|
||||||
edits = await self._get_applicable_edits(
|
|
||||||
[
|
|
||||||
event_id
|
|
||||||
for event_id, event in events_by_id.items()
|
|
||||||
if not event.internal_metadata.is_redacted()
|
|
||||||
]
|
|
||||||
)
|
|
||||||
for event_id, edit in edits.items():
|
|
||||||
results.setdefault(event_id, BundledAggregations()).replace = edit
|
|
||||||
|
|
||||||
# Fetch thread summaries.
|
|
||||||
summaries = await self._get_thread_summaries(events_by_id.keys())
|
|
||||||
# Only fetch participated for a limited selection based on what had
|
|
||||||
# summaries.
|
|
||||||
participated = await self._get_threads_participated(summaries.keys(), user_id)
|
|
||||||
for event_id, summary in summaries.items():
|
|
||||||
if summary:
|
|
||||||
thread_count, latest_thread_event, edit = summary
|
|
||||||
results.setdefault(
|
|
||||||
event_id, BundledAggregations()
|
|
||||||
).thread = _ThreadAggregation(
|
|
||||||
latest_event=latest_thread_event,
|
|
||||||
latest_edit=edit,
|
|
||||||
count=thread_count,
|
|
||||||
# If there's a thread summary it must also exist in the
|
|
||||||
# participated dictionary.
|
|
||||||
current_user_participated=participated[event_id],
|
|
||||||
)
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
class RelationsStore(RelationsWorkerStore):
|
class RelationsStore(RelationsWorkerStore):
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -34,6 +34,7 @@ import attr
|
||||||
from synapse.api.constants import EventContentFields, EventTypes, JoinRules
|
from synapse.api.constants import EventContentFields, EventTypes, JoinRules
|
||||||
from synapse.api.errors import StoreError
|
from synapse.api.errors import StoreError
|
||||||
from synapse.api.room_versions import RoomVersion, RoomVersions
|
from synapse.api.room_versions import RoomVersion, RoomVersions
|
||||||
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||||
from synapse.storage.database import (
|
from synapse.storage.database import (
|
||||||
|
@ -98,7 +99,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
||||||
):
|
):
|
||||||
super().__init__(database, db_conn, hs)
|
super().__init__(database, db_conn, hs)
|
||||||
|
|
||||||
self.config = hs.config
|
self.config: HomeServerConfig = hs.config
|
||||||
|
|
||||||
async def store_room(
|
async def store_room(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -63,6 +63,14 @@ _MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
|
||||||
_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
|
_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(frozen=True, slots=True, auto_attribs=True)
|
||||||
|
class EventIdMembership:
|
||||||
|
"""Returned by `get_membership_from_event_ids`"""
|
||||||
|
|
||||||
|
user_id: str
|
||||||
|
membership: str
|
||||||
|
|
||||||
|
|
||||||
class RoomMemberWorkerStore(EventsWorkerStore):
|
class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -772,7 +780,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
retcols=("user_id", "display_name", "avatar_url", "event_id"),
|
retcols=("user_id", "display_name", "avatar_url", "event_id"),
|
||||||
keyvalues={"membership": Membership.JOIN},
|
keyvalues={"membership": Membership.JOIN},
|
||||||
batch_size=500,
|
batch_size=500,
|
||||||
desc="_get_membership_from_event_ids",
|
desc="_get_joined_profiles_from_event_ids",
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
@ -1000,12 +1008,26 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
|
|
||||||
return set(room_ids)
|
return set(room_ids)
|
||||||
|
|
||||||
|
@cached(max_entries=5000)
|
||||||
|
async def _get_membership_from_event_id(
|
||||||
|
self, member_event_id: str
|
||||||
|
) -> Optional[EventIdMembership]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@cachedList(
|
||||||
|
cached_method_name="_get_membership_from_event_id", list_name="member_event_ids"
|
||||||
|
)
|
||||||
async def get_membership_from_event_ids(
|
async def get_membership_from_event_ids(
|
||||||
self, member_event_ids: Iterable[str]
|
self, member_event_ids: Iterable[str]
|
||||||
) -> List[dict]:
|
) -> Dict[str, Optional[EventIdMembership]]:
|
||||||
"""Get user_id and membership of a set of event IDs."""
|
"""Get user_id and membership of a set of event IDs.
|
||||||
|
|
||||||
return await self.db_pool.simple_select_many_batch(
|
Returns:
|
||||||
|
Mapping from event ID to `EventIdMembership` if the event is a
|
||||||
|
membership event, otherwise the value is None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
rows = await self.db_pool.simple_select_many_batch(
|
||||||
table="room_memberships",
|
table="room_memberships",
|
||||||
column="event_id",
|
column="event_id",
|
||||||
iterable=member_event_ids,
|
iterable=member_event_ids,
|
||||||
|
@ -1015,6 +1037,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
desc="get_membership_from_event_ids",
|
desc="get_membership_from_event_ids",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
row["event_id"]: EventIdMembership(
|
||||||
|
membership=row["membership"], user_id=row["user_id"]
|
||||||
|
)
|
||||||
|
for row in rows
|
||||||
|
}
|
||||||
|
|
||||||
async def is_local_host_in_room_ignoring_users(
|
async def is_local_host_in_room_ignoring_users(
|
||||||
self, room_id: str, ignore_users: Collection[str]
|
self, room_id: str, ignore_users: Collection[str]
|
||||||
) -> bool:
|
) -> bool:
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import TYPE_CHECKING, Collection, Iterable, List, Optional, Set
|
from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
|
@ -74,7 +74,7 @@ class SearchWorkerStore(SQLBaseStore):
|
||||||
" VALUES (?,?,?,to_tsvector('english', ?),?,?)"
|
" VALUES (?,?,?,to_tsvector('english', ?),?,?)"
|
||||||
)
|
)
|
||||||
|
|
||||||
args = (
|
args1 = (
|
||||||
(
|
(
|
||||||
entry.event_id,
|
entry.event_id,
|
||||||
entry.room_id,
|
entry.room_id,
|
||||||
|
@ -86,14 +86,14 @@ class SearchWorkerStore(SQLBaseStore):
|
||||||
for entry in entries
|
for entry in entries
|
||||||
)
|
)
|
||||||
|
|
||||||
txn.execute_batch(sql, args)
|
txn.execute_batch(sql, args1)
|
||||||
|
|
||||||
elif isinstance(self.database_engine, Sqlite3Engine):
|
elif isinstance(self.database_engine, Sqlite3Engine):
|
||||||
sql = (
|
sql = (
|
||||||
"INSERT INTO event_search (event_id, room_id, key, value)"
|
"INSERT INTO event_search (event_id, room_id, key, value)"
|
||||||
" VALUES (?,?,?,?)"
|
" VALUES (?,?,?,?)"
|
||||||
)
|
)
|
||||||
args = (
|
args2 = (
|
||||||
(
|
(
|
||||||
entry.event_id,
|
entry.event_id,
|
||||||
entry.room_id,
|
entry.room_id,
|
||||||
|
@ -102,7 +102,7 @@ class SearchWorkerStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
for entry in entries
|
for entry in entries
|
||||||
)
|
)
|
||||||
txn.execute_batch(sql, args)
|
txn.execute_batch(sql, args2)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# This should be unreachable.
|
# This should be unreachable.
|
||||||
|
@ -427,7 +427,7 @@ class SearchStore(SearchBackgroundUpdateStore):
|
||||||
|
|
||||||
search_query = _parse_query(self.database_engine, search_term)
|
search_query = _parse_query(self.database_engine, search_term)
|
||||||
|
|
||||||
args = []
|
args: List[Any] = []
|
||||||
|
|
||||||
# Make sure we don't explode because the person is in too many rooms.
|
# Make sure we don't explode because the person is in too many rooms.
|
||||||
# We filter the results below regardless.
|
# We filter the results below regardless.
|
||||||
|
@ -496,7 +496,7 @@ class SearchStore(SearchBackgroundUpdateStore):
|
||||||
|
|
||||||
# We set redact_behaviour to BLOCK here to prevent redacted events being returned in
|
# We set redact_behaviour to BLOCK here to prevent redacted events being returned in
|
||||||
# search results (which is a data leak)
|
# search results (which is a data leak)
|
||||||
events = await self.get_events_as_list(
|
events = await self.get_events_as_list( # type: ignore[attr-defined]
|
||||||
[r["event_id"] for r in results],
|
[r["event_id"] for r in results],
|
||||||
redact_behaviour=EventRedactBehaviour.BLOCK,
|
redact_behaviour=EventRedactBehaviour.BLOCK,
|
||||||
)
|
)
|
||||||
|
@ -530,7 +530,7 @@ class SearchStore(SearchBackgroundUpdateStore):
|
||||||
room_ids: Collection[str],
|
room_ids: Collection[str],
|
||||||
search_term: str,
|
search_term: str,
|
||||||
keys: Iterable[str],
|
keys: Iterable[str],
|
||||||
limit,
|
limit: int,
|
||||||
pagination_token: Optional[str] = None,
|
pagination_token: Optional[str] = None,
|
||||||
) -> JsonDict:
|
) -> JsonDict:
|
||||||
"""Performs a full text search over events with given keys.
|
"""Performs a full text search over events with given keys.
|
||||||
|
@ -549,7 +549,7 @@ class SearchStore(SearchBackgroundUpdateStore):
|
||||||
|
|
||||||
search_query = _parse_query(self.database_engine, search_term)
|
search_query = _parse_query(self.database_engine, search_term)
|
||||||
|
|
||||||
args = []
|
args: List[Any] = []
|
||||||
|
|
||||||
# Make sure we don't explode because the person is in too many rooms.
|
# Make sure we don't explode because the person is in too many rooms.
|
||||||
# We filter the results below regardless.
|
# We filter the results below regardless.
|
||||||
|
@ -573,9 +573,9 @@ class SearchStore(SearchBackgroundUpdateStore):
|
||||||
|
|
||||||
if pagination_token:
|
if pagination_token:
|
||||||
try:
|
try:
|
||||||
origin_server_ts, stream = pagination_token.split(",")
|
origin_server_ts_str, stream_str = pagination_token.split(",")
|
||||||
origin_server_ts = int(origin_server_ts)
|
origin_server_ts = int(origin_server_ts_str)
|
||||||
stream = int(stream)
|
stream = int(stream_str)
|
||||||
except Exception:
|
except Exception:
|
||||||
raise SynapseError(400, "Invalid pagination token")
|
raise SynapseError(400, "Invalid pagination token")
|
||||||
|
|
||||||
|
@ -654,7 +654,7 @@ class SearchStore(SearchBackgroundUpdateStore):
|
||||||
|
|
||||||
# We set redact_behaviour to BLOCK here to prevent redacted events being returned in
|
# We set redact_behaviour to BLOCK here to prevent redacted events being returned in
|
||||||
# search results (which is a data leak)
|
# search results (which is a data leak)
|
||||||
events = await self.get_events_as_list(
|
events = await self.get_events_as_list( # type: ignore[attr-defined]
|
||||||
[r["event_id"] for r in results],
|
[r["event_id"] for r in results],
|
||||||
redact_behaviour=EventRedactBehaviour.BLOCK,
|
redact_behaviour=EventRedactBehaviour.BLOCK,
|
||||||
)
|
)
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import collections.abc
|
import collections.abc
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Iterable, Optional, Set
|
from typing import TYPE_CHECKING, Collection, Iterable, Optional, Set, Tuple
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, Membership
|
from synapse.api.constants import EventTypes, Membership
|
||||||
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
|
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
|
||||||
|
@ -29,7 +29,7 @@ from synapse.storage.database import (
|
||||||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||||
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
|
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
|
||||||
from synapse.storage.state import StateFilter
|
from synapse.storage.state import StateFilter
|
||||||
from synapse.types import StateMap
|
from synapse.types import JsonDict, StateMap
|
||||||
from synapse.util.caches import intern_string
|
from synapse.util.caches import intern_string
|
||||||
from synapse.util.caches.descriptors import cached, cachedList
|
from synapse.util.caches.descriptors import cached, cachedList
|
||||||
|
|
||||||
|
@ -241,7 +241,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
# We delegate to the cached version
|
# We delegate to the cached version
|
||||||
return await self.get_current_state_ids(room_id)
|
return await self.get_current_state_ids(room_id)
|
||||||
|
|
||||||
def _get_filtered_current_state_ids_txn(txn):
|
def _get_filtered_current_state_ids_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> StateMap[str]:
|
||||||
results = {}
|
results = {}
|
||||||
sql = """
|
sql = """
|
||||||
SELECT type, state_key, event_id FROM current_state_events
|
SELECT type, state_key, event_id FROM current_state_events
|
||||||
|
@ -281,11 +283,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
|
|
||||||
event_id = state.get((EventTypes.CanonicalAlias, ""))
|
event_id = state.get((EventTypes.CanonicalAlias, ""))
|
||||||
if not event_id:
|
if not event_id:
|
||||||
return
|
return None
|
||||||
|
|
||||||
event = await self.get_event(event_id, allow_none=True)
|
event = await self.get_event(event_id, allow_none=True)
|
||||||
if not event:
|
if not event:
|
||||||
return
|
return None
|
||||||
|
|
||||||
return event.content.get("canonical_alias")
|
return event.content.get("canonical_alias")
|
||||||
|
|
||||||
|
@ -304,7 +306,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
list_name="event_ids",
|
list_name="event_ids",
|
||||||
num_args=1,
|
num_args=1,
|
||||||
)
|
)
|
||||||
async def _get_state_group_for_events(self, event_ids):
|
async def _get_state_group_for_events(self, event_ids: Collection[str]) -> JsonDict:
|
||||||
"""Returns mapping event_id -> state_group"""
|
"""Returns mapping event_id -> state_group"""
|
||||||
rows = await self.db_pool.simple_select_many_batch(
|
rows = await self.db_pool.simple_select_many_batch(
|
||||||
table="event_to_state_groups",
|
table="event_to_state_groups",
|
||||||
|
@ -355,7 +357,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
|
||||||
):
|
):
|
||||||
super().__init__(database, db_conn, hs)
|
super().__init__(database, db_conn, hs)
|
||||||
|
|
||||||
self.server_name = hs.hostname
|
self.server_name: str = hs.hostname
|
||||||
|
|
||||||
self.db_pool.updates.register_background_index_update(
|
self.db_pool.updates.register_background_index_update(
|
||||||
self.CURRENT_STATE_INDEX_UPDATE_NAME,
|
self.CURRENT_STATE_INDEX_UPDATE_NAME,
|
||||||
|
@ -375,7 +377,9 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
|
||||||
self._background_remove_left_rooms,
|
self._background_remove_left_rooms,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _background_remove_left_rooms(self, progress, batch_size):
|
async def _background_remove_left_rooms(
|
||||||
|
self, progress: JsonDict, batch_size: int
|
||||||
|
) -> int:
|
||||||
"""Background update to delete rows from `current_state_events` and
|
"""Background update to delete rows from `current_state_events` and
|
||||||
`event_forward_extremities` tables of rooms that the server is no
|
`event_forward_extremities` tables of rooms that the server is no
|
||||||
longer joined to.
|
longer joined to.
|
||||||
|
@ -383,7 +387,9 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
|
||||||
|
|
||||||
last_room_id = progress.get("last_room_id", "")
|
last_room_id = progress.get("last_room_id", "")
|
||||||
|
|
||||||
def _background_remove_left_rooms_txn(txn):
|
def _background_remove_left_rooms_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> Tuple[bool, Set[str]]:
|
||||||
# get a batch of room ids to consider
|
# get a batch of room ids to consider
|
||||||
sql = """
|
sql = """
|
||||||
SELECT DISTINCT room_id FROM current_state_events
|
SELECT DISTINCT room_id FROM current_state_events
|
||||||
|
|
|
@ -108,7 +108,7 @@ class StatsStore(StateDeltasStore):
|
||||||
):
|
):
|
||||||
super().__init__(database, db_conn, hs)
|
super().__init__(database, db_conn, hs)
|
||||||
|
|
||||||
self.server_name = hs.hostname
|
self.server_name: str = hs.hostname
|
||||||
self.clock = self.hs.get_clock()
|
self.clock = self.hs.get_clock()
|
||||||
self.stats_enabled = hs.config.stats.stats_enabled
|
self.stats_enabled = hs.config.stats.stats_enabled
|
||||||
|
|
||||||
|
|
|
@ -26,6 +26,8 @@ from typing import (
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from synapse.api.errors import StoreError
|
from synapse.api.errors import StoreError
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -40,7 +42,12 @@ from synapse.storage.database import (
|
||||||
from synapse.storage.databases.main.state import StateFilter
|
from synapse.storage.databases.main.state import StateFilter
|
||||||
from synapse.storage.databases.main.state_deltas import StateDeltasStore
|
from synapse.storage.databases.main.state_deltas import StateDeltasStore
|
||||||
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
|
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
|
||||||
from synapse.types import JsonDict, get_domain_from_id, get_localpart_from_id
|
from synapse.types import (
|
||||||
|
JsonDict,
|
||||||
|
UserProfile,
|
||||||
|
get_domain_from_id,
|
||||||
|
get_localpart_from_id,
|
||||||
|
)
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import cached
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -61,7 +68,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(database, db_conn, hs)
|
super().__init__(database, db_conn, hs)
|
||||||
|
|
||||||
self.server_name = hs.hostname
|
self.server_name: str = hs.hostname
|
||||||
|
|
||||||
self.db_pool.updates.register_background_update_handler(
|
self.db_pool.updates.register_background_update_handler(
|
||||||
"populate_user_directory_createtables",
|
"populate_user_directory_createtables",
|
||||||
|
@ -591,6 +598,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SearchResult(TypedDict):
|
||||||
|
limited: bool
|
||||||
|
results: List[UserProfile]
|
||||||
|
|
||||||
|
|
||||||
class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
|
class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
|
||||||
# How many records do we calculate before sending it to
|
# How many records do we calculate before sending it to
|
||||||
# add_users_who_share_private_rooms?
|
# add_users_who_share_private_rooms?
|
||||||
|
@ -718,7 +730,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
|
||||||
users.update(rows)
|
users.update(rows)
|
||||||
return list(users)
|
return list(users)
|
||||||
|
|
||||||
async def get_shared_rooms_for_users(
|
async def get_mutual_rooms_for_users(
|
||||||
self, user_id: str, other_user_id: str
|
self, user_id: str, other_user_id: str
|
||||||
) -> Set[str]:
|
) -> Set[str]:
|
||||||
"""
|
"""
|
||||||
|
@ -732,7 +744,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
|
||||||
A set of room ID's that the users share.
|
A set of room ID's that the users share.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _get_shared_rooms_for_users_txn(
|
def _get_mutual_rooms_for_users_txn(
|
||||||
txn: LoggingTransaction,
|
txn: LoggingTransaction,
|
||||||
) -> List[Dict[str, str]]:
|
) -> List[Dict[str, str]]:
|
||||||
txn.execute(
|
txn.execute(
|
||||||
|
@ -756,7 +768,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
|
||||||
return rows
|
return rows
|
||||||
|
|
||||||
rows = await self.db_pool.runInteraction(
|
rows = await self.db_pool.runInteraction(
|
||||||
"get_shared_rooms_for_users", _get_shared_rooms_for_users_txn
|
"get_mutual_rooms_for_users", _get_mutual_rooms_for_users_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
return {row["room_id"] for row in rows}
|
return {row["room_id"] for row in rows}
|
||||||
|
@ -777,7 +789,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
|
||||||
|
|
||||||
async def search_user_dir(
|
async def search_user_dir(
|
||||||
self, user_id: str, search_term: str, limit: int
|
self, user_id: str, search_term: str, limit: int
|
||||||
) -> JsonDict:
|
) -> SearchResult:
|
||||||
"""Searches for users in directory
|
"""Searches for users in directory
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -910,8 +922,11 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
|
||||||
# This should be unreachable.
|
# This should be unreachable.
|
||||||
raise Exception("Unrecognized database engine")
|
raise Exception("Unrecognized database engine")
|
||||||
|
|
||||||
results = await self.db_pool.execute(
|
results = cast(
|
||||||
|
List[UserProfile],
|
||||||
|
await self.db_pool.execute(
|
||||||
"search_user_dir", self.db_pool.cursor_to_dict, sql, *args
|
"search_user_dir", self.db_pool.cursor_to_dict, sql, *args
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
limited = len(results) > limit
|
limited = len(results) > limit
|
||||||
|
|
|
@ -27,7 +27,7 @@ def create_engine(database_config) -> BaseDatabaseEngine:
|
||||||
|
|
||||||
if name == "psycopg2":
|
if name == "psycopg2":
|
||||||
# Note that psycopg2cffi-compat provides the psycopg2 module on pypy.
|
# Note that psycopg2cffi-compat provides the psycopg2 module on pypy.
|
||||||
import psycopg2 # type: ignore
|
import psycopg2
|
||||||
|
|
||||||
return PostgresEngine(psycopg2, database_config)
|
return PostgresEngine(psycopg2, database_config)
|
||||||
|
|
||||||
|
|
|
@ -47,17 +47,26 @@ class PostgresEngine(BaseDatabaseEngine):
|
||||||
self.default_isolation_level = (
|
self.default_isolation_level = (
|
||||||
self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
|
self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
|
||||||
)
|
)
|
||||||
|
self.config = database_config
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def single_threaded(self) -> bool:
|
def single_threaded(self) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def get_db_locale(self, txn):
|
||||||
|
txn.execute(
|
||||||
|
"SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()"
|
||||||
|
)
|
||||||
|
collation, ctype = txn.fetchone()
|
||||||
|
return collation, ctype
|
||||||
|
|
||||||
def check_database(self, db_conn, allow_outdated_version: bool = False):
|
def check_database(self, db_conn, allow_outdated_version: bool = False):
|
||||||
# Get the version of PostgreSQL that we're using. As per the psycopg2
|
# Get the version of PostgreSQL that we're using. As per the psycopg2
|
||||||
# docs: The number is formed by converting the major, minor, and
|
# docs: The number is formed by converting the major, minor, and
|
||||||
# revision numbers into two-decimal-digit numbers and appending them
|
# revision numbers into two-decimal-digit numbers and appending them
|
||||||
# together. For example, version 8.1.5 will be returned as 80105
|
# together. For example, version 8.1.5 will be returned as 80105
|
||||||
self._version = db_conn.server_version
|
self._version = db_conn.server_version
|
||||||
|
allow_unsafe_locale = self.config.get("allow_unsafe_locale", False)
|
||||||
|
|
||||||
# Are we on a supported PostgreSQL version?
|
# Are we on a supported PostgreSQL version?
|
||||||
if not allow_outdated_version and self._version < 100000:
|
if not allow_outdated_version and self._version < 100000:
|
||||||
|
@ -72,21 +81,30 @@ class PostgresEngine(BaseDatabaseEngine):
|
||||||
"See docs/postgres.md for more information." % (rows[0][0],)
|
"See docs/postgres.md for more information." % (rows[0][0],)
|
||||||
)
|
)
|
||||||
|
|
||||||
txn.execute(
|
collation, ctype = self.get_db_locale(txn)
|
||||||
"SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()"
|
|
||||||
)
|
|
||||||
collation, ctype = txn.fetchone()
|
|
||||||
if collation != "C":
|
if collation != "C":
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
"Database has incorrect collation of %r. Should be 'C'",
|
||||||
|
collation,
|
||||||
|
)
|
||||||
|
if not allow_unsafe_locale:
|
||||||
|
raise IncorrectDatabaseSetup(
|
||||||
"Database has incorrect collation of %r. Should be 'C'\n"
|
"Database has incorrect collation of %r. Should be 'C'\n"
|
||||||
"See docs/postgres.md for more information.",
|
"See docs/postgres.md for more information. You can override this check by"
|
||||||
|
"setting 'allow_unsafe_locale' to true in the database config.",
|
||||||
collation,
|
collation,
|
||||||
)
|
)
|
||||||
|
|
||||||
if ctype != "C":
|
if ctype != "C":
|
||||||
|
if not allow_unsafe_locale:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
"Database has incorrect ctype of %r. Should be 'C'",
|
||||||
|
ctype,
|
||||||
|
)
|
||||||
|
raise IncorrectDatabaseSetup(
|
||||||
"Database has incorrect ctype of %r. Should be 'C'\n"
|
"Database has incorrect ctype of %r. Should be 'C'\n"
|
||||||
"See docs/postgres.md for more information.",
|
"See docs/postgres.md for more information. You can override this check by"
|
||||||
|
"setting 'allow_unsafe_locale' to true in the database config.",
|
||||||
ctype,
|
ctype,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -95,10 +113,7 @@ class PostgresEngine(BaseDatabaseEngine):
|
||||||
apply stricter checks on new databases versus existing database.
|
apply stricter checks on new databases versus existing database.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
txn.execute(
|
collation, ctype = self.get_db_locale(txn)
|
||||||
"SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()"
|
|
||||||
)
|
|
||||||
collation, ctype = txn.fetchone()
|
|
||||||
|
|
||||||
errors = []
|
errors = []
|
||||||
|
|
||||||
|
|
|
@ -1023,8 +1023,13 @@ class EventsPersistenceStorage:
|
||||||
|
|
||||||
# Check if any of the changes that we don't have events for are joins.
|
# Check if any of the changes that we don't have events for are joins.
|
||||||
if events_to_check:
|
if events_to_check:
|
||||||
rows = await self.main_store.get_membership_from_event_ids(events_to_check)
|
members = await self.main_store.get_membership_from_event_ids(
|
||||||
is_still_joined = any(row["membership"] == Membership.JOIN for row in rows)
|
events_to_check
|
||||||
|
)
|
||||||
|
is_still_joined = any(
|
||||||
|
member and member.membership == Membership.JOIN
|
||||||
|
for member in members.values()
|
||||||
|
)
|
||||||
if is_still_joined:
|
if is_still_joined:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -1060,9 +1065,11 @@ class EventsPersistenceStorage:
|
||||||
), event_id in current_state.items()
|
), event_id in current_state.items()
|
||||||
if typ == EventTypes.Member and not self.is_mine_id(state_key)
|
if typ == EventTypes.Member and not self.is_mine_id(state_key)
|
||||||
]
|
]
|
||||||
rows = await self.main_store.get_membership_from_event_ids(remote_event_ids)
|
members = await self.main_store.get_membership_from_event_ids(remote_event_ids)
|
||||||
potentially_left_users.update(
|
potentially_left_users.update(
|
||||||
row["user_id"] for row in rows if row["membership"] == Membership.JOIN
|
member.user_id
|
||||||
|
for member in members.values()
|
||||||
|
if member and member.membership == Membership.JOIN
|
||||||
)
|
)
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
|
@ -34,6 +34,7 @@ from typing import (
|
||||||
import attr
|
import attr
|
||||||
from frozendict import frozendict
|
from frozendict import frozendict
|
||||||
from signedjson.key import decode_verify_key_bytes
|
from signedjson.key import decode_verify_key_bytes
|
||||||
|
from typing_extensions import TypedDict
|
||||||
from unpaddedbase64 import decode_base64
|
from unpaddedbase64 import decode_base64
|
||||||
from zope.interface import Interface
|
from zope.interface import Interface
|
||||||
|
|
||||||
|
@ -63,6 +64,10 @@ MutableStateMap = MutableMapping[StateKey, T]
|
||||||
# JSON types. These could be made stronger, but will do for now.
|
# JSON types. These could be made stronger, but will do for now.
|
||||||
# A JSON-serialisable dict.
|
# A JSON-serialisable dict.
|
||||||
JsonDict = Dict[str, Any]
|
JsonDict = Dict[str, Any]
|
||||||
|
# A JSON-serialisable mapping; roughly speaking an immutable JSONDict.
|
||||||
|
# Useful when you have a TypedDict which isn't going to be mutated and you don't want
|
||||||
|
# to cast to JsonDict everywhere.
|
||||||
|
JsonMapping = Mapping[str, Any]
|
||||||
# A JSON-serialisable object.
|
# A JSON-serialisable object.
|
||||||
JsonSerializable = object
|
JsonSerializable = object
|
||||||
|
|
||||||
|
@ -791,3 +796,9 @@ class UserInfo:
|
||||||
is_deactivated: bool
|
is_deactivated: bool
|
||||||
is_guest: bool
|
is_guest: bool
|
||||||
is_shadow_banned: bool
|
is_shadow_banned: bool
|
||||||
|
|
||||||
|
|
||||||
|
class UserProfile(TypedDict):
|
||||||
|
user_id: str
|
||||||
|
display_name: Optional[str]
|
||||||
|
avatar_url: Optional[str]
|
||||||
|
|
|
@ -128,6 +128,19 @@ def _incorrect_version(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _no_reported_version(requirement: Requirement, extra: Optional[str] = None) -> str:
|
||||||
|
if extra:
|
||||||
|
return (
|
||||||
|
f"Synapse {VERSION} needs {requirement} for {extra}, "
|
||||||
|
f"but can't determine {requirement.name}'s version"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return (
|
||||||
|
f"Synapse {VERSION} needs {requirement}, "
|
||||||
|
f"but can't determine {requirement.name}'s version"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def check_requirements(extra: Optional[str] = None) -> None:
|
def check_requirements(extra: Optional[str] = None) -> None:
|
||||||
"""Check Synapse's dependencies are present and correctly versioned.
|
"""Check Synapse's dependencies are present and correctly versioned.
|
||||||
|
|
||||||
|
@ -163,8 +176,17 @@ def check_requirements(extra: Optional[str] = None) -> None:
|
||||||
deps_unfulfilled.append(requirement.name)
|
deps_unfulfilled.append(requirement.name)
|
||||||
errors.append(_not_installed(requirement, extra))
|
errors.append(_not_installed(requirement, extra))
|
||||||
else:
|
else:
|
||||||
|
if dist.version is None:
|
||||||
|
# This shouldn't happen---it suggests a borked virtualenv. (See #12223)
|
||||||
|
# Try to give a vaguely helpful error message anyway.
|
||||||
|
# Type-ignore: the annotations don't reflect reality: see
|
||||||
|
# https://github.com/python/typeshed/issues/7513
|
||||||
|
# https://bugs.python.org/issue47060
|
||||||
|
deps_unfulfilled.append(requirement.name) # type: ignore[unreachable]
|
||||||
|
errors.append(_no_reported_version(requirement, extra))
|
||||||
|
|
||||||
# We specify prereleases=True to allow prereleases such as RCs.
|
# We specify prereleases=True to allow prereleases such as RCs.
|
||||||
if not requirement.specifier.contains(dist.version, prereleases=True):
|
elif not requirement.specifier.contains(dist.version, prereleases=True):
|
||||||
deps_unfulfilled.append(requirement.name)
|
deps_unfulfilled.append(requirement.name)
|
||||||
errors.append(_incorrect_version(requirement, dist.version, extra))
|
errors.append(_incorrect_version(requirement, dist.version, extra))
|
||||||
|
|
||||||
|
|
|
@ -14,12 +14,7 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, FrozenSet, List, Optional
|
from typing import Dict, FrozenSet, List, Optional
|
||||||
|
|
||||||
from synapse.api.constants import (
|
from synapse.api.constants import EventTypes, HistoryVisibility, Membership
|
||||||
AccountDataTypes,
|
|
||||||
EventTypes,
|
|
||||||
HistoryVisibility,
|
|
||||||
Membership,
|
|
||||||
)
|
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.events.utils import prune_event
|
from synapse.events.utils import prune_event
|
||||||
from synapse.storage import Storage
|
from synapse.storage import Storage
|
||||||
|
@ -87,15 +82,8 @@ async def filter_events_for_client(
|
||||||
state_filter=StateFilter.from_types(types),
|
state_filter=StateFilter.from_types(types),
|
||||||
)
|
)
|
||||||
|
|
||||||
ignore_dict_content = await storage.main.get_global_account_data_by_type_for_user(
|
# Get the users who are ignored by the requesting user.
|
||||||
user_id, AccountDataTypes.IGNORED_USER_LIST
|
ignore_list = await storage.main.ignored_users(user_id)
|
||||||
)
|
|
||||||
|
|
||||||
ignore_list: FrozenSet[str] = frozenset()
|
|
||||||
if ignore_dict_content:
|
|
||||||
ignored_users_dict = ignore_dict_content.get("ignored_users", {})
|
|
||||||
if isinstance(ignored_users_dict, dict):
|
|
||||||
ignore_list = frozenset(ignored_users_dict.keys())
|
|
||||||
|
|
||||||
erased_senders = await storage.main.are_users_erased(e.sender for e in events)
|
erased_senders = await storage.main.are_users_erased(e.sender for e in events)
|
||||||
|
|
||||||
|
|
|
@ -11,14 +11,16 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import synapse.app.homeserver
|
||||||
from synapse.config import ConfigError
|
from synapse.config import ConfigError
|
||||||
from synapse.config.homeserver import HomeServerConfig
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
|
|
||||||
from tests.unittest import TestCase
|
from tests.config.utils import ConfigFileTestCase
|
||||||
from tests.utils import default_config
|
from tests.utils import default_config
|
||||||
|
|
||||||
|
|
||||||
class RegistrationConfigTestCase(TestCase):
|
class RegistrationConfigTestCase(ConfigFileTestCase):
|
||||||
def test_session_lifetime_must_not_be_exceeded_by_smaller_lifetimes(self):
|
def test_session_lifetime_must_not_be_exceeded_by_smaller_lifetimes(self):
|
||||||
"""
|
"""
|
||||||
session_lifetime should logically be larger than, or at least as large as,
|
session_lifetime should logically be larger than, or at least as large as,
|
||||||
|
@ -76,3 +78,19 @@ class RegistrationConfigTestCase(TestCase):
|
||||||
HomeServerConfig().parse_config_dict(
|
HomeServerConfig().parse_config_dict(
|
||||||
{"session_lifetime": "31m", "refresh_token_lifetime": "31m", **config_dict}
|
{"session_lifetime": "31m", "refresh_token_lifetime": "31m", **config_dict}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_refuse_to_start_if_open_registration_and_no_verification(self):
|
||||||
|
self.generate_config()
|
||||||
|
self.add_lines_to_config(
|
||||||
|
[
|
||||||
|
" ",
|
||||||
|
"enable_registration: true",
|
||||||
|
"registrations_require_3pid: []",
|
||||||
|
"enable_registration_captcha: false",
|
||||||
|
"registration_requires_token: false",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test that allowing open registration without verification raises an error
|
||||||
|
with self.assertRaises(ConfigError):
|
||||||
|
synapse.app.homeserver.setup(["-c", self.config_file])
|
||||||
|
|
|
@ -11,9 +11,14 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from typing import Any, Dict
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.handlers.cas import CasResponse
|
from synapse.handlers.cas import CasResponse
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests.test_utils import simple_async_mock
|
from tests.test_utils import simple_async_mock
|
||||||
from tests.unittest import HomeserverTestCase, override_config
|
from tests.unittest import HomeserverTestCase, override_config
|
||||||
|
@ -24,7 +29,7 @@ SERVER_URL = "https://issuer/"
|
||||||
|
|
||||||
|
|
||||||
class CasHandlerTestCase(HomeserverTestCase):
|
class CasHandlerTestCase(HomeserverTestCase):
|
||||||
def default_config(self):
|
def default_config(self) -> Dict[str, Any]:
|
||||||
config = super().default_config()
|
config = super().default_config()
|
||||||
config["public_baseurl"] = BASE_URL
|
config["public_baseurl"] = BASE_URL
|
||||||
cas_config = {
|
cas_config = {
|
||||||
|
@ -40,7 +45,7 @@ class CasHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
hs = self.setup_test_homeserver()
|
hs = self.setup_test_homeserver()
|
||||||
|
|
||||||
self.handler = hs.get_cas_handler()
|
self.handler = hs.get_cas_handler()
|
||||||
|
@ -51,7 +56,7 @@ class CasHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
return hs
|
return hs
|
||||||
|
|
||||||
def test_map_cas_user_to_user(self):
|
def test_map_cas_user_to_user(self) -> None:
|
||||||
"""Ensure that mapping the CAS user returned from a provider to an MXID works properly."""
|
"""Ensure that mapping the CAS user returned from a provider to an MXID works properly."""
|
||||||
|
|
||||||
# stub out the auth handler
|
# stub out the auth handler
|
||||||
|
@ -75,7 +80,7 @@ class CasHandlerTestCase(HomeserverTestCase):
|
||||||
auth_provider_session_id=None,
|
auth_provider_session_id=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_map_cas_user_to_existing_user(self):
|
def test_map_cas_user_to_existing_user(self) -> None:
|
||||||
"""Existing users can log in with CAS account."""
|
"""Existing users can log in with CAS account."""
|
||||||
store = self.hs.get_datastores().main
|
store = self.hs.get_datastores().main
|
||||||
self.get_success(
|
self.get_success(
|
||||||
|
@ -119,7 +124,7 @@ class CasHandlerTestCase(HomeserverTestCase):
|
||||||
auth_provider_session_id=None,
|
auth_provider_session_id=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_map_cas_user_to_invalid_localpart(self):
|
def test_map_cas_user_to_invalid_localpart(self) -> None:
|
||||||
"""CAS automaps invalid characters to base-64 encoding."""
|
"""CAS automaps invalid characters to base-64 encoding."""
|
||||||
|
|
||||||
# stub out the auth handler
|
# stub out the auth handler
|
||||||
|
@ -150,7 +155,7 @@ class CasHandlerTestCase(HomeserverTestCase):
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_required_attributes(self):
|
def test_required_attributes(self) -> None:
|
||||||
"""The required attributes must be met from the CAS response."""
|
"""The required attributes must be met from the CAS response."""
|
||||||
|
|
||||||
# stub out the auth handler
|
# stub out the auth handler
|
||||||
|
@ -166,7 +171,7 @@ class CasHandlerTestCase(HomeserverTestCase):
|
||||||
auth_handler.complete_sso_login.assert_not_called()
|
auth_handler.complete_sso_login.assert_not_called()
|
||||||
|
|
||||||
# The response doesn't have any department.
|
# The response doesn't have any department.
|
||||||
cas_response = CasResponse("test_user", {"userGroup": "staff"})
|
cas_response = CasResponse("test_user", {"userGroup": ["staff"]})
|
||||||
request.reset_mock()
|
request.reset_mock()
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
|
self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
|
||||||
|
|
|
@ -12,14 +12,18 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from typing import Any, Awaitable, Callable, Dict
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
import synapse.api.errors
|
import synapse.api.errors
|
||||||
import synapse.rest.admin
|
import synapse.rest.admin
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.rest.client import directory, login, room
|
from synapse.rest.client import directory, login, room
|
||||||
from synapse.types import RoomAlias, create_requester
|
from synapse.server import HomeServer
|
||||||
|
from synapse.types import JsonDict, RoomAlias, create_requester
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.test_utils import make_awaitable
|
from tests.test_utils import make_awaitable
|
||||||
|
@ -28,13 +32,15 @@ from tests.test_utils import make_awaitable
|
||||||
class DirectoryTestCase(unittest.HomeserverTestCase):
|
class DirectoryTestCase(unittest.HomeserverTestCase):
|
||||||
"""Tests the directory service."""
|
"""Tests the directory service."""
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
self.mock_federation = Mock()
|
self.mock_federation = Mock()
|
||||||
self.mock_registry = Mock()
|
self.mock_registry = Mock()
|
||||||
|
|
||||||
self.query_handlers = {}
|
self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {}
|
||||||
|
|
||||||
def register_query_handler(query_type, handler):
|
def register_query_handler(
|
||||||
|
query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
|
||||||
|
) -> None:
|
||||||
self.query_handlers[query_type] = handler
|
self.query_handlers[query_type] = handler
|
||||||
|
|
||||||
self.mock_registry.register_query_handler = register_query_handler
|
self.mock_registry.register_query_handler = register_query_handler
|
||||||
|
@ -54,7 +60,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
return hs
|
return hs
|
||||||
|
|
||||||
def test_get_local_association(self):
|
def test_get_local_association(self) -> None:
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.store.create_room_alias_association(
|
self.store.create_room_alias_association(
|
||||||
self.my_room, "!8765qwer:test", ["test"]
|
self.my_room, "!8765qwer:test", ["test"]
|
||||||
|
@ -65,7 +71,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.assertEqual({"room_id": "!8765qwer:test", "servers": ["test"]}, result)
|
self.assertEqual({"room_id": "!8765qwer:test", "servers": ["test"]}, result)
|
||||||
|
|
||||||
def test_get_remote_association(self):
|
def test_get_remote_association(self) -> None:
|
||||||
self.mock_federation.make_query.return_value = make_awaitable(
|
self.mock_federation.make_query.return_value = make_awaitable(
|
||||||
{"room_id": "!8765qwer:test", "servers": ["test", "remote"]}
|
{"room_id": "!8765qwer:test", "servers": ["test", "remote"]}
|
||||||
)
|
)
|
||||||
|
@ -83,7 +89,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
|
||||||
ignore_backoff=True,
|
ignore_backoff=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_incoming_fed_query(self):
|
def test_incoming_fed_query(self) -> None:
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.store.create_room_alias_association(
|
self.store.create_room_alias_association(
|
||||||
self.your_room, "!8765asdf:test", ["test"]
|
self.your_room, "!8765asdf:test", ["test"]
|
||||||
|
@ -105,7 +111,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
|
||||||
directory.register_servlets,
|
directory.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.handler = hs.get_directory_handler()
|
self.handler = hs.get_directory_handler()
|
||||||
|
|
||||||
# Create user
|
# Create user
|
||||||
|
@ -125,7 +131,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
|
||||||
self.test_user_tok = self.login("user", "pass")
|
self.test_user_tok = self.login("user", "pass")
|
||||||
self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok)
|
self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok)
|
||||||
|
|
||||||
def test_create_alias_joined_room(self):
|
def test_create_alias_joined_room(self) -> None:
|
||||||
"""A user can create an alias for a room they're in."""
|
"""A user can create an alias for a room they're in."""
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.handler.create_association(
|
self.handler.create_association(
|
||||||
|
@ -135,7 +141,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_create_alias_other_room(self):
|
def test_create_alias_other_room(self) -> None:
|
||||||
"""A user cannot create an alias for a room they're NOT in."""
|
"""A user cannot create an alias for a room they're NOT in."""
|
||||||
other_room_id = self.helper.create_room_as(
|
other_room_id = self.helper.create_room_as(
|
||||||
self.admin_user, tok=self.admin_user_tok
|
self.admin_user, tok=self.admin_user_tok
|
||||||
|
@ -150,7 +156,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
|
||||||
synapse.api.errors.SynapseError,
|
synapse.api.errors.SynapseError,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_create_alias_admin(self):
|
def test_create_alias_admin(self) -> None:
|
||||||
"""An admin can create an alias for a room they're NOT in."""
|
"""An admin can create an alias for a room they're NOT in."""
|
||||||
other_room_id = self.helper.create_room_as(
|
other_room_id = self.helper.create_room_as(
|
||||||
self.test_user, tok=self.test_user_tok
|
self.test_user, tok=self.test_user_tok
|
||||||
|
@ -173,7 +179,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
|
||||||
directory.register_servlets,
|
directory.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.handler = hs.get_directory_handler()
|
self.handler = hs.get_directory_handler()
|
||||||
self.state_handler = hs.get_state_handler()
|
self.state_handler = hs.get_state_handler()
|
||||||
|
@ -195,7 +201,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
|
||||||
self.test_user_tok = self.login("user", "pass")
|
self.test_user_tok = self.login("user", "pass")
|
||||||
self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok)
|
self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok)
|
||||||
|
|
||||||
def _create_alias(self, user):
|
def _create_alias(self, user) -> None:
|
||||||
# Create a new alias to this room.
|
# Create a new alias to this room.
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.store.create_room_alias_association(
|
self.store.create_room_alias_association(
|
||||||
|
@ -203,7 +209,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_delete_alias_not_allowed(self):
|
def test_delete_alias_not_allowed(self) -> None:
|
||||||
"""A user that doesn't meet the expected guidelines cannot delete an alias."""
|
"""A user that doesn't meet the expected guidelines cannot delete an alias."""
|
||||||
self._create_alias(self.admin_user)
|
self._create_alias(self.admin_user)
|
||||||
self.get_failure(
|
self.get_failure(
|
||||||
|
@ -213,7 +219,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
|
||||||
synapse.api.errors.AuthError,
|
synapse.api.errors.AuthError,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_delete_alias_creator(self):
|
def test_delete_alias_creator(self) -> None:
|
||||||
"""An alias creator can delete their own alias."""
|
"""An alias creator can delete their own alias."""
|
||||||
# Create an alias from a different user.
|
# Create an alias from a different user.
|
||||||
self._create_alias(self.test_user)
|
self._create_alias(self.test_user)
|
||||||
|
@ -232,7 +238,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
|
||||||
synapse.api.errors.SynapseError,
|
synapse.api.errors.SynapseError,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_delete_alias_admin(self):
|
def test_delete_alias_admin(self) -> None:
|
||||||
"""A server admin can delete an alias created by another user."""
|
"""A server admin can delete an alias created by another user."""
|
||||||
# Create an alias from a different user.
|
# Create an alias from a different user.
|
||||||
self._create_alias(self.test_user)
|
self._create_alias(self.test_user)
|
||||||
|
@ -251,7 +257,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
|
||||||
synapse.api.errors.SynapseError,
|
synapse.api.errors.SynapseError,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_delete_alias_sufficient_power(self):
|
def test_delete_alias_sufficient_power(self) -> None:
|
||||||
"""A user with a sufficient power level should be able to delete an alias."""
|
"""A user with a sufficient power level should be able to delete an alias."""
|
||||||
self._create_alias(self.admin_user)
|
self._create_alias(self.admin_user)
|
||||||
|
|
||||||
|
@ -288,7 +294,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
|
||||||
directory.register_servlets,
|
directory.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.handler = hs.get_directory_handler()
|
self.handler = hs.get_directory_handler()
|
||||||
self.state_handler = hs.get_state_handler()
|
self.state_handler = hs.get_state_handler()
|
||||||
|
@ -317,7 +323,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
return room_alias
|
return room_alias
|
||||||
|
|
||||||
def _set_canonical_alias(self, content):
|
def _set_canonical_alias(self, content) -> None:
|
||||||
"""Configure the canonical alias state on the room."""
|
"""Configure the canonical alias state on the room."""
|
||||||
self.helper.send_state(
|
self.helper.send_state(
|
||||||
self.room_id,
|
self.room_id,
|
||||||
|
@ -334,7 +340,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_remove_alias(self):
|
def test_remove_alias(self) -> None:
|
||||||
"""Removing an alias that is the canonical alias should remove it there too."""
|
"""Removing an alias that is the canonical alias should remove it there too."""
|
||||||
# Set this new alias as the canonical alias for this room
|
# Set this new alias as the canonical alias for this room
|
||||||
self._set_canonical_alias(
|
self._set_canonical_alias(
|
||||||
|
@ -356,7 +362,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertNotIn("alias", data["content"])
|
self.assertNotIn("alias", data["content"])
|
||||||
self.assertNotIn("alt_aliases", data["content"])
|
self.assertNotIn("alt_aliases", data["content"])
|
||||||
|
|
||||||
def test_remove_other_alias(self):
|
def test_remove_other_alias(self) -> None:
|
||||||
"""Removing an alias listed as in alt_aliases should remove it there too."""
|
"""Removing an alias listed as in alt_aliases should remove it there too."""
|
||||||
# Create a second alias.
|
# Create a second alias.
|
||||||
other_test_alias = "#test2:test"
|
other_test_alias = "#test2:test"
|
||||||
|
@ -393,7 +399,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
servlets = [directory.register_servlets, room.register_servlets]
|
servlets = [directory.register_servlets, room.register_servlets]
|
||||||
|
|
||||||
def default_config(self):
|
def default_config(self) -> Dict[str, Any]:
|
||||||
config = super().default_config()
|
config = super().default_config()
|
||||||
|
|
||||||
# Add custom alias creation rules to the config.
|
# Add custom alias creation rules to the config.
|
||||||
|
@ -403,7 +409,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def test_denied(self):
|
def test_denied(self) -> None:
|
||||||
room_id = self.helper.create_room_as(self.user_id)
|
room_id = self.helper.create_room_as(self.user_id)
|
||||||
|
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
|
@ -413,7 +419,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.assertEqual(403, channel.code, channel.result)
|
self.assertEqual(403, channel.code, channel.result)
|
||||||
|
|
||||||
def test_allowed(self):
|
def test_allowed(self) -> None:
|
||||||
room_id = self.helper.create_room_as(self.user_id)
|
room_id = self.helper.create_room_as(self.user_id)
|
||||||
|
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
|
@ -423,7 +429,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.assertEqual(200, channel.code, channel.result)
|
self.assertEqual(200, channel.code, channel.result)
|
||||||
|
|
||||||
def test_denied_during_creation(self):
|
def test_denied_during_creation(self) -> None:
|
||||||
"""A room alias that is not allowed should be rejected during creation."""
|
"""A room alias that is not allowed should be rejected during creation."""
|
||||||
# Invalid room alias.
|
# Invalid room alias.
|
||||||
self.helper.create_room_as(
|
self.helper.create_room_as(
|
||||||
|
@ -432,7 +438,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
|
||||||
extra_content={"room_alias_name": "foo"},
|
extra_content={"room_alias_name": "foo"},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_allowed_during_creation(self):
|
def test_allowed_during_creation(self) -> None:
|
||||||
"""A valid room alias should be allowed during creation."""
|
"""A valid room alias should be allowed during creation."""
|
||||||
room_id = self.helper.create_room_as(
|
room_id = self.helper.create_room_as(
|
||||||
self.user_id,
|
self.user_id,
|
||||||
|
@ -459,7 +465,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
|
||||||
data = {"room_alias_name": "unofficial_test"}
|
data = {"room_alias_name": "unofficial_test"}
|
||||||
allowed_localpart = "allowed"
|
allowed_localpart = "allowed"
|
||||||
|
|
||||||
def default_config(self):
|
def default_config(self) -> Dict[str, Any]:
|
||||||
config = super().default_config()
|
config = super().default_config()
|
||||||
|
|
||||||
# Add custom room list publication rules to the config.
|
# Add custom room list publication rules to the config.
|
||||||
|
@ -474,7 +480,9 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(
|
||||||
|
self, reactor: MemoryReactor, clock: Clock, hs: HomeServer
|
||||||
|
) -> HomeServer:
|
||||||
self.allowed_user_id = self.register_user(self.allowed_localpart, "pass")
|
self.allowed_user_id = self.register_user(self.allowed_localpart, "pass")
|
||||||
self.allowed_access_token = self.login(self.allowed_localpart, "pass")
|
self.allowed_access_token = self.login(self.allowed_localpart, "pass")
|
||||||
|
|
||||||
|
@ -483,7 +491,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
return hs
|
return hs
|
||||||
|
|
||||||
def test_denied_without_publication_permission(self):
|
def test_denied_without_publication_permission(self) -> None:
|
||||||
"""
|
"""
|
||||||
Try to create a room, register an alias for it, and publish it,
|
Try to create a room, register an alias for it, and publish it,
|
||||||
as a user without permission to publish rooms.
|
as a user without permission to publish rooms.
|
||||||
|
@ -497,7 +505,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
|
||||||
expect_code=403,
|
expect_code=403,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_allowed_when_creating_private_room(self):
|
def test_allowed_when_creating_private_room(self) -> None:
|
||||||
"""
|
"""
|
||||||
Try to create a room, register an alias for it, and NOT publish it,
|
Try to create a room, register an alias for it, and NOT publish it,
|
||||||
as a user without permission to publish rooms.
|
as a user without permission to publish rooms.
|
||||||
|
@ -511,7 +519,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
|
||||||
expect_code=200,
|
expect_code=200,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_allowed_with_publication_permission(self):
|
def test_allowed_with_publication_permission(self) -> None:
|
||||||
"""
|
"""
|
||||||
Try to create a room, register an alias for it, and publish it,
|
Try to create a room, register an alias for it, and publish it,
|
||||||
as a user WITH permission to publish rooms.
|
as a user WITH permission to publish rooms.
|
||||||
|
@ -525,7 +533,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
|
||||||
expect_code=200,
|
expect_code=200,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_denied_publication_with_invalid_alias(self):
|
def test_denied_publication_with_invalid_alias(self) -> None:
|
||||||
"""
|
"""
|
||||||
Try to create a room, register an alias for it, and publish it,
|
Try to create a room, register an alias for it, and publish it,
|
||||||
as a user WITH permission to publish rooms.
|
as a user WITH permission to publish rooms.
|
||||||
|
@ -538,7 +546,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
|
||||||
expect_code=403,
|
expect_code=403,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_can_create_as_private_room_after_rejection(self):
|
def test_can_create_as_private_room_after_rejection(self) -> None:
|
||||||
"""
|
"""
|
||||||
After failing to publish a room with an alias as a user without publish permission,
|
After failing to publish a room with an alias as a user without publish permission,
|
||||||
retry as the same user, but without publishing the room.
|
retry as the same user, but without publishing the room.
|
||||||
|
@ -549,7 +557,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
|
||||||
self.test_denied_without_publication_permission()
|
self.test_denied_without_publication_permission()
|
||||||
self.test_allowed_when_creating_private_room()
|
self.test_allowed_when_creating_private_room()
|
||||||
|
|
||||||
def test_can_create_with_permission_after_rejection(self):
|
def test_can_create_with_permission_after_rejection(self) -> None:
|
||||||
"""
|
"""
|
||||||
After failing to publish a room with an alias as a user without publish permission,
|
After failing to publish a room with an alias as a user without publish permission,
|
||||||
retry as someone with permission, using the same alias.
|
retry as someone with permission, using the same alias.
|
||||||
|
@ -566,7 +574,9 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
servlets = [directory.register_servlets, room.register_servlets]
|
servlets = [directory.register_servlets, room.register_servlets]
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(
|
||||||
|
self, reactor: MemoryReactor, clock: Clock, hs: HomeServer
|
||||||
|
) -> HomeServer:
|
||||||
room_id = self.helper.create_room_as(self.user_id)
|
room_id = self.helper.create_room_as(self.user_id)
|
||||||
|
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
|
@ -579,7 +589,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
return hs
|
return hs
|
||||||
|
|
||||||
def test_disabling_room_list(self):
|
def test_disabling_room_list(self) -> None:
|
||||||
self.room_list_handler.enable_room_list_search = True
|
self.room_list_handler.enable_room_list_search = True
|
||||||
self.directory_handler.enable_room_list_search = True
|
self.directory_handler.enable_room_list_search = True
|
||||||
|
|
||||||
|
|
|
@ -20,33 +20,37 @@ from parameterized import parameterized
|
||||||
from signedjson import key as key, sign as sign
|
from signedjson import key as key, sign as sign
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.api.constants import RoomEncryptionAlgorithms
|
from synapse.api.constants import RoomEncryptionAlgorithms
|
||||||
from synapse.api.errors import Codes, SynapseError
|
from synapse.api.errors import Codes, SynapseError
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.types import JsonDict
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.test_utils import make_awaitable
|
from tests.test_utils import make_awaitable
|
||||||
|
|
||||||
|
|
||||||
class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
return self.setup_test_homeserver(federation_client=mock.Mock())
|
return self.setup_test_homeserver(federation_client=mock.Mock())
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.handler = hs.get_e2e_keys_handler()
|
self.handler = hs.get_e2e_keys_handler()
|
||||||
self.store = self.hs.get_datastores().main
|
self.store = self.hs.get_datastores().main
|
||||||
|
|
||||||
def test_query_local_devices_no_devices(self):
|
def test_query_local_devices_no_devices(self) -> None:
|
||||||
"""If the user has no devices, we expect an empty list."""
|
"""If the user has no devices, we expect an empty list."""
|
||||||
local_user = "@boris:" + self.hs.hostname
|
local_user = "@boris:" + self.hs.hostname
|
||||||
res = self.get_success(self.handler.query_local_devices({local_user: None}))
|
res = self.get_success(self.handler.query_local_devices({local_user: None}))
|
||||||
self.assertDictEqual(res, {local_user: {}})
|
self.assertDictEqual(res, {local_user: {}})
|
||||||
|
|
||||||
def test_reupload_one_time_keys(self):
|
def test_reupload_one_time_keys(self) -> None:
|
||||||
"""we should be able to re-upload the same keys"""
|
"""we should be able to re-upload the same keys"""
|
||||||
local_user = "@boris:" + self.hs.hostname
|
local_user = "@boris:" + self.hs.hostname
|
||||||
device_id = "xyz"
|
device_id = "xyz"
|
||||||
keys = {
|
keys: JsonDict = {
|
||||||
"alg1:k1": "key1",
|
"alg1:k1": "key1",
|
||||||
"alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}},
|
"alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}},
|
||||||
"alg2:k3": {"key": "key3"},
|
"alg2:k3": {"key": "key3"},
|
||||||
|
@ -74,7 +78,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
res, {"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}}
|
res, {"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}}
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_change_one_time_keys(self):
|
def test_change_one_time_keys(self) -> None:
|
||||||
"""attempts to change one-time-keys should be rejected"""
|
"""attempts to change one-time-keys should be rejected"""
|
||||||
|
|
||||||
local_user = "@boris:" + self.hs.hostname
|
local_user = "@boris:" + self.hs.hostname
|
||||||
|
@ -134,7 +138,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
SynapseError,
|
SynapseError,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_claim_one_time_key(self):
|
def test_claim_one_time_key(self) -> None:
|
||||||
local_user = "@boris:" + self.hs.hostname
|
local_user = "@boris:" + self.hs.hostname
|
||||||
device_id = "xyz"
|
device_id = "xyz"
|
||||||
keys = {"alg1:k1": "key1"}
|
keys = {"alg1:k1": "key1"}
|
||||||
|
@ -161,7 +165,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_fallback_key(self):
|
def test_fallback_key(self) -> None:
|
||||||
local_user = "@boris:" + self.hs.hostname
|
local_user = "@boris:" + self.hs.hostname
|
||||||
device_id = "xyz"
|
device_id = "xyz"
|
||||||
fallback_key = {"alg1:k1": "fallback_key1"}
|
fallback_key = {"alg1:k1": "fallback_key1"}
|
||||||
|
@ -294,7 +298,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}},
|
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_replace_master_key(self):
|
def test_replace_master_key(self) -> None:
|
||||||
"""uploading a new signing key should make the old signing key unavailable"""
|
"""uploading a new signing key should make the old signing key unavailable"""
|
||||||
local_user = "@boris:" + self.hs.hostname
|
local_user = "@boris:" + self.hs.hostname
|
||||||
keys1 = {
|
keys1 = {
|
||||||
|
@ -328,7 +332,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
|
self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
|
||||||
|
|
||||||
def test_reupload_signatures(self):
|
def test_reupload_signatures(self) -> None:
|
||||||
"""re-uploading a signature should not fail"""
|
"""re-uploading a signature should not fail"""
|
||||||
local_user = "@boris:" + self.hs.hostname
|
local_user = "@boris:" + self.hs.hostname
|
||||||
keys1 = {
|
keys1 = {
|
||||||
|
@ -433,7 +437,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertDictEqual(devices["device_keys"][local_user]["abc"], device_key_1)
|
self.assertDictEqual(devices["device_keys"][local_user]["abc"], device_key_1)
|
||||||
self.assertDictEqual(devices["device_keys"][local_user]["def"], device_key_2)
|
self.assertDictEqual(devices["device_keys"][local_user]["def"], device_key_2)
|
||||||
|
|
||||||
def test_self_signing_key_doesnt_show_up_as_device(self):
|
def test_self_signing_key_doesnt_show_up_as_device(self) -> None:
|
||||||
"""signing keys should be hidden when fetching a user's devices"""
|
"""signing keys should be hidden when fetching a user's devices"""
|
||||||
local_user = "@boris:" + self.hs.hostname
|
local_user = "@boris:" + self.hs.hostname
|
||||||
keys1 = {
|
keys1 = {
|
||||||
|
@ -462,7 +466,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
res = self.get_success(self.handler.query_local_devices({local_user: None}))
|
res = self.get_success(self.handler.query_local_devices({local_user: None}))
|
||||||
self.assertDictEqual(res, {local_user: {}})
|
self.assertDictEqual(res, {local_user: {}})
|
||||||
|
|
||||||
def test_upload_signatures(self):
|
def test_upload_signatures(self) -> None:
|
||||||
"""should check signatures that are uploaded"""
|
"""should check signatures that are uploaded"""
|
||||||
# set up a user with cross-signing keys and a device. This user will
|
# set up a user with cross-signing keys and a device. This user will
|
||||||
# try uploading signatures
|
# try uploading signatures
|
||||||
|
@ -686,7 +690,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey],
|
other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey],
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_query_devices_remote_no_sync(self):
|
def test_query_devices_remote_no_sync(self) -> None:
|
||||||
"""Tests that querying keys for a remote user that we don't share a room
|
"""Tests that querying keys for a remote user that we don't share a room
|
||||||
with returns the cross signing keys correctly.
|
with returns the cross signing keys correctly.
|
||||||
"""
|
"""
|
||||||
|
@ -759,7 +763,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_query_devices_remote_sync(self):
|
def test_query_devices_remote_sync(self) -> None:
|
||||||
"""Tests that querying keys for a remote user that we share a room with,
|
"""Tests that querying keys for a remote user that we share a room with,
|
||||||
but haven't yet fetched the keys for, returns the cross signing keys
|
but haven't yet fetched the keys for, returns the cross signing keys
|
||||||
correctly.
|
correctly.
|
||||||
|
@ -845,7 +849,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
(["device_1", "device_2"],),
|
(["device_1", "device_2"],),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
def test_query_all_devices_caches_result(self, device_ids: Iterable[str]):
|
def test_query_all_devices_caches_result(self, device_ids: Iterable[str]) -> None:
|
||||||
"""Test that requests for all of a remote user's devices are cached.
|
"""Test that requests for all of a remote user's devices are cached.
|
||||||
|
|
||||||
We do this by asserting that only one call over federation was made, and that
|
We do this by asserting that only one call over federation was made, and that
|
||||||
|
@ -853,7 +857,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
"""
|
"""
|
||||||
local_user_id = "@test:test"
|
local_user_id = "@test:test"
|
||||||
remote_user_id = "@test:other"
|
remote_user_id = "@test:other"
|
||||||
request_body = {"device_keys": {remote_user_id: []}}
|
request_body: JsonDict = {"device_keys": {remote_user_id: []}}
|
||||||
|
|
||||||
response_devices = [
|
response_devices = [
|
||||||
{
|
{
|
||||||
|
|
|
@ -12,9 +12,11 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import List
|
from typing import List, cast
|
||||||
from unittest import TestCase
|
from unittest import TestCase
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError
|
from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError
|
||||||
from synapse.api.room_versions import RoomVersions
|
from synapse.api.room_versions import RoomVersions
|
||||||
|
@ -23,7 +25,9 @@ from synapse.federation.federation_base import event_from_pdu_json
|
||||||
from synapse.logging.context import LoggingContext, run_in_background
|
from synapse.logging.context import LoggingContext, run_in_background
|
||||||
from synapse.rest import admin
|
from synapse.rest import admin
|
||||||
from synapse.rest.client import login, room
|
from synapse.rest.client import login, room
|
||||||
|
from synapse.server import HomeServer
|
||||||
from synapse.types import create_requester
|
from synapse.types import create_requester
|
||||||
|
from synapse.util import Clock
|
||||||
from synapse.util.stringutils import random_string
|
from synapse.util.stringutils import random_string
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
@ -42,7 +46,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
|
||||||
room.register_servlets,
|
room.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
hs = self.setup_test_homeserver(federation_http_client=None)
|
hs = self.setup_test_homeserver(federation_http_client=None)
|
||||||
self.handler = hs.get_federation_handler()
|
self.handler = hs.get_federation_handler()
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
|
@ -50,7 +54,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
|
||||||
self._event_auth_handler = hs.get_event_auth_handler()
|
self._event_auth_handler = hs.get_event_auth_handler()
|
||||||
return hs
|
return hs
|
||||||
|
|
||||||
def test_exchange_revoked_invite(self):
|
def test_exchange_revoked_invite(self) -> None:
|
||||||
user_id = self.register_user("kermit", "test")
|
user_id = self.register_user("kermit", "test")
|
||||||
tok = self.login("kermit", "test")
|
tok = self.login("kermit", "test")
|
||||||
|
|
||||||
|
@ -96,7 +100,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(failure.errcode, Codes.FORBIDDEN, failure)
|
self.assertEqual(failure.errcode, Codes.FORBIDDEN, failure)
|
||||||
self.assertEqual(failure.msg, "You are not invited to this room.")
|
self.assertEqual(failure.msg, "You are not invited to this room.")
|
||||||
|
|
||||||
def test_rejected_message_event_state(self):
|
def test_rejected_message_event_state(self) -> None:
|
||||||
"""
|
"""
|
||||||
Check that we store the state group correctly for rejected non-state events.
|
Check that we store the state group correctly for rejected non-state events.
|
||||||
|
|
||||||
|
@ -126,7 +130,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
|
||||||
"content": {},
|
"content": {},
|
||||||
"room_id": room_id,
|
"room_id": room_id,
|
||||||
"sender": "@yetanotheruser:" + OTHER_SERVER,
|
"sender": "@yetanotheruser:" + OTHER_SERVER,
|
||||||
"depth": join_event["depth"] + 1,
|
"depth": cast(int, join_event["depth"]) + 1,
|
||||||
"prev_events": [join_event.event_id],
|
"prev_events": [join_event.event_id],
|
||||||
"auth_events": [],
|
"auth_events": [],
|
||||||
"origin_server_ts": self.clock.time_msec(),
|
"origin_server_ts": self.clock.time_msec(),
|
||||||
|
@ -149,7 +153,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.assertEqual(sg, sg2)
|
self.assertEqual(sg, sg2)
|
||||||
|
|
||||||
def test_rejected_state_event_state(self):
|
def test_rejected_state_event_state(self) -> None:
|
||||||
"""
|
"""
|
||||||
Check that we store the state group correctly for rejected state events.
|
Check that we store the state group correctly for rejected state events.
|
||||||
|
|
||||||
|
@ -180,7 +184,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
|
||||||
"content": {},
|
"content": {},
|
||||||
"room_id": room_id,
|
"room_id": room_id,
|
||||||
"sender": "@yetanotheruser:" + OTHER_SERVER,
|
"sender": "@yetanotheruser:" + OTHER_SERVER,
|
||||||
"depth": join_event["depth"] + 1,
|
"depth": cast(int, join_event["depth"]) + 1,
|
||||||
"prev_events": [join_event.event_id],
|
"prev_events": [join_event.event_id],
|
||||||
"auth_events": [],
|
"auth_events": [],
|
||||||
"origin_server_ts": self.clock.time_msec(),
|
"origin_server_ts": self.clock.time_msec(),
|
||||||
|
@ -203,7 +207,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.assertEqual(sg, sg2)
|
self.assertEqual(sg, sg2)
|
||||||
|
|
||||||
def test_backfill_with_many_backward_extremities(self):
|
def test_backfill_with_many_backward_extremities(self) -> None:
|
||||||
"""
|
"""
|
||||||
Check that we can backfill with many backward extremities.
|
Check that we can backfill with many backward extremities.
|
||||||
The goal is to make sure that when we only use a portion
|
The goal is to make sure that when we only use a portion
|
||||||
|
@ -262,7 +266,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.get_success(d)
|
self.get_success(d)
|
||||||
|
|
||||||
def test_backfill_floating_outlier_membership_auth(self):
|
def test_backfill_floating_outlier_membership_auth(self) -> None:
|
||||||
"""
|
"""
|
||||||
As the local homeserver, check that we can properly process a federated
|
As the local homeserver, check that we can properly process a federated
|
||||||
event from the OTHER_SERVER with auth_events that include a floating
|
event from the OTHER_SERVER with auth_events that include a floating
|
||||||
|
@ -377,7 +381,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
|
||||||
for ae in auth_events
|
for ae in auth_events
|
||||||
]
|
]
|
||||||
|
|
||||||
self.handler.federation_client.get_event_auth = get_event_auth
|
self.handler.federation_client.get_event_auth = get_event_auth # type: ignore[assignment]
|
||||||
|
|
||||||
with LoggingContext("receive_pdu"):
|
with LoggingContext("receive_pdu"):
|
||||||
# Fake the OTHER_SERVER federating the message event over to our local homeserver
|
# Fake the OTHER_SERVER federating the message event over to our local homeserver
|
||||||
|
@ -397,7 +401,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
|
||||||
@unittest.override_config(
|
@unittest.override_config(
|
||||||
{"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}}
|
{"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}}
|
||||||
)
|
)
|
||||||
def test_invite_by_user_ratelimit(self):
|
def test_invite_by_user_ratelimit(self) -> None:
|
||||||
"""Tests that invites from federation to a particular user are
|
"""Tests that invites from federation to a particular user are
|
||||||
actually rate-limited.
|
actually rate-limited.
|
||||||
"""
|
"""
|
||||||
|
@ -446,7 +450,9 @@ class FederationTestCase(unittest.HomeserverTestCase):
|
||||||
exc=LimitExceededError,
|
exc=LimitExceededError,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _build_and_send_join_event(self, other_server, other_user, room_id):
|
def _build_and_send_join_event(
|
||||||
|
self, other_server: str, other_user: str, room_id: str
|
||||||
|
) -> EventBase:
|
||||||
join_event = self.get_success(
|
join_event = self.get_success(
|
||||||
self.handler.on_make_join_request(other_server, room_id, other_user)
|
self.handler.on_make_join_request(other_server, room_id, other_user)
|
||||||
)
|
)
|
||||||
|
@ -469,7 +475,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
|
|
||||||
class EventFromPduTestCase(TestCase):
|
class EventFromPduTestCase(TestCase):
|
||||||
def test_valid_json(self):
|
def test_valid_json(self) -> None:
|
||||||
"""Valid JSON should be turned into an event."""
|
"""Valid JSON should be turned into an event."""
|
||||||
ev = event_from_pdu_json(
|
ev = event_from_pdu_json(
|
||||||
{
|
{
|
||||||
|
@ -487,7 +493,7 @@ class EventFromPduTestCase(TestCase):
|
||||||
|
|
||||||
self.assertIsInstance(ev, EventBase)
|
self.assertIsInstance(ev, EventBase)
|
||||||
|
|
||||||
def test_invalid_numbers(self):
|
def test_invalid_numbers(self) -> None:
|
||||||
"""Invalid values for an integer should be rejected, all floats should be rejected."""
|
"""Invalid values for an integer should be rejected, all floats should be rejected."""
|
||||||
for value in [
|
for value in [
|
||||||
-(2 ** 53),
|
-(2 ** 53),
|
||||||
|
@ -512,7 +518,7 @@ class EventFromPduTestCase(TestCase):
|
||||||
RoomVersions.V6,
|
RoomVersions.V6,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_invalid_nested(self):
|
def test_invalid_nested(self) -> None:
|
||||||
"""List and dictionaries are recursively searched."""
|
"""List and dictionaries are recursively searched."""
|
||||||
with self.assertRaises(SynapseError):
|
with self.assertRaises(SynapseError):
|
||||||
event_from_pdu_json(
|
event_from_pdu_json(
|
||||||
|
|
|
@ -13,14 +13,18 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
from typing import Any, Dict
|
||||||
from unittest.mock import ANY, Mock, patch
|
from unittest.mock import ANY, Mock, patch
|
||||||
from urllib.parse import parse_qs, urlparse
|
from urllib.parse import parse_qs, urlparse
|
||||||
|
|
||||||
import pymacaroons
|
import pymacaroons
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.handlers.sso import MappingException
|
from synapse.handlers.sso import MappingException
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.types import UserID
|
from synapse.types import JsonDict, UserID
|
||||||
|
from synapse.util import Clock
|
||||||
from synapse.util.macaroons import get_value_from_macaroon
|
from synapse.util.macaroons import get_value_from_macaroon
|
||||||
|
|
||||||
from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
|
from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
|
||||||
|
@ -98,7 +102,7 @@ class TestMappingProviderFailures(TestMappingProvider):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def get_json(url):
|
async def get_json(url: str) -> JsonDict:
|
||||||
# Mock get_json calls to handle jwks & oidc discovery endpoints
|
# Mock get_json calls to handle jwks & oidc discovery endpoints
|
||||||
if url == WELL_KNOWN:
|
if url == WELL_KNOWN:
|
||||||
# Minimal discovery document, as defined in OpenID.Discovery
|
# Minimal discovery document, as defined in OpenID.Discovery
|
||||||
|
@ -116,6 +120,8 @@ async def get_json(url):
|
||||||
elif url == JWKS_URI:
|
elif url == JWKS_URI:
|
||||||
return {"keys": []}
|
return {"keys": []}
|
||||||
|
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
def _key_file_path() -> str:
|
def _key_file_path() -> str:
|
||||||
"""path to a file containing the private half of a test key"""
|
"""path to a file containing the private half of a test key"""
|
||||||
|
@ -147,12 +153,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
if not HAS_OIDC:
|
if not HAS_OIDC:
|
||||||
skip = "requires OIDC"
|
skip = "requires OIDC"
|
||||||
|
|
||||||
def default_config(self):
|
def default_config(self) -> Dict[str, Any]:
|
||||||
config = super().default_config()
|
config = super().default_config()
|
||||||
config["public_baseurl"] = BASE_URL
|
config["public_baseurl"] = BASE_URL
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
self.http_client = Mock(spec=["get_json"])
|
self.http_client = Mock(spec=["get_json"])
|
||||||
self.http_client.get_json.side_effect = get_json
|
self.http_client.get_json.side_effect = get_json
|
||||||
self.http_client.user_agent = b"Synapse Test"
|
self.http_client.user_agent = b"Synapse Test"
|
||||||
|
@ -164,7 +170,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
sso_handler = hs.get_sso_handler()
|
sso_handler = hs.get_sso_handler()
|
||||||
# Mock the render error method.
|
# Mock the render error method.
|
||||||
self.render_error = Mock(return_value=None)
|
self.render_error = Mock(return_value=None)
|
||||||
sso_handler.render_error = self.render_error
|
sso_handler.render_error = self.render_error # type: ignore[assignment]
|
||||||
|
|
||||||
# Reduce the number of attempts when generating MXIDs.
|
# Reduce the number of attempts when generating MXIDs.
|
||||||
sso_handler._MAP_USERNAME_RETRIES = 3
|
sso_handler._MAP_USERNAME_RETRIES = 3
|
||||||
|
@ -193,14 +199,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
return args
|
return args
|
||||||
|
|
||||||
@override_config({"oidc_config": DEFAULT_CONFIG})
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||||
def test_config(self):
|
def test_config(self) -> None:
|
||||||
"""Basic config correctly sets up the callback URL and client auth correctly."""
|
"""Basic config correctly sets up the callback URL and client auth correctly."""
|
||||||
self.assertEqual(self.provider._callback_url, CALLBACK_URL)
|
self.assertEqual(self.provider._callback_url, CALLBACK_URL)
|
||||||
self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID)
|
self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID)
|
||||||
self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET)
|
self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET)
|
||||||
|
|
||||||
@override_config({"oidc_config": {**DEFAULT_CONFIG, "discover": True}})
|
@override_config({"oidc_config": {**DEFAULT_CONFIG, "discover": True}})
|
||||||
def test_discovery(self):
|
def test_discovery(self) -> None:
|
||||||
"""The handler should discover the endpoints from OIDC discovery document."""
|
"""The handler should discover the endpoints from OIDC discovery document."""
|
||||||
# This would throw if some metadata were invalid
|
# This would throw if some metadata were invalid
|
||||||
metadata = self.get_success(self.provider.load_metadata())
|
metadata = self.get_success(self.provider.load_metadata())
|
||||||
|
@ -219,13 +225,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
self.http_client.get_json.assert_not_called()
|
self.http_client.get_json.assert_not_called()
|
||||||
|
|
||||||
@override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
|
@override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
|
||||||
def test_no_discovery(self):
|
def test_no_discovery(self) -> None:
|
||||||
"""When discovery is disabled, it should not try to load from discovery document."""
|
"""When discovery is disabled, it should not try to load from discovery document."""
|
||||||
self.get_success(self.provider.load_metadata())
|
self.get_success(self.provider.load_metadata())
|
||||||
self.http_client.get_json.assert_not_called()
|
self.http_client.get_json.assert_not_called()
|
||||||
|
|
||||||
@override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
|
@override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
|
||||||
def test_load_jwks(self):
|
def test_load_jwks(self) -> None:
|
||||||
"""JWKS loading is done once (then cached) if used."""
|
"""JWKS loading is done once (then cached) if used."""
|
||||||
jwks = self.get_success(self.provider.load_jwks())
|
jwks = self.get_success(self.provider.load_jwks())
|
||||||
self.http_client.get_json.assert_called_once_with(JWKS_URI)
|
self.http_client.get_json.assert_called_once_with(JWKS_URI)
|
||||||
|
@ -253,7 +259,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
|
self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
|
||||||
|
|
||||||
@override_config({"oidc_config": DEFAULT_CONFIG})
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||||
def test_validate_config(self):
|
def test_validate_config(self) -> None:
|
||||||
"""Provider metadatas are extensively validated."""
|
"""Provider metadatas are extensively validated."""
|
||||||
h = self.provider
|
h = self.provider
|
||||||
|
|
||||||
|
@ -336,14 +342,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
force_load_metadata()
|
force_load_metadata()
|
||||||
|
|
||||||
@override_config({"oidc_config": {**DEFAULT_CONFIG, "skip_verification": True}})
|
@override_config({"oidc_config": {**DEFAULT_CONFIG, "skip_verification": True}})
|
||||||
def test_skip_verification(self):
|
def test_skip_verification(self) -> None:
|
||||||
"""Provider metadata validation can be disabled by config."""
|
"""Provider metadata validation can be disabled by config."""
|
||||||
with self.metadata_edit({"issuer": "http://insecure"}):
|
with self.metadata_edit({"issuer": "http://insecure"}):
|
||||||
# This should not throw
|
# This should not throw
|
||||||
get_awaitable_result(self.provider.load_metadata())
|
get_awaitable_result(self.provider.load_metadata())
|
||||||
|
|
||||||
@override_config({"oidc_config": DEFAULT_CONFIG})
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||||
def test_redirect_request(self):
|
def test_redirect_request(self) -> None:
|
||||||
"""The redirect request has the right arguments & generates a valid session cookie."""
|
"""The redirect request has the right arguments & generates a valid session cookie."""
|
||||||
req = Mock(spec=["cookies"])
|
req = Mock(spec=["cookies"])
|
||||||
req.cookies = []
|
req.cookies = []
|
||||||
|
@ -387,7 +393,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
self.assertEqual(redirect, "http://client/redirect")
|
self.assertEqual(redirect, "http://client/redirect")
|
||||||
|
|
||||||
@override_config({"oidc_config": DEFAULT_CONFIG})
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||||
def test_callback_error(self):
|
def test_callback_error(self) -> None:
|
||||||
"""Errors from the provider returned in the callback are displayed."""
|
"""Errors from the provider returned in the callback are displayed."""
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.args[b"error"] = [b"invalid_client"]
|
request.args[b"error"] = [b"invalid_client"]
|
||||||
|
@ -399,7 +405,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
self.assertRenderedError("invalid_client", "some description")
|
self.assertRenderedError("invalid_client", "some description")
|
||||||
|
|
||||||
@override_config({"oidc_config": DEFAULT_CONFIG})
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||||
def test_callback(self):
|
def test_callback(self) -> None:
|
||||||
"""Code callback works and display errors if something went wrong.
|
"""Code callback works and display errors if something went wrong.
|
||||||
|
|
||||||
A lot of scenarios are tested here:
|
A lot of scenarios are tested here:
|
||||||
|
@ -428,9 +434,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
"username": username,
|
"username": username,
|
||||||
}
|
}
|
||||||
expected_user_id = "@%s:%s" % (username, self.hs.hostname)
|
expected_user_id = "@%s:%s" % (username, self.hs.hostname)
|
||||||
self.provider._exchange_code = simple_async_mock(return_value=token)
|
self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
|
||||||
self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
|
self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
|
||||||
self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
|
self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
|
||||||
auth_handler = self.hs.get_auth_handler()
|
auth_handler = self.hs.get_auth_handler()
|
||||||
auth_handler.complete_sso_login = simple_async_mock()
|
auth_handler.complete_sso_login = simple_async_mock()
|
||||||
|
|
||||||
|
@ -468,7 +474,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
self.assertRenderedError("mapping_error")
|
self.assertRenderedError("mapping_error")
|
||||||
|
|
||||||
# Handle ID token errors
|
# Handle ID token errors
|
||||||
self.provider._parse_id_token = simple_async_mock(raises=Exception())
|
self.provider._parse_id_token = simple_async_mock(raises=Exception()) # type: ignore[assignment]
|
||||||
self.get_success(self.handler.handle_oidc_callback(request))
|
self.get_success(self.handler.handle_oidc_callback(request))
|
||||||
self.assertRenderedError("invalid_token")
|
self.assertRenderedError("invalid_token")
|
||||||
|
|
||||||
|
@ -483,7 +489,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
"type": "bearer",
|
"type": "bearer",
|
||||||
"access_token": "access_token",
|
"access_token": "access_token",
|
||||||
}
|
}
|
||||||
self.provider._exchange_code = simple_async_mock(return_value=token)
|
self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
|
||||||
self.get_success(self.handler.handle_oidc_callback(request))
|
self.get_success(self.handler.handle_oidc_callback(request))
|
||||||
|
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
|
@ -510,8 +516,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
id_token = {
|
id_token = {
|
||||||
"sid": "abcdefgh",
|
"sid": "abcdefgh",
|
||||||
}
|
}
|
||||||
self.provider._parse_id_token = simple_async_mock(return_value=id_token)
|
self.provider._parse_id_token = simple_async_mock(return_value=id_token) # type: ignore[assignment]
|
||||||
self.provider._exchange_code = simple_async_mock(return_value=token)
|
self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
|
||||||
auth_handler.complete_sso_login.reset_mock()
|
auth_handler.complete_sso_login.reset_mock()
|
||||||
self.provider._fetch_userinfo.reset_mock()
|
self.provider._fetch_userinfo.reset_mock()
|
||||||
self.get_success(self.handler.handle_oidc_callback(request))
|
self.get_success(self.handler.handle_oidc_callback(request))
|
||||||
|
@ -531,21 +537,21 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
self.render_error.assert_not_called()
|
self.render_error.assert_not_called()
|
||||||
|
|
||||||
# Handle userinfo fetching error
|
# Handle userinfo fetching error
|
||||||
self.provider._fetch_userinfo = simple_async_mock(raises=Exception())
|
self.provider._fetch_userinfo = simple_async_mock(raises=Exception()) # type: ignore[assignment]
|
||||||
self.get_success(self.handler.handle_oidc_callback(request))
|
self.get_success(self.handler.handle_oidc_callback(request))
|
||||||
self.assertRenderedError("fetch_error")
|
self.assertRenderedError("fetch_error")
|
||||||
|
|
||||||
# Handle code exchange failure
|
# Handle code exchange failure
|
||||||
from synapse.handlers.oidc import OidcError
|
from synapse.handlers.oidc import OidcError
|
||||||
|
|
||||||
self.provider._exchange_code = simple_async_mock(
|
self.provider._exchange_code = simple_async_mock( # type: ignore[assignment]
|
||||||
raises=OidcError("invalid_request")
|
raises=OidcError("invalid_request")
|
||||||
)
|
)
|
||||||
self.get_success(self.handler.handle_oidc_callback(request))
|
self.get_success(self.handler.handle_oidc_callback(request))
|
||||||
self.assertRenderedError("invalid_request")
|
self.assertRenderedError("invalid_request")
|
||||||
|
|
||||||
@override_config({"oidc_config": DEFAULT_CONFIG})
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||||
def test_callback_session(self):
|
def test_callback_session(self) -> None:
|
||||||
"""The callback verifies the session presence and validity"""
|
"""The callback verifies the session presence and validity"""
|
||||||
request = Mock(spec=["args", "getCookie", "cookies"])
|
request = Mock(spec=["args", "getCookie", "cookies"])
|
||||||
|
|
||||||
|
@ -590,7 +596,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
@override_config(
|
@override_config(
|
||||||
{"oidc_config": {**DEFAULT_CONFIG, "client_auth_method": "client_secret_post"}}
|
{"oidc_config": {**DEFAULT_CONFIG, "client_auth_method": "client_secret_post"}}
|
||||||
)
|
)
|
||||||
def test_exchange_code(self):
|
def test_exchange_code(self) -> None:
|
||||||
"""Code exchange behaves correctly and handles various error scenarios."""
|
"""Code exchange behaves correctly and handles various error scenarios."""
|
||||||
token = {"type": "bearer"}
|
token = {"type": "bearer"}
|
||||||
token_json = json.dumps(token).encode("utf-8")
|
token_json = json.dumps(token).encode("utf-8")
|
||||||
|
@ -686,7 +692,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_exchange_code_jwt_key(self):
|
def test_exchange_code_jwt_key(self) -> None:
|
||||||
"""Test that code exchange works with a JWK client secret."""
|
"""Test that code exchange works with a JWK client secret."""
|
||||||
from authlib.jose import jwt
|
from authlib.jose import jwt
|
||||||
|
|
||||||
|
@ -741,7 +747,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_exchange_code_no_auth(self):
|
def test_exchange_code_no_auth(self) -> None:
|
||||||
"""Test that code exchange works with no client secret."""
|
"""Test that code exchange works with no client secret."""
|
||||||
token = {"type": "bearer"}
|
token = {"type": "bearer"}
|
||||||
self.http_client.request = simple_async_mock(
|
self.http_client.request = simple_async_mock(
|
||||||
|
@ -776,7 +782,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_extra_attributes(self):
|
def test_extra_attributes(self) -> None:
|
||||||
"""
|
"""
|
||||||
Login while using a mapping provider that implements get_extra_attributes.
|
Login while using a mapping provider that implements get_extra_attributes.
|
||||||
"""
|
"""
|
||||||
|
@ -790,8 +796,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
"username": "foo",
|
"username": "foo",
|
||||||
"phone": "1234567",
|
"phone": "1234567",
|
||||||
}
|
}
|
||||||
self.provider._exchange_code = simple_async_mock(return_value=token)
|
self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
|
||||||
self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
|
self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
|
||||||
auth_handler = self.hs.get_auth_handler()
|
auth_handler = self.hs.get_auth_handler()
|
||||||
auth_handler.complete_sso_login = simple_async_mock()
|
auth_handler.complete_sso_login = simple_async_mock()
|
||||||
|
|
||||||
|
@ -817,12 +823,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
@override_config({"oidc_config": DEFAULT_CONFIG})
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||||
def test_map_userinfo_to_user(self):
|
def test_map_userinfo_to_user(self) -> None:
|
||||||
"""Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
|
"""Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
|
||||||
auth_handler = self.hs.get_auth_handler()
|
auth_handler = self.hs.get_auth_handler()
|
||||||
auth_handler.complete_sso_login = simple_async_mock()
|
auth_handler.complete_sso_login = simple_async_mock()
|
||||||
|
|
||||||
userinfo = {
|
userinfo: dict = {
|
||||||
"sub": "test_user",
|
"sub": "test_user",
|
||||||
"username": "test_user",
|
"username": "test_user",
|
||||||
}
|
}
|
||||||
|
@ -870,7 +876,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
@override_config({"oidc_config": {**DEFAULT_CONFIG, "allow_existing_users": True}})
|
@override_config({"oidc_config": {**DEFAULT_CONFIG, "allow_existing_users": True}})
|
||||||
def test_map_userinfo_to_existing_user(self):
|
def test_map_userinfo_to_existing_user(self) -> None:
|
||||||
"""Existing users can log in with OpenID Connect when allow_existing_users is True."""
|
"""Existing users can log in with OpenID Connect when allow_existing_users is True."""
|
||||||
store = self.hs.get_datastores().main
|
store = self.hs.get_datastores().main
|
||||||
user = UserID.from_string("@test_user:test")
|
user = UserID.from_string("@test_user:test")
|
||||||
|
@ -974,7 +980,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
@override_config({"oidc_config": DEFAULT_CONFIG})
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||||
def test_map_userinfo_to_invalid_localpart(self):
|
def test_map_userinfo_to_invalid_localpart(self) -> None:
|
||||||
"""If the mapping provider generates an invalid localpart it should be rejected."""
|
"""If the mapping provider generates an invalid localpart it should be rejected."""
|
||||||
self.get_success(
|
self.get_success(
|
||||||
_make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"})
|
_make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"})
|
||||||
|
@ -991,7 +997,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_map_userinfo_to_user_retries(self):
|
def test_map_userinfo_to_user_retries(self) -> None:
|
||||||
"""The mapping provider can retry generating an MXID if the MXID is already in use."""
|
"""The mapping provider can retry generating an MXID if the MXID is already in use."""
|
||||||
auth_handler = self.hs.get_auth_handler()
|
auth_handler = self.hs.get_auth_handler()
|
||||||
auth_handler.complete_sso_login = simple_async_mock()
|
auth_handler.complete_sso_login = simple_async_mock()
|
||||||
|
@ -1039,7 +1045,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
@override_config({"oidc_config": DEFAULT_CONFIG})
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||||
def test_empty_localpart(self):
|
def test_empty_localpart(self) -> None:
|
||||||
"""Attempts to map onto an empty localpart should be rejected."""
|
"""Attempts to map onto an empty localpart should be rejected."""
|
||||||
userinfo = {
|
userinfo = {
|
||||||
"sub": "tester",
|
"sub": "tester",
|
||||||
|
@ -1058,7 +1064,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_null_localpart(self):
|
def test_null_localpart(self) -> None:
|
||||||
"""Mapping onto a null localpart via an empty OIDC attribute should be rejected"""
|
"""Mapping onto a null localpart via an empty OIDC attribute should be rejected"""
|
||||||
userinfo = {
|
userinfo = {
|
||||||
"sub": "tester",
|
"sub": "tester",
|
||||||
|
@ -1075,7 +1081,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_attribute_requirements(self):
|
def test_attribute_requirements(self) -> None:
|
||||||
"""The required attributes must be met from the OIDC userinfo response."""
|
"""The required attributes must be met from the OIDC userinfo response."""
|
||||||
auth_handler = self.hs.get_auth_handler()
|
auth_handler = self.hs.get_auth_handler()
|
||||||
auth_handler.complete_sso_login = simple_async_mock()
|
auth_handler.complete_sso_login = simple_async_mock()
|
||||||
|
@ -1115,7 +1121,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_attribute_requirements_contains(self):
|
def test_attribute_requirements_contains(self) -> None:
|
||||||
"""Test that auth succeeds if userinfo attribute CONTAINS required value"""
|
"""Test that auth succeeds if userinfo attribute CONTAINS required value"""
|
||||||
auth_handler = self.hs.get_auth_handler()
|
auth_handler = self.hs.get_auth_handler()
|
||||||
auth_handler.complete_sso_login = simple_async_mock()
|
auth_handler.complete_sso_login = simple_async_mock()
|
||||||
|
@ -1146,7 +1152,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_attribute_requirements_mismatch(self):
|
def test_attribute_requirements_mismatch(self) -> None:
|
||||||
"""
|
"""
|
||||||
Test that auth fails if attributes exist but don't match,
|
Test that auth fails if attributes exist but don't match,
|
||||||
or are non-string values.
|
or are non-string values.
|
||||||
|
@ -1154,7 +1160,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
auth_handler = self.hs.get_auth_handler()
|
auth_handler = self.hs.get_auth_handler()
|
||||||
auth_handler.complete_sso_login = simple_async_mock()
|
auth_handler.complete_sso_login = simple_async_mock()
|
||||||
# userinfo with "test": "not_foobar" attribute should fail
|
# userinfo with "test": "not_foobar" attribute should fail
|
||||||
userinfo = {
|
userinfo: dict = {
|
||||||
"sub": "tester",
|
"sub": "tester",
|
||||||
"username": "tester",
|
"username": "tester",
|
||||||
"test": "not_foobar",
|
"test": "not_foobar",
|
||||||
|
@ -1248,9 +1254,9 @@ async def _make_callback_with_userinfo(
|
||||||
|
|
||||||
handler = hs.get_oidc_handler()
|
handler = hs.get_oidc_handler()
|
||||||
provider = handler._providers["oidc"]
|
provider = handler._providers["oidc"]
|
||||||
provider._exchange_code = simple_async_mock(return_value={"id_token": ""})
|
provider._exchange_code = simple_async_mock(return_value={"id_token": ""}) # type: ignore[assignment]
|
||||||
provider._parse_id_token = simple_async_mock(return_value=userinfo)
|
provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
|
||||||
provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
|
provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
|
||||||
|
|
||||||
state = "state"
|
state = "state"
|
||||||
session = handler._token_generator.generate_oidc_session_token(
|
session = handler._token_generator.generate_oidc_session_token(
|
||||||
|
|
|
@ -331,11 +331,11 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# Extract presence update user ID and state information into lists of tuples
|
# Extract presence update user ID and state information into lists of tuples
|
||||||
db_presence_states = [(ps[0], ps[1]) for _, ps in db_presence_states[0]]
|
db_presence_states = [(ps[0], ps[1]) for _, ps in db_presence_states[0]]
|
||||||
presence_states = [(ps.user_id, ps.state) for ps in presence_states]
|
presence_states_compare = [(ps.user_id, ps.state) for ps in presence_states]
|
||||||
|
|
||||||
# Compare what we put into the storage with what we got out.
|
# Compare what we put into the storage with what we got out.
|
||||||
# They should be identical.
|
# They should be identical.
|
||||||
self.assertEqual(presence_states, db_presence_states)
|
self.assertEqual(presence_states_compare, db_presence_states)
|
||||||
|
|
||||||
|
|
||||||
class PresenceTimeoutTestCase(unittest.TestCase):
|
class PresenceTimeoutTestCase(unittest.TestCase):
|
||||||
|
@ -357,6 +357,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
|
||||||
new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
|
new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
|
||||||
|
|
||||||
self.assertIsNotNone(new_state)
|
self.assertIsNotNone(new_state)
|
||||||
|
assert new_state is not None
|
||||||
self.assertEqual(new_state.state, PresenceState.UNAVAILABLE)
|
self.assertEqual(new_state.state, PresenceState.UNAVAILABLE)
|
||||||
self.assertEqual(new_state.status_msg, status_msg)
|
self.assertEqual(new_state.status_msg, status_msg)
|
||||||
|
|
||||||
|
@ -380,6 +381,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
|
||||||
new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
|
new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
|
||||||
|
|
||||||
self.assertIsNotNone(new_state)
|
self.assertIsNotNone(new_state)
|
||||||
|
assert new_state is not None
|
||||||
self.assertEqual(new_state.state, PresenceState.BUSY)
|
self.assertEqual(new_state.state, PresenceState.BUSY)
|
||||||
self.assertEqual(new_state.status_msg, status_msg)
|
self.assertEqual(new_state.status_msg, status_msg)
|
||||||
|
|
||||||
|
@ -399,6 +401,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
|
||||||
new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
|
new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
|
||||||
|
|
||||||
self.assertIsNotNone(new_state)
|
self.assertIsNotNone(new_state)
|
||||||
|
assert new_state is not None
|
||||||
self.assertEqual(new_state.state, PresenceState.OFFLINE)
|
self.assertEqual(new_state.state, PresenceState.OFFLINE)
|
||||||
self.assertEqual(new_state.status_msg, status_msg)
|
self.assertEqual(new_state.status_msg, status_msg)
|
||||||
|
|
||||||
|
@ -420,6 +423,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertIsNotNone(new_state)
|
self.assertIsNotNone(new_state)
|
||||||
|
assert new_state is not None
|
||||||
self.assertEqual(new_state.state, PresenceState.ONLINE)
|
self.assertEqual(new_state.state, PresenceState.ONLINE)
|
||||||
self.assertEqual(new_state.status_msg, status_msg)
|
self.assertEqual(new_state.status_msg, status_msg)
|
||||||
|
|
||||||
|
@ -477,6 +481,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertIsNotNone(new_state)
|
self.assertIsNotNone(new_state)
|
||||||
|
assert new_state is not None
|
||||||
self.assertEqual(new_state.state, PresenceState.OFFLINE)
|
self.assertEqual(new_state.state, PresenceState.OFFLINE)
|
||||||
self.assertEqual(new_state.status_msg, status_msg)
|
self.assertEqual(new_state.status_msg, status_msg)
|
||||||
|
|
||||||
|
@ -653,13 +658,13 @@ class PresenceHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
self._set_presencestate_with_status_msg(user_id, PresenceState.ONLINE, None)
|
self._set_presencestate_with_status_msg(user_id, PresenceState.ONLINE, None)
|
||||||
|
|
||||||
def _set_presencestate_with_status_msg(
|
def _set_presencestate_with_status_msg(
|
||||||
self, user_id: str, state: PresenceState, status_msg: Optional[str]
|
self, user_id: str, state: str, status_msg: Optional[str]
|
||||||
):
|
):
|
||||||
"""Set a PresenceState and status_msg and check the result.
|
"""Set a PresenceState and status_msg and check the result.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id: User for that the status is to be set.
|
user_id: User for that the status is to be set.
|
||||||
PresenceState: The new PresenceState.
|
state: The new PresenceState.
|
||||||
status_msg: Status message that is to be set.
|
status_msg: Status message that is to be set.
|
||||||
"""
|
"""
|
||||||
self.get_success(
|
self.get_success(
|
||||||
|
|
|
@ -11,14 +11,17 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import Any, Dict
|
from typing import Any, Awaitable, Callable, Dict
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
import synapse.types
|
import synapse.types
|
||||||
from synapse.api.errors import AuthError, SynapseError
|
from synapse.api.errors import AuthError, SynapseError
|
||||||
from synapse.rest import admin
|
from synapse.rest import admin
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.types import UserID
|
from synapse.types import JsonDict, UserID
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.test_utils import make_awaitable
|
from tests.test_utils import make_awaitable
|
||||||
|
@ -29,13 +32,15 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
servlets = [admin.register_servlets]
|
servlets = [admin.register_servlets]
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
self.mock_federation = Mock()
|
self.mock_federation = Mock()
|
||||||
self.mock_registry = Mock()
|
self.mock_registry = Mock()
|
||||||
|
|
||||||
self.query_handlers = {}
|
self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {}
|
||||||
|
|
||||||
def register_query_handler(query_type, handler):
|
def register_query_handler(
|
||||||
|
query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
|
||||||
|
) -> None:
|
||||||
self.query_handlers[query_type] = handler
|
self.query_handlers[query_type] = handler
|
||||||
|
|
||||||
self.mock_registry.register_query_handler = register_query_handler
|
self.mock_registry.register_query_handler = register_query_handler
|
||||||
|
@ -47,7 +52,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
return hs
|
return hs
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs: HomeServer):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
|
|
||||||
self.frank = UserID.from_string("@1234abcd:test")
|
self.frank = UserID.from_string("@1234abcd:test")
|
||||||
|
@ -58,7 +63,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.handler = hs.get_profile_handler()
|
self.handler = hs.get_profile_handler()
|
||||||
|
|
||||||
def test_get_my_name(self):
|
def test_get_my_name(self) -> None:
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.store.set_profile_displayname(self.frank.localpart, "Frank")
|
self.store.set_profile_displayname(self.frank.localpart, "Frank")
|
||||||
)
|
)
|
||||||
|
@ -67,7 +72,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.assertEqual("Frank", displayname)
|
self.assertEqual("Frank", displayname)
|
||||||
|
|
||||||
def test_set_my_name(self):
|
def test_set_my_name(self) -> None:
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.handler.set_displayname(
|
self.handler.set_displayname(
|
||||||
self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
|
self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
|
||||||
|
@ -110,7 +115,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||||
self.get_success(self.store.get_profile_displayname(self.frank.localpart))
|
self.get_success(self.store.get_profile_displayname(self.frank.localpart))
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_set_my_name_if_disabled(self):
|
def test_set_my_name_if_disabled(self) -> None:
|
||||||
self.hs.config.registration.enable_set_displayname = False
|
self.hs.config.registration.enable_set_displayname = False
|
||||||
|
|
||||||
# Setting displayname for the first time is allowed
|
# Setting displayname for the first time is allowed
|
||||||
|
@ -135,7 +140,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||||
SynapseError,
|
SynapseError,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_set_my_name_noauth(self):
|
def test_set_my_name_noauth(self) -> None:
|
||||||
self.get_failure(
|
self.get_failure(
|
||||||
self.handler.set_displayname(
|
self.handler.set_displayname(
|
||||||
self.frank, synapse.types.create_requester(self.bob), "Frank Jr."
|
self.frank, synapse.types.create_requester(self.bob), "Frank Jr."
|
||||||
|
@ -143,7 +148,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||||
AuthError,
|
AuthError,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_get_other_name(self):
|
def test_get_other_name(self) -> None:
|
||||||
self.mock_federation.make_query.return_value = make_awaitable(
|
self.mock_federation.make_query.return_value = make_awaitable(
|
||||||
{"displayname": "Alice"}
|
{"displayname": "Alice"}
|
||||||
)
|
)
|
||||||
|
@ -158,7 +163,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||||
ignore_backoff=True,
|
ignore_backoff=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_incoming_fed_query(self):
|
def test_incoming_fed_query(self) -> None:
|
||||||
self.get_success(self.store.create_profile("caroline"))
|
self.get_success(self.store.create_profile("caroline"))
|
||||||
self.get_success(self.store.set_profile_displayname("caroline", "Caroline"))
|
self.get_success(self.store.set_profile_displayname("caroline", "Caroline"))
|
||||||
|
|
||||||
|
@ -174,7 +179,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.assertEqual({"displayname": "Caroline"}, response)
|
self.assertEqual({"displayname": "Caroline"}, response)
|
||||||
|
|
||||||
def test_get_my_avatar(self):
|
def test_get_my_avatar(self) -> None:
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.store.set_profile_avatar_url(
|
self.store.set_profile_avatar_url(
|
||||||
self.frank.localpart, "http://my.server/me.png"
|
self.frank.localpart, "http://my.server/me.png"
|
||||||
|
@ -184,7 +189,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.assertEqual("http://my.server/me.png", avatar_url)
|
self.assertEqual("http://my.server/me.png", avatar_url)
|
||||||
|
|
||||||
def test_set_my_avatar(self):
|
def test_set_my_avatar(self) -> None:
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.handler.set_avatar_url(
|
self.handler.set_avatar_url(
|
||||||
self.frank,
|
self.frank,
|
||||||
|
@ -225,7 +230,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||||
(self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
|
(self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_set_my_avatar_if_disabled(self):
|
def test_set_my_avatar_if_disabled(self) -> None:
|
||||||
self.hs.config.registration.enable_set_avatar_url = False
|
self.hs.config.registration.enable_set_avatar_url = False
|
||||||
|
|
||||||
# Setting displayname for the first time is allowed
|
# Setting displayname for the first time is allowed
|
||||||
|
@ -250,7 +255,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||||
SynapseError,
|
SynapseError,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_avatar_constraints_no_config(self):
|
def test_avatar_constraints_no_config(self) -> None:
|
||||||
"""Tests that the method to check an avatar against configured constraints skips
|
"""Tests that the method to check an avatar against configured constraints skips
|
||||||
all of its check if no constraint is configured.
|
all of its check if no constraint is configured.
|
||||||
"""
|
"""
|
||||||
|
@ -263,7 +268,13 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertTrue(res)
|
self.assertTrue(res)
|
||||||
|
|
||||||
@unittest.override_config({"max_avatar_size": 50})
|
@unittest.override_config({"max_avatar_size": 50})
|
||||||
def test_avatar_constraints_missing(self):
|
def test_avatar_constraints_allow_empty_avatar_url(self) -> None:
|
||||||
|
"""An empty avatar is always permitted."""
|
||||||
|
res = self.get_success(self.handler.check_avatar_size_and_mime_type(""))
|
||||||
|
self.assertTrue(res)
|
||||||
|
|
||||||
|
@unittest.override_config({"max_avatar_size": 50})
|
||||||
|
def test_avatar_constraints_missing(self) -> None:
|
||||||
"""Tests that an avatar isn't allowed if the file at the given MXC URI couldn't
|
"""Tests that an avatar isn't allowed if the file at the given MXC URI couldn't
|
||||||
be found.
|
be found.
|
||||||
"""
|
"""
|
||||||
|
@ -273,7 +284,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertFalse(res)
|
self.assertFalse(res)
|
||||||
|
|
||||||
@unittest.override_config({"max_avatar_size": 50})
|
@unittest.override_config({"max_avatar_size": 50})
|
||||||
def test_avatar_constraints_file_size(self):
|
def test_avatar_constraints_file_size(self) -> None:
|
||||||
"""Tests that a file that's above the allowed file size is forbidden but one
|
"""Tests that a file that's above the allowed file size is forbidden but one
|
||||||
that's below it is allowed.
|
that's below it is allowed.
|
||||||
"""
|
"""
|
||||||
|
@ -295,7 +306,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertFalse(res)
|
self.assertFalse(res)
|
||||||
|
|
||||||
@unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]})
|
@unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]})
|
||||||
def test_avatar_constraint_mime_type(self):
|
def test_avatar_constraint_mime_type(self) -> None:
|
||||||
"""Tests that a file with an unauthorised MIME type is forbidden but one with
|
"""Tests that a file with an unauthorised MIME type is forbidden but one with
|
||||||
an authorised content type is allowed.
|
an authorised content type is allowed.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -12,12 +12,16 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Any, Dict, Optional
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.api.errors import RedirectException
|
from synapse.api.errors import RedirectException
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests.test_utils import simple_async_mock
|
from tests.test_utils import simple_async_mock
|
||||||
from tests.unittest import HomeserverTestCase, override_config
|
from tests.unittest import HomeserverTestCase, override_config
|
||||||
|
@ -81,10 +85,10 @@ class TestRedirectMappingProvider(TestMappingProvider):
|
||||||
|
|
||||||
|
|
||||||
class SamlHandlerTestCase(HomeserverTestCase):
|
class SamlHandlerTestCase(HomeserverTestCase):
|
||||||
def default_config(self):
|
def default_config(self) -> Dict[str, Any]:
|
||||||
config = super().default_config()
|
config = super().default_config()
|
||||||
config["public_baseurl"] = BASE_URL
|
config["public_baseurl"] = BASE_URL
|
||||||
saml_config = {
|
saml_config: Dict[str, Any] = {
|
||||||
"sp_config": {"metadata": {}},
|
"sp_config": {"metadata": {}},
|
||||||
# Disable grandfathering.
|
# Disable grandfathering.
|
||||||
"grandfathered_mxid_source_attribute": None,
|
"grandfathered_mxid_source_attribute": None,
|
||||||
|
@ -98,7 +102,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
hs = self.setup_test_homeserver()
|
hs = self.setup_test_homeserver()
|
||||||
|
|
||||||
self.handler = hs.get_saml_handler()
|
self.handler = hs.get_saml_handler()
|
||||||
|
@ -114,7 +118,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||||
elif not has_xmlsec1:
|
elif not has_xmlsec1:
|
||||||
skip = "Requires xmlsec1"
|
skip = "Requires xmlsec1"
|
||||||
|
|
||||||
def test_map_saml_response_to_user(self):
|
def test_map_saml_response_to_user(self) -> None:
|
||||||
"""Ensure that mapping the SAML response returned from a provider to an MXID works properly."""
|
"""Ensure that mapping the SAML response returned from a provider to an MXID works properly."""
|
||||||
|
|
||||||
# stub out the auth handler
|
# stub out the auth handler
|
||||||
|
@ -140,7 +144,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
@override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
|
@override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
|
||||||
def test_map_saml_response_to_existing_user(self):
|
def test_map_saml_response_to_existing_user(self) -> None:
|
||||||
"""Existing users can log in with SAML account."""
|
"""Existing users can log in with SAML account."""
|
||||||
store = self.hs.get_datastores().main
|
store = self.hs.get_datastores().main
|
||||||
self.get_success(
|
self.get_success(
|
||||||
|
@ -186,7 +190,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||||
auth_provider_session_id=None,
|
auth_provider_session_id=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_map_saml_response_to_invalid_localpart(self):
|
def test_map_saml_response_to_invalid_localpart(self) -> None:
|
||||||
"""If the mapping provider generates an invalid localpart it should be rejected."""
|
"""If the mapping provider generates an invalid localpart it should be rejected."""
|
||||||
|
|
||||||
# stub out the auth handler
|
# stub out the auth handler
|
||||||
|
@ -207,7 +211,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
auth_handler.complete_sso_login.assert_not_called()
|
auth_handler.complete_sso_login.assert_not_called()
|
||||||
|
|
||||||
def test_map_saml_response_to_user_retries(self):
|
def test_map_saml_response_to_user_retries(self) -> None:
|
||||||
"""The mapping provider can retry generating an MXID if the MXID is already in use."""
|
"""The mapping provider can retry generating an MXID if the MXID is already in use."""
|
||||||
|
|
||||||
# stub out the auth handler and error renderer
|
# stub out the auth handler and error renderer
|
||||||
|
@ -271,7 +275,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_map_saml_response_redirect(self):
|
def test_map_saml_response_redirect(self) -> None:
|
||||||
"""Test a mapping provider that raises a RedirectException"""
|
"""Test a mapping provider that raises a RedirectException"""
|
||||||
|
|
||||||
saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
|
saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
|
||||||
|
@ -292,7 +296,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_attribute_requirements(self):
|
def test_attribute_requirements(self) -> None:
|
||||||
"""The required attributes must be met from the SAML response."""
|
"""The required attributes must be met from the SAML response."""
|
||||||
|
|
||||||
# stub out the auth handler
|
# stub out the auth handler
|
||||||
|
|
|
@ -18,11 +18,14 @@ from typing import Dict
|
||||||
from unittest.mock import ANY, Mock, call
|
from unittest.mock import ANY, Mock, call
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
from twisted.web.resource import Resource
|
from twisted.web.resource import Resource
|
||||||
|
|
||||||
from synapse.api.errors import AuthError
|
from synapse.api.errors import AuthError
|
||||||
from synapse.federation.transport.server import TransportLayerServer
|
from synapse.federation.transport.server import TransportLayerServer
|
||||||
from synapse.types import UserID, create_requester
|
from synapse.server import HomeServer
|
||||||
|
from synapse.types import JsonDict, UserID, create_requester
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.test_utils import make_awaitable
|
from tests.test_utils import make_awaitable
|
||||||
|
@ -42,7 +45,9 @@ ROOM_ID = "a-room"
|
||||||
OTHER_ROOM_ID = "another-room"
|
OTHER_ROOM_ID = "another-room"
|
||||||
|
|
||||||
|
|
||||||
def _expect_edu_transaction(edu_type, content, origin="test"):
|
def _expect_edu_transaction(
|
||||||
|
edu_type: str, content: JsonDict, origin: str = "test"
|
||||||
|
) -> JsonDict:
|
||||||
return {
|
return {
|
||||||
"origin": origin,
|
"origin": origin,
|
||||||
"origin_server_ts": 1000000,
|
"origin_server_ts": 1000000,
|
||||||
|
@ -51,12 +56,12 @@ def _expect_edu_transaction(edu_type, content, origin="test"):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _make_edu_transaction_json(edu_type, content):
|
def _make_edu_transaction_json(edu_type: str, content: JsonDict) -> bytes:
|
||||||
return json.dumps(_expect_edu_transaction(edu_type, content)).encode("utf8")
|
return json.dumps(_expect_edu_transaction(edu_type, content)).encode("utf8")
|
||||||
|
|
||||||
|
|
||||||
class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
# we mock out the keyring so as to skip the authentication check on the
|
# we mock out the keyring so as to skip the authentication check on the
|
||||||
# federation API call.
|
# federation API call.
|
||||||
mock_keyring = Mock(spec=["verify_json_for_server"])
|
mock_keyring = Mock(spec=["verify_json_for_server"])
|
||||||
|
@ -83,7 +88,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||||
d["/_matrix/federation"] = TransportLayerServer(self.hs)
|
d["/_matrix/federation"] = TransportLayerServer(self.hs)
|
||||||
return d
|
return d
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
mock_notifier = hs.get_notifier()
|
mock_notifier = hs.get_notifier()
|
||||||
self.on_new_event = mock_notifier.on_new_event
|
self.on_new_event = mock_notifier.on_new_event
|
||||||
|
|
||||||
|
@ -111,24 +116,24 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.room_members = []
|
self.room_members = []
|
||||||
|
|
||||||
async def check_user_in_room(room_id, user_id):
|
async def check_user_in_room(room_id: str, user_id: str) -> None:
|
||||||
if user_id not in [u.to_string() for u in self.room_members]:
|
if user_id not in [u.to_string() for u in self.room_members]:
|
||||||
raise AuthError(401, "User is not in the room")
|
raise AuthError(401, "User is not in the room")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
hs.get_auth().check_user_in_room = check_user_in_room
|
hs.get_auth().check_user_in_room = check_user_in_room
|
||||||
|
|
||||||
async def check_host_in_room(room_id, server_name):
|
async def check_host_in_room(room_id: str, server_name: str) -> bool:
|
||||||
return room_id == ROOM_ID
|
return room_id == ROOM_ID
|
||||||
|
|
||||||
hs.get_event_auth_handler().check_host_in_room = check_host_in_room
|
hs.get_event_auth_handler().check_host_in_room = check_host_in_room
|
||||||
|
|
||||||
def get_joined_hosts_for_room(room_id):
|
def get_joined_hosts_for_room(room_id: str):
|
||||||
return {member.domain for member in self.room_members}
|
return {member.domain for member in self.room_members}
|
||||||
|
|
||||||
self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
|
self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
|
||||||
|
|
||||||
async def get_users_in_room(room_id):
|
async def get_users_in_room(room_id: str):
|
||||||
return {str(u) for u in self.room_members}
|
return {str(u) for u in self.room_members}
|
||||||
|
|
||||||
self.datastore.get_users_in_room = get_users_in_room
|
self.datastore.get_users_in_room = get_users_in_room
|
||||||
|
@ -153,7 +158,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||||
lambda *args, **kwargs: make_awaitable(None)
|
lambda *args, **kwargs: make_awaitable(None)
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_started_typing_local(self):
|
def test_started_typing_local(self) -> None:
|
||||||
self.room_members = [U_APPLE, U_BANANA]
|
self.room_members = [U_APPLE, U_BANANA]
|
||||||
|
|
||||||
self.assertEqual(self.event_source.get_current_key(), 0)
|
self.assertEqual(self.event_source.get_current_key(), 0)
|
||||||
|
@ -187,7 +192,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
@override_config({"send_federation": True})
|
@override_config({"send_federation": True})
|
||||||
def test_started_typing_remote_send(self):
|
def test_started_typing_remote_send(self) -> None:
|
||||||
self.room_members = [U_APPLE, U_ONION]
|
self.room_members = [U_APPLE, U_ONION]
|
||||||
|
|
||||||
self.get_success(
|
self.get_success(
|
||||||
|
@ -217,7 +222,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||||
try_trailing_slash_on_400=True,
|
try_trailing_slash_on_400=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_started_typing_remote_recv(self):
|
def test_started_typing_remote_recv(self) -> None:
|
||||||
self.room_members = [U_APPLE, U_ONION]
|
self.room_members = [U_APPLE, U_ONION]
|
||||||
|
|
||||||
self.assertEqual(self.event_source.get_current_key(), 0)
|
self.assertEqual(self.event_source.get_current_key(), 0)
|
||||||
|
@ -256,7 +261,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_started_typing_remote_recv_not_in_room(self):
|
def test_started_typing_remote_recv_not_in_room(self) -> None:
|
||||||
self.room_members = [U_APPLE, U_ONION]
|
self.room_members = [U_APPLE, U_ONION]
|
||||||
|
|
||||||
self.assertEqual(self.event_source.get_current_key(), 0)
|
self.assertEqual(self.event_source.get_current_key(), 0)
|
||||||
|
@ -292,7 +297,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(events[1], 0)
|
self.assertEqual(events[1], 0)
|
||||||
|
|
||||||
@override_config({"send_federation": True})
|
@override_config({"send_federation": True})
|
||||||
def test_stopped_typing(self):
|
def test_stopped_typing(self) -> None:
|
||||||
self.room_members = [U_APPLE, U_BANANA, U_ONION]
|
self.room_members = [U_APPLE, U_BANANA, U_ONION]
|
||||||
|
|
||||||
# Gut-wrenching
|
# Gut-wrenching
|
||||||
|
@ -343,7 +348,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||||
[{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}],
|
[{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}],
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_typing_timeout(self):
|
def test_typing_timeout(self) -> None:
|
||||||
self.room_members = [U_APPLE, U_BANANA]
|
self.room_members = [U_APPLE, U_BANANA]
|
||||||
|
|
||||||
self.assertEqual(self.event_source.get_current_key(), 0)
|
self.assertEqual(self.event_source.get_current_key(), 0)
|
||||||
|
|
|
@ -86,6 +86,16 @@ class ModuleApiTestCase(HomeserverTestCase):
|
||||||
displayname = self.get_success(self.store.get_profile_displayname("bob"))
|
displayname = self.get_success(self.store.get_profile_displayname("bob"))
|
||||||
self.assertEqual(displayname, "Bobberino")
|
self.assertEqual(displayname, "Bobberino")
|
||||||
|
|
||||||
|
def test_can_register_admin_user(self):
|
||||||
|
user_id = self.get_success(
|
||||||
|
self.register_user(
|
||||||
|
"bob_module_admin", "1234", displayname="Bobberino Admin", admin=True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id))
|
||||||
|
self.assertEqual(found_user.user_id.to_string(), user_id)
|
||||||
|
self.assertIdentical(found_user.is_admin, True)
|
||||||
|
|
||||||
def test_get_userinfo_by_id(self):
|
def test_get_userinfo_by_id(self):
|
||||||
user_id = self.register_user("alice", "1234")
|
user_id = self.register_user("alice", "1234")
|
||||||
found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id))
|
found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id))
|
||||||
|
|
|
@ -11,15 +11,19 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import List, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
from twisted.internet.defer import Deferred
|
from twisted.internet.defer import Deferred
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
import synapse.rest.admin
|
import synapse.rest.admin
|
||||||
from synapse.logging.context import make_deferred_yieldable
|
from synapse.logging.context import make_deferred_yieldable
|
||||||
from synapse.push import PusherConfigException
|
from synapse.push import PusherConfigException
|
||||||
from synapse.rest.client import login, push_rule, receipts, room
|
from synapse.rest.client import login, push_rule, receipts, room
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.types import JsonDict
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests.unittest import HomeserverTestCase, override_config
|
from tests.unittest import HomeserverTestCase, override_config
|
||||||
|
|
||||||
|
@ -35,13 +39,13 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||||
user_id = True
|
user_id = True
|
||||||
hijack_auth = False
|
hijack_auth = False
|
||||||
|
|
||||||
def default_config(self):
|
def default_config(self) -> Dict[str, Any]:
|
||||||
config = super().default_config()
|
config = super().default_config()
|
||||||
config["start_pushers"] = True
|
config["start_pushers"] = True
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
self.push_attempts: List[tuple[Deferred, str, dict]] = []
|
self.push_attempts: List[Tuple[Deferred, str, dict]] = []
|
||||||
|
|
||||||
m = Mock()
|
m = Mock()
|
||||||
|
|
||||||
|
@ -56,7 +60,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||||
|
|
||||||
return hs
|
return hs
|
||||||
|
|
||||||
def test_invalid_configuration(self):
|
def test_invalid_configuration(self) -> None:
|
||||||
"""Invalid push configurations should be rejected."""
|
"""Invalid push configurations should be rejected."""
|
||||||
# Register the user who gets notified
|
# Register the user who gets notified
|
||||||
user_id = self.register_user("user", "pass")
|
user_id = self.register_user("user", "pass")
|
||||||
|
@ -68,7 +72,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
token_id = user_tuple.token_id
|
token_id = user_tuple.token_id
|
||||||
|
|
||||||
def test_data(data):
|
def test_data(data: Optional[JsonDict]) -> None:
|
||||||
self.get_failure(
|
self.get_failure(
|
||||||
self.hs.get_pusherpool().add_pusher(
|
self.hs.get_pusherpool().add_pusher(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
@ -95,7 +99,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||||
# A url with an incorrect path isn't accepted.
|
# A url with an incorrect path isn't accepted.
|
||||||
test_data({"url": "http://example.com/foo"})
|
test_data({"url": "http://example.com/foo"})
|
||||||
|
|
||||||
def test_sends_http(self):
|
def test_sends_http(self) -> None:
|
||||||
"""
|
"""
|
||||||
The HTTP pusher will send pushes for each message to a HTTP endpoint
|
The HTTP pusher will send pushes for each message to a HTTP endpoint
|
||||||
when configured to do so.
|
when configured to do so.
|
||||||
|
@ -200,7 +204,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||||
self.assertEqual(len(pushers), 1)
|
self.assertEqual(len(pushers), 1)
|
||||||
self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
|
self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
|
||||||
|
|
||||||
def test_sends_high_priority_for_encrypted(self):
|
def test_sends_high_priority_for_encrypted(self) -> None:
|
||||||
"""
|
"""
|
||||||
The HTTP pusher will send pushes at high priority if they correspond
|
The HTTP pusher will send pushes at high priority if they correspond
|
||||||
to an encrypted message.
|
to an encrypted message.
|
||||||
|
@ -321,7 +325,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high")
|
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high")
|
||||||
|
|
||||||
def test_sends_high_priority_for_one_to_one_only(self):
|
def test_sends_high_priority_for_one_to_one_only(self) -> None:
|
||||||
"""
|
"""
|
||||||
The HTTP pusher will send pushes at high priority if they correspond
|
The HTTP pusher will send pushes at high priority if they correspond
|
||||||
to a message in a one-to-one room.
|
to a message in a one-to-one room.
|
||||||
|
@ -404,7 +408,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||||
# check that this is low-priority
|
# check that this is low-priority
|
||||||
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
|
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
|
||||||
|
|
||||||
def test_sends_high_priority_for_mention(self):
|
def test_sends_high_priority_for_mention(self) -> None:
|
||||||
"""
|
"""
|
||||||
The HTTP pusher will send pushes at high priority if they correspond
|
The HTTP pusher will send pushes at high priority if they correspond
|
||||||
to a message containing the user's display name.
|
to a message containing the user's display name.
|
||||||
|
@ -480,7 +484,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||||
# check that this is low-priority
|
# check that this is low-priority
|
||||||
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
|
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
|
||||||
|
|
||||||
def test_sends_high_priority_for_atroom(self):
|
def test_sends_high_priority_for_atroom(self) -> None:
|
||||||
"""
|
"""
|
||||||
The HTTP pusher will send pushes at high priority if they correspond
|
The HTTP pusher will send pushes at high priority if they correspond
|
||||||
to a message that contains @room.
|
to a message that contains @room.
|
||||||
|
@ -563,7 +567,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||||
# check that this is low-priority
|
# check that this is low-priority
|
||||||
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
|
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
|
||||||
|
|
||||||
def test_push_unread_count_group_by_room(self):
|
def test_push_unread_count_group_by_room(self) -> None:
|
||||||
"""
|
"""
|
||||||
The HTTP pusher will group unread count by number of unread rooms.
|
The HTTP pusher will group unread count by number of unread rooms.
|
||||||
"""
|
"""
|
||||||
|
@ -576,7 +580,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||||
self._check_push_attempt(6, 1)
|
self._check_push_attempt(6, 1)
|
||||||
|
|
||||||
@override_config({"push": {"group_unread_count_by_room": False}})
|
@override_config({"push": {"group_unread_count_by_room": False}})
|
||||||
def test_push_unread_count_message_count(self):
|
def test_push_unread_count_message_count(self) -> None:
|
||||||
"""
|
"""
|
||||||
The HTTP pusher will send the total unread message count.
|
The HTTP pusher will send the total unread message count.
|
||||||
"""
|
"""
|
||||||
|
@ -589,7 +593,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||||
# last read receipt
|
# last read receipt
|
||||||
self._check_push_attempt(6, 3)
|
self._check_push_attempt(6, 3)
|
||||||
|
|
||||||
def _test_push_unread_count(self):
|
def _test_push_unread_count(self) -> None:
|
||||||
"""
|
"""
|
||||||
Tests that the correct unread count appears in sent push notifications
|
Tests that the correct unread count appears in sent push notifications
|
||||||
|
|
||||||
|
@ -681,7 +685,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||||
|
|
||||||
self.helper.send(room_id, body="HELLO???", tok=other_access_token)
|
self.helper.send(room_id, body="HELLO???", tok=other_access_token)
|
||||||
|
|
||||||
def _advance_time_and_make_push_succeed(self, expected_push_attempts):
|
def _advance_time_and_make_push_succeed(self, expected_push_attempts: int) -> None:
|
||||||
self.pump()
|
self.pump()
|
||||||
self.push_attempts[expected_push_attempts - 1][0].callback({})
|
self.push_attempts[expected_push_attempts - 1][0].callback({})
|
||||||
|
|
||||||
|
@ -708,7 +712,9 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||||
expected_unread_count_last_push,
|
expected_unread_count_last_push,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _send_read_request(self, access_token, message_event_id, room_id):
|
def _send_read_request(
|
||||||
|
self, access_token: str, message_event_id: str, room_id: str
|
||||||
|
) -> None:
|
||||||
# Now set the user's read receipt position to the first event
|
# Now set the user's read receipt position to the first event
|
||||||
#
|
#
|
||||||
# This will actually trigger a new notification to be sent out so that
|
# This will actually trigger a new notification to be sent out so that
|
||||||
|
@ -748,7 +754,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||||
|
|
||||||
return user_id, access_token
|
return user_id, access_token
|
||||||
|
|
||||||
def test_dont_notify_rule_overrides_message(self):
|
def test_dont_notify_rule_overrides_message(self) -> None:
|
||||||
"""
|
"""
|
||||||
The override push rule will suppress notification
|
The override push rule will suppress notification
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Any, Dict
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
import frozendict
|
import frozendict
|
||||||
|
|
||||||
|
@ -20,12 +20,13 @@ from synapse.api.room_versions import RoomVersions
|
||||||
from synapse.events import FrozenEvent
|
from synapse.events import FrozenEvent
|
||||||
from synapse.push import push_rule_evaluator
|
from synapse.push import push_rule_evaluator
|
||||||
from synapse.push.push_rule_evaluator import PushRuleEvaluatorForEvent
|
from synapse.push.push_rule_evaluator import PushRuleEvaluatorForEvent
|
||||||
|
from synapse.types import JsonDict
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
|
||||||
|
|
||||||
class PushRuleEvaluatorTestCase(unittest.TestCase):
|
class PushRuleEvaluatorTestCase(unittest.TestCase):
|
||||||
def _get_evaluator(self, content):
|
def _get_evaluator(self, content: JsonDict) -> PushRuleEvaluatorForEvent:
|
||||||
event = FrozenEvent(
|
event = FrozenEvent(
|
||||||
{
|
{
|
||||||
"event_id": "$event_id",
|
"event_id": "$event_id",
|
||||||
|
@ -39,12 +40,12 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
|
||||||
)
|
)
|
||||||
room_member_count = 0
|
room_member_count = 0
|
||||||
sender_power_level = 0
|
sender_power_level = 0
|
||||||
power_levels = {}
|
power_levels: Dict[str, Union[int, Dict[str, int]]] = {}
|
||||||
return PushRuleEvaluatorForEvent(
|
return PushRuleEvaluatorForEvent(
|
||||||
event, room_member_count, sender_power_level, power_levels
|
event, room_member_count, sender_power_level, power_levels
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_display_name(self):
|
def test_display_name(self) -> None:
|
||||||
"""Check for a matching display name in the body of the event."""
|
"""Check for a matching display name in the body of the event."""
|
||||||
evaluator = self._get_evaluator({"body": "foo bar baz"})
|
evaluator = self._get_evaluator({"body": "foo bar baz"})
|
||||||
|
|
||||||
|
@ -71,20 +72,20 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
|
||||||
self.assertTrue(evaluator.matches(condition, "@user:test", "foo bar"))
|
self.assertTrue(evaluator.matches(condition, "@user:test", "foo bar"))
|
||||||
|
|
||||||
def _assert_matches(
|
def _assert_matches(
|
||||||
self, condition: Dict[str, Any], content: Dict[str, Any], msg=None
|
self, condition: JsonDict, content: JsonDict, msg: Optional[str] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
evaluator = self._get_evaluator(content)
|
evaluator = self._get_evaluator(content)
|
||||||
self.assertTrue(evaluator.matches(condition, "@user:test", "display_name"), msg)
|
self.assertTrue(evaluator.matches(condition, "@user:test", "display_name"), msg)
|
||||||
|
|
||||||
def _assert_not_matches(
|
def _assert_not_matches(
|
||||||
self, condition: Dict[str, Any], content: Dict[str, Any], msg=None
|
self, condition: JsonDict, content: JsonDict, msg: Optional[str] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
evaluator = self._get_evaluator(content)
|
evaluator = self._get_evaluator(content)
|
||||||
self.assertFalse(
|
self.assertFalse(
|
||||||
evaluator.matches(condition, "@user:test", "display_name"), msg
|
evaluator.matches(condition, "@user:test", "display_name"), msg
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_event_match_body(self):
|
def test_event_match_body(self) -> None:
|
||||||
"""Check that event_match conditions on content.body work as expected"""
|
"""Check that event_match conditions on content.body work as expected"""
|
||||||
|
|
||||||
# if the key is `content.body`, the pattern matches substrings.
|
# if the key is `content.body`, the pattern matches substrings.
|
||||||
|
@ -165,7 +166,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
|
||||||
r"? after \ should match any character",
|
r"? after \ should match any character",
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_event_match_non_body(self):
|
def test_event_match_non_body(self) -> None:
|
||||||
"""Check that event_match conditions on other keys work as expected"""
|
"""Check that event_match conditions on other keys work as expected"""
|
||||||
|
|
||||||
# if the key is anything other than 'content.body', the pattern must match the
|
# if the key is anything other than 'content.body', the pattern must match the
|
||||||
|
@ -241,7 +242,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
|
||||||
"pattern should not match before a newline",
|
"pattern should not match before a newline",
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_no_body(self):
|
def test_no_body(self) -> None:
|
||||||
"""Not having a body shouldn't break the evaluator."""
|
"""Not having a body shouldn't break the evaluator."""
|
||||||
evaluator = self._get_evaluator({})
|
evaluator = self._get_evaluator({})
|
||||||
|
|
||||||
|
@ -250,7 +251,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
|
||||||
}
|
}
|
||||||
self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
|
self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
|
||||||
|
|
||||||
def test_invalid_body(self):
|
def test_invalid_body(self) -> None:
|
||||||
"""A non-string body should not break the evaluator."""
|
"""A non-string body should not break the evaluator."""
|
||||||
condition = {
|
condition = {
|
||||||
"kind": "contains_display_name",
|
"kind": "contains_display_name",
|
||||||
|
@ -260,7 +261,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
|
||||||
evaluator = self._get_evaluator({"body": body})
|
evaluator = self._get_evaluator({"body": body})
|
||||||
self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
|
self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
|
||||||
|
|
||||||
def test_tweaks_for_actions(self):
|
def test_tweaks_for_actions(self) -> None:
|
||||||
"""
|
"""
|
||||||
This tests the behaviour of tweaks_for_actions.
|
This tests the behaviour of tweaks_for_actions.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1050,6 +1050,25 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self._is_erased("@user:test", True)
|
self._is_erased("@user:test", True)
|
||||||
|
|
||||||
|
@override_config({"max_avatar_size": 1234})
|
||||||
|
def test_deactivate_user_erase_true_avatar_nonnull_but_empty(self) -> None:
|
||||||
|
"""Check we can erase a user whose avatar is the empty string.
|
||||||
|
|
||||||
|
Reproduces #12257.
|
||||||
|
"""
|
||||||
|
# Patch `self.other_user` to have an empty string as their avatar.
|
||||||
|
self.get_success(self.store.set_profile_avatar_url("user", ""))
|
||||||
|
|
||||||
|
# Check we can still erase them.
|
||||||
|
channel = self.make_request(
|
||||||
|
"POST",
|
||||||
|
self.url,
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
content={"erase": True},
|
||||||
|
)
|
||||||
|
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
|
||||||
|
self._is_erased("@user:test", True)
|
||||||
|
|
||||||
def test_deactivate_user_erase_false(self) -> None:
|
def test_deactivate_user_erase_false(self) -> None:
|
||||||
"""
|
"""
|
||||||
Test deactivating a user and set `erase` to `false`
|
Test deactivating a user and set `erase` to `false`
|
||||||
|
|
|
@ -31,7 +31,7 @@ from synapse.rest import admin
|
||||||
from synapse.rest.client import account, login, register, room
|
from synapse.rest.client import account, login, register, room
|
||||||
from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
|
from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict, UserID
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
@ -1222,6 +1222,62 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
|
||||||
expected_failures=[users[2]],
|
expected_failures=[users[2]],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@unittest.override_config(
|
||||||
|
{
|
||||||
|
"use_account_validity_in_account_status": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def test_no_account_validity(self) -> None:
|
||||||
|
"""Tests that if we decide to include account validity in the response but no
|
||||||
|
account validity 'is_user_expired' callback is provided, we default to marking all
|
||||||
|
users as not expired.
|
||||||
|
"""
|
||||||
|
user = self.register_user("someuser", "password")
|
||||||
|
|
||||||
|
self._test_status(
|
||||||
|
users=[user],
|
||||||
|
expected_statuses={
|
||||||
|
user: {
|
||||||
|
"exists": True,
|
||||||
|
"deactivated": False,
|
||||||
|
"org.matrix.expired": False,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected_failures=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
@unittest.override_config(
|
||||||
|
{
|
||||||
|
"use_account_validity_in_account_status": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def test_account_validity_expired(self) -> None:
|
||||||
|
"""Test that if we decide to include account validity in the response and the user
|
||||||
|
is expired, we return the correct info.
|
||||||
|
"""
|
||||||
|
user = self.register_user("someuser", "password")
|
||||||
|
|
||||||
|
async def is_expired(user_id: str) -> bool:
|
||||||
|
# We can't blindly say everyone is expired, otherwise the request to get the
|
||||||
|
# account status will fail.
|
||||||
|
return UserID.from_string(user_id).localpart == "someuser"
|
||||||
|
|
||||||
|
self.hs.get_account_validity_handler()._is_user_expired_callbacks.append(
|
||||||
|
is_expired
|
||||||
|
)
|
||||||
|
|
||||||
|
self._test_status(
|
||||||
|
users=[user],
|
||||||
|
expected_statuses={
|
||||||
|
user: {
|
||||||
|
"exists": True,
|
||||||
|
"deactivated": False,
|
||||||
|
"org.matrix.expired": True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected_failures=[],
|
||||||
|
)
|
||||||
|
|
||||||
def _test_status(
|
def _test_status(
|
||||||
self,
|
self,
|
||||||
users: Optional[List[str]],
|
users: Optional[List[str]],
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
from twisted.test.proto_helpers import MemoryReactor
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
import synapse.rest.admin
|
import synapse.rest.admin
|
||||||
from synapse.rest.client import login, room, shared_rooms
|
from synapse.rest.client import login, mutual_rooms, room
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
|
||||||
|
@ -22,16 +22,16 @@ from tests import unittest
|
||||||
from tests.server import FakeChannel
|
from tests.server import FakeChannel
|
||||||
|
|
||||||
|
|
||||||
class UserSharedRoomsTest(unittest.HomeserverTestCase):
|
class UserMutualRoomsTest(unittest.HomeserverTestCase):
|
||||||
"""
|
"""
|
||||||
Tests the UserSharedRoomsServlet.
|
Tests the UserMutualRoomsServlet.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
servlets = [
|
servlets = [
|
||||||
login.register_servlets,
|
login.register_servlets,
|
||||||
synapse.rest.admin.register_servlets_for_client_rest_resource,
|
synapse.rest.admin.register_servlets_for_client_rest_resource,
|
||||||
room.register_servlets,
|
room.register_servlets,
|
||||||
shared_rooms.register_servlets,
|
mutual_rooms.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
|
@ -43,10 +43,10 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.handler = hs.get_user_directory_handler()
|
self.handler = hs.get_user_directory_handler()
|
||||||
|
|
||||||
def _get_shared_rooms(self, token: str, other_user: str) -> FakeChannel:
|
def _get_mutual_rooms(self, token: str, other_user: str) -> FakeChannel:
|
||||||
return self.make_request(
|
return self.make_request(
|
||||||
"GET",
|
"GET",
|
||||||
"/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/%s"
|
"/_matrix/client/unstable/uk.half-shot.msc2666/user/mutual_rooms/%s"
|
||||||
% other_user,
|
% other_user,
|
||||||
access_token=token,
|
access_token=token,
|
||||||
)
|
)
|
||||||
|
@ -56,14 +56,14 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
|
||||||
A room should show up in the shared list of rooms between two users
|
A room should show up in the shared list of rooms between two users
|
||||||
if it is public.
|
if it is public.
|
||||||
"""
|
"""
|
||||||
self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=True)
|
self._check_mutual_rooms_with(room_one_is_public=True, room_two_is_public=True)
|
||||||
|
|
||||||
def test_shared_room_list_private(self) -> None:
|
def test_shared_room_list_private(self) -> None:
|
||||||
"""
|
"""
|
||||||
A room should show up in the shared list of rooms between two users
|
A room should show up in the shared list of rooms between two users
|
||||||
if it is private.
|
if it is private.
|
||||||
"""
|
"""
|
||||||
self._check_shared_rooms_with(
|
self._check_mutual_rooms_with(
|
||||||
room_one_is_public=False, room_two_is_public=False
|
room_one_is_public=False, room_two_is_public=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -72,9 +72,9 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
|
||||||
The shared room list between two users should contain both public and private
|
The shared room list between two users should contain both public and private
|
||||||
rooms.
|
rooms.
|
||||||
"""
|
"""
|
||||||
self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=False)
|
self._check_mutual_rooms_with(room_one_is_public=True, room_two_is_public=False)
|
||||||
|
|
||||||
def _check_shared_rooms_with(
|
def _check_mutual_rooms_with(
|
||||||
self, room_one_is_public: bool, room_two_is_public: bool
|
self, room_one_is_public: bool, room_two_is_public: bool
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Checks that shared public or private rooms between two users appear in
|
"""Checks that shared public or private rooms between two users appear in
|
||||||
|
@ -94,7 +94,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# Check shared rooms from user1's perspective.
|
# Check shared rooms from user1's perspective.
|
||||||
# We should see the one room in common
|
# We should see the one room in common
|
||||||
channel = self._get_shared_rooms(u1_token, u2)
|
channel = self._get_mutual_rooms(u1_token, u2)
|
||||||
self.assertEqual(200, channel.code, channel.result)
|
self.assertEqual(200, channel.code, channel.result)
|
||||||
self.assertEqual(len(channel.json_body["joined"]), 1)
|
self.assertEqual(len(channel.json_body["joined"]), 1)
|
||||||
self.assertEqual(channel.json_body["joined"][0], room_id_one)
|
self.assertEqual(channel.json_body["joined"][0], room_id_one)
|
||||||
|
@ -107,7 +107,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
|
||||||
self.helper.join(room_id_two, user=u2, tok=u2_token)
|
self.helper.join(room_id_two, user=u2, tok=u2_token)
|
||||||
|
|
||||||
# Check shared rooms again. We should now see both rooms.
|
# Check shared rooms again. We should now see both rooms.
|
||||||
channel = self._get_shared_rooms(u1_token, u2)
|
channel = self._get_mutual_rooms(u1_token, u2)
|
||||||
self.assertEqual(200, channel.code, channel.result)
|
self.assertEqual(200, channel.code, channel.result)
|
||||||
self.assertEqual(len(channel.json_body["joined"]), 2)
|
self.assertEqual(len(channel.json_body["joined"]), 2)
|
||||||
for room_id_id in channel.json_body["joined"]:
|
for room_id_id in channel.json_body["joined"]:
|
||||||
|
@ -128,7 +128,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
|
||||||
self.helper.join(room, user=u2, tok=u2_token)
|
self.helper.join(room, user=u2, tok=u2_token)
|
||||||
|
|
||||||
# Assert user directory is not empty
|
# Assert user directory is not empty
|
||||||
channel = self._get_shared_rooms(u1_token, u2)
|
channel = self._get_mutual_rooms(u1_token, u2)
|
||||||
self.assertEqual(200, channel.code, channel.result)
|
self.assertEqual(200, channel.code, channel.result)
|
||||||
self.assertEqual(len(channel.json_body["joined"]), 1)
|
self.assertEqual(len(channel.json_body["joined"]), 1)
|
||||||
self.assertEqual(channel.json_body["joined"][0], room)
|
self.assertEqual(channel.json_body["joined"][0], room)
|
||||||
|
@ -136,11 +136,11 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
|
||||||
self.helper.leave(room, user=u1, tok=u1_token)
|
self.helper.leave(room, user=u1, tok=u1_token)
|
||||||
|
|
||||||
# Check user1's view of shared rooms with user2
|
# Check user1's view of shared rooms with user2
|
||||||
channel = self._get_shared_rooms(u1_token, u2)
|
channel = self._get_mutual_rooms(u1_token, u2)
|
||||||
self.assertEqual(200, channel.code, channel.result)
|
self.assertEqual(200, channel.code, channel.result)
|
||||||
self.assertEqual(len(channel.json_body["joined"]), 0)
|
self.assertEqual(len(channel.json_body["joined"]), 0)
|
||||||
|
|
||||||
# Check user2's view of shared rooms with user1
|
# Check user2's view of shared rooms with user1
|
||||||
channel = self._get_shared_rooms(u2_token, u1)
|
channel = self._get_mutual_rooms(u2_token, u1)
|
||||||
self.assertEqual(200, channel.code, channel.result)
|
self.assertEqual(200, channel.code, channel.result)
|
||||||
self.assertEqual(len(channel.json_body["joined"]), 0)
|
self.assertEqual(len(channel.json_body["joined"]), 0)
|
File diff suppressed because it is too large
Load diff
|
@ -16,7 +16,6 @@ from synapse.rest.media.v1.preview_html import (
|
||||||
_get_html_media_encodings,
|
_get_html_media_encodings,
|
||||||
decode_body,
|
decode_body,
|
||||||
parse_html_to_open_graph,
|
parse_html_to_open_graph,
|
||||||
rebase_url,
|
|
||||||
summarize_paragraphs,
|
summarize_paragraphs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -161,7 +160,7 @@ class CalcOgTestCase(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tree = decode_body(html, "http://example.com/test.html")
|
tree = decode_body(html, "http://example.com/test.html")
|
||||||
og = parse_html_to_open_graph(tree, "http://example.com/test.html")
|
og = parse_html_to_open_graph(tree)
|
||||||
|
|
||||||
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
|
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
|
||||||
|
|
||||||
|
@ -177,7 +176,7 @@ class CalcOgTestCase(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tree = decode_body(html, "http://example.com/test.html")
|
tree = decode_body(html, "http://example.com/test.html")
|
||||||
og = parse_html_to_open_graph(tree, "http://example.com/test.html")
|
og = parse_html_to_open_graph(tree)
|
||||||
|
|
||||||
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
|
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
|
||||||
|
|
||||||
|
@ -196,7 +195,7 @@ class CalcOgTestCase(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tree = decode_body(html, "http://example.com/test.html")
|
tree = decode_body(html, "http://example.com/test.html")
|
||||||
og = parse_html_to_open_graph(tree, "http://example.com/test.html")
|
og = parse_html_to_open_graph(tree)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
og,
|
og,
|
||||||
|
@ -218,7 +217,7 @@ class CalcOgTestCase(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tree = decode_body(html, "http://example.com/test.html")
|
tree = decode_body(html, "http://example.com/test.html")
|
||||||
og = parse_html_to_open_graph(tree, "http://example.com/test.html")
|
og = parse_html_to_open_graph(tree)
|
||||||
|
|
||||||
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
|
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
|
||||||
|
|
||||||
|
@ -232,7 +231,7 @@ class CalcOgTestCase(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tree = decode_body(html, "http://example.com/test.html")
|
tree = decode_body(html, "http://example.com/test.html")
|
||||||
og = parse_html_to_open_graph(tree, "http://example.com/test.html")
|
og = parse_html_to_open_graph(tree)
|
||||||
|
|
||||||
self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
|
self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
|
||||||
|
|
||||||
|
@ -247,7 +246,7 @@ class CalcOgTestCase(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tree = decode_body(html, "http://example.com/test.html")
|
tree = decode_body(html, "http://example.com/test.html")
|
||||||
og = parse_html_to_open_graph(tree, "http://example.com/test.html")
|
og = parse_html_to_open_graph(tree)
|
||||||
|
|
||||||
self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."})
|
self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."})
|
||||||
|
|
||||||
|
@ -262,7 +261,7 @@ class CalcOgTestCase(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tree = decode_body(html, "http://example.com/test.html")
|
tree = decode_body(html, "http://example.com/test.html")
|
||||||
og = parse_html_to_open_graph(tree, "http://example.com/test.html")
|
og = parse_html_to_open_graph(tree)
|
||||||
|
|
||||||
self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
|
self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
|
||||||
|
|
||||||
|
@ -290,7 +289,7 @@ class CalcOgTestCase(unittest.TestCase):
|
||||||
<head><title>Foo</title></head><body>Some text.</body></html>
|
<head><title>Foo</title></head><body>Some text.</body></html>
|
||||||
""".strip()
|
""".strip()
|
||||||
tree = decode_body(html, "http://example.com/test.html")
|
tree = decode_body(html, "http://example.com/test.html")
|
||||||
og = parse_html_to_open_graph(tree, "http://example.com/test.html")
|
og = parse_html_to_open_graph(tree)
|
||||||
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
|
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
|
||||||
|
|
||||||
def test_invalid_encoding(self) -> None:
|
def test_invalid_encoding(self) -> None:
|
||||||
|
@ -304,7 +303,7 @@ class CalcOgTestCase(unittest.TestCase):
|
||||||
</html>
|
</html>
|
||||||
"""
|
"""
|
||||||
tree = decode_body(html, "http://example.com/test.html", "invalid-encoding")
|
tree = decode_body(html, "http://example.com/test.html", "invalid-encoding")
|
||||||
og = parse_html_to_open_graph(tree, "http://example.com/test.html")
|
og = parse_html_to_open_graph(tree)
|
||||||
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
|
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
|
||||||
|
|
||||||
def test_invalid_encoding2(self) -> None:
|
def test_invalid_encoding2(self) -> None:
|
||||||
|
@ -319,7 +318,7 @@ class CalcOgTestCase(unittest.TestCase):
|
||||||
</html>
|
</html>
|
||||||
"""
|
"""
|
||||||
tree = decode_body(html, "http://example.com/test.html")
|
tree = decode_body(html, "http://example.com/test.html")
|
||||||
og = parse_html_to_open_graph(tree, "http://example.com/test.html")
|
og = parse_html_to_open_graph(tree)
|
||||||
self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."})
|
self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."})
|
||||||
|
|
||||||
def test_windows_1252(self) -> None:
|
def test_windows_1252(self) -> None:
|
||||||
|
@ -333,7 +332,7 @@ class CalcOgTestCase(unittest.TestCase):
|
||||||
</html>
|
</html>
|
||||||
"""
|
"""
|
||||||
tree = decode_body(html, "http://example.com/test.html")
|
tree = decode_body(html, "http://example.com/test.html")
|
||||||
og = parse_html_to_open_graph(tree, "http://example.com/test.html")
|
og = parse_html_to_open_graph(tree)
|
||||||
self.assertEqual(og, {"og:title": "ó", "og:description": "Some text."})
|
self.assertEqual(og, {"og:title": "ó", "og:description": "Some text."})
|
||||||
|
|
||||||
|
|
||||||
|
@ -448,34 +447,3 @@ class MediaEncodingTestCase(unittest.TestCase):
|
||||||
'text/html; charset="invalid"',
|
'text/html; charset="invalid"',
|
||||||
)
|
)
|
||||||
self.assertEqual(list(encodings), ["utf-8", "cp1252"])
|
self.assertEqual(list(encodings), ["utf-8", "cp1252"])
|
||||||
|
|
||||||
|
|
||||||
class RebaseUrlTestCase(unittest.TestCase):
|
|
||||||
def test_relative(self) -> None:
|
|
||||||
"""Relative URLs should be resolved based on the context of the base URL."""
|
|
||||||
self.assertEqual(
|
|
||||||
rebase_url("subpage", "https://example.com/foo/"),
|
|
||||||
"https://example.com/foo/subpage",
|
|
||||||
)
|
|
||||||
self.assertEqual(
|
|
||||||
rebase_url("sibling", "https://example.com/foo"),
|
|
||||||
"https://example.com/sibling",
|
|
||||||
)
|
|
||||||
self.assertEqual(
|
|
||||||
rebase_url("/bar", "https://example.com/foo/"),
|
|
||||||
"https://example.com/bar",
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_absolute(self) -> None:
|
|
||||||
"""Absolute URLs should not be modified."""
|
|
||||||
self.assertEqual(
|
|
||||||
rebase_url("https://alice.com/a/", "https://example.com/foo/"),
|
|
||||||
"https://alice.com/a/",
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_data(self) -> None:
|
|
||||||
"""Data URLs should not be modified."""
|
|
||||||
self.assertEqual(
|
|
||||||
rebase_url("data:,Hello%2C%20World%21", "https://example.com/foo/"),
|
|
||||||
"data:,Hello%2C%20World%21",
|
|
||||||
)
|
|
||||||
|
|
|
@ -54,13 +54,18 @@ from twisted.internet.interfaces import (
|
||||||
ITransport,
|
ITransport,
|
||||||
)
|
)
|
||||||
from twisted.python.failure import Failure
|
from twisted.python.failure import Failure
|
||||||
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
|
from twisted.test.proto_helpers import (
|
||||||
|
AccumulatingProtocol,
|
||||||
|
MemoryReactor,
|
||||||
|
MemoryReactorClock,
|
||||||
|
)
|
||||||
from twisted.web.http_headers import Headers
|
from twisted.web.http_headers import Headers
|
||||||
from twisted.web.resource import IResource
|
from twisted.web.resource import IResource
|
||||||
from twisted.web.server import Request, Site
|
from twisted.web.server import Request, Site
|
||||||
|
|
||||||
from synapse.config.database import DatabaseConnectionConfig
|
from synapse.config.database import DatabaseConnectionConfig
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
|
from synapse.logging.context import ContextResourceUsage
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.storage import DataStore
|
from synapse.storage import DataStore
|
||||||
from synapse.storage.engines import PostgresEngine, create_engine
|
from synapse.storage.engines import PostgresEngine, create_engine
|
||||||
|
@ -88,18 +93,19 @@ class TimedOutException(Exception):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@attr.s
|
@attr.s(auto_attribs=True)
|
||||||
class FakeChannel:
|
class FakeChannel:
|
||||||
"""
|
"""
|
||||||
A fake Twisted Web Channel (the part that interfaces with the
|
A fake Twisted Web Channel (the part that interfaces with the
|
||||||
wire).
|
wire).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
site = attr.ib(type=Union[Site, "FakeSite"])
|
site: Union[Site, "FakeSite"]
|
||||||
_reactor = attr.ib()
|
_reactor: MemoryReactor
|
||||||
result = attr.ib(type=dict, default=attr.Factory(dict))
|
result: dict = attr.Factory(dict)
|
||||||
_ip = attr.ib(type=str, default="127.0.0.1")
|
_ip: str = "127.0.0.1"
|
||||||
_producer: Optional[Union[IPullProducer, IPushProducer]] = None
|
_producer: Optional[Union[IPullProducer, IPushProducer]] = None
|
||||||
|
resource_usage: Optional[ContextResourceUsage] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def json_body(self):
|
def json_body(self):
|
||||||
|
@ -168,6 +174,8 @@ class FakeChannel:
|
||||||
|
|
||||||
def requestDone(self, _self):
|
def requestDone(self, _self):
|
||||||
self.result["done"] = True
|
self.result["done"] = True
|
||||||
|
if isinstance(_self, SynapseRequest):
|
||||||
|
self.resource_usage = _self.logcontext.get_resource_usage()
|
||||||
|
|
||||||
def getPeer(self):
|
def getPeer(self):
|
||||||
# We give an address so that getClientIP returns a non null entry,
|
# We give an address so that getClientIP returns a non null entry,
|
||||||
|
|
|
@ -47,9 +47,18 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
|
||||||
expected_ignorer_user_ids,
|
expected_ignorer_user_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def assert_ignored(
|
||||||
|
self, ignorer_user_id: str, expected_ignored_user_ids: Set[str]
|
||||||
|
) -> None:
|
||||||
|
self.assertEqual(
|
||||||
|
self.get_success(self.store.ignored_users(ignorer_user_id)),
|
||||||
|
expected_ignored_user_ids,
|
||||||
|
)
|
||||||
|
|
||||||
def test_ignoring_users(self):
|
def test_ignoring_users(self):
|
||||||
"""Basic adding/removing of users from the ignore list."""
|
"""Basic adding/removing of users from the ignore list."""
|
||||||
self._update_ignore_list("@other:test", "@another:remote")
|
self._update_ignore_list("@other:test", "@another:remote")
|
||||||
|
self.assert_ignored(self.user, {"@other:test", "@another:remote"})
|
||||||
|
|
||||||
# Check a user which no one ignores.
|
# Check a user which no one ignores.
|
||||||
self.assert_ignorers("@user:test", set())
|
self.assert_ignorers("@user:test", set())
|
||||||
|
@ -62,6 +71,7 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# Add one user, remove one user, and leave one user.
|
# Add one user, remove one user, and leave one user.
|
||||||
self._update_ignore_list("@foo:test", "@another:remote")
|
self._update_ignore_list("@foo:test", "@another:remote")
|
||||||
|
self.assert_ignored(self.user, {"@foo:test", "@another:remote"})
|
||||||
|
|
||||||
# Check the removed user.
|
# Check the removed user.
|
||||||
self.assert_ignorers("@other:test", set())
|
self.assert_ignorers("@other:test", set())
|
||||||
|
@ -76,20 +86,24 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
|
||||||
"""Ensure that caching works properly between different users."""
|
"""Ensure that caching works properly between different users."""
|
||||||
# The first user ignores a user.
|
# The first user ignores a user.
|
||||||
self._update_ignore_list("@other:test")
|
self._update_ignore_list("@other:test")
|
||||||
|
self.assert_ignored(self.user, {"@other:test"})
|
||||||
self.assert_ignorers("@other:test", {self.user})
|
self.assert_ignorers("@other:test", {self.user})
|
||||||
|
|
||||||
# The second user ignores them.
|
# The second user ignores them.
|
||||||
self._update_ignore_list("@other:test", ignorer_user_id="@second:test")
|
self._update_ignore_list("@other:test", ignorer_user_id="@second:test")
|
||||||
|
self.assert_ignored("@second:test", {"@other:test"})
|
||||||
self.assert_ignorers("@other:test", {self.user, "@second:test"})
|
self.assert_ignorers("@other:test", {self.user, "@second:test"})
|
||||||
|
|
||||||
# The first user un-ignores them.
|
# The first user un-ignores them.
|
||||||
self._update_ignore_list()
|
self._update_ignore_list()
|
||||||
|
self.assert_ignored(self.user, set())
|
||||||
self.assert_ignorers("@other:test", {"@second:test"})
|
self.assert_ignorers("@other:test", {"@second:test"})
|
||||||
|
|
||||||
def test_invalid_data(self):
|
def test_invalid_data(self):
|
||||||
"""Invalid data ends up clearing out the ignored users list."""
|
"""Invalid data ends up clearing out the ignored users list."""
|
||||||
# Add some data and ensure it is there.
|
# Add some data and ensure it is there.
|
||||||
self._update_ignore_list("@other:test")
|
self._update_ignore_list("@other:test")
|
||||||
|
self.assert_ignored(self.user, {"@other:test"})
|
||||||
self.assert_ignorers("@other:test", {self.user})
|
self.assert_ignorers("@other:test", {self.user})
|
||||||
|
|
||||||
# No ignored_users key.
|
# No ignored_users key.
|
||||||
|
@ -102,10 +116,12 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# No one ignores the user now.
|
# No one ignores the user now.
|
||||||
|
self.assert_ignored(self.user, set())
|
||||||
self.assert_ignorers("@other:test", set())
|
self.assert_ignorers("@other:test", set())
|
||||||
|
|
||||||
# Add some data and ensure it is there.
|
# Add some data and ensure it is there.
|
||||||
self._update_ignore_list("@other:test")
|
self._update_ignore_list("@other:test")
|
||||||
|
self.assert_ignored(self.user, {"@other:test"})
|
||||||
self.assert_ignorers("@other:test", {self.user})
|
self.assert_ignorers("@other:test", {self.user})
|
||||||
|
|
||||||
# Invalid data.
|
# Invalid data.
|
||||||
|
@ -118,4 +134,5 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# No one ignores the user now.
|
# No one ignores the user now.
|
||||||
|
self.assert_ignored(self.user, set())
|
||||||
self.assert_ignorers("@other:test", set())
|
self.assert_ignorers("@other:test", set())
|
||||||
|
|
|
@ -17,8 +17,12 @@ from unittest.mock import Mock
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from twisted.internet.defer import Deferred, ensureDeferred
|
from twisted.internet.defer import Deferred, ensureDeferred
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
|
from synapse.server import HomeServer
|
||||||
from synapse.storage.background_updates import BackgroundUpdater
|
from synapse.storage.background_updates import BackgroundUpdater
|
||||||
|
from synapse.types import JsonDict
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.test_utils import make_awaitable, simple_async_mock
|
from tests.test_utils import make_awaitable, simple_async_mock
|
||||||
|
@ -26,7 +30,7 @@ from tests.unittest import override_config
|
||||||
|
|
||||||
|
|
||||||
class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
|
class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
|
||||||
def prepare(self, reactor, clock, homeserver):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.updates: BackgroundUpdater = self.hs.get_datastores().main.db_pool.updates
|
self.updates: BackgroundUpdater = self.hs.get_datastores().main.db_pool.updates
|
||||||
# the base test class should have run the real bg updates for us
|
# the base test class should have run the real bg updates for us
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
|
@ -39,7 +43,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.store = self.hs.get_datastores().main
|
self.store = self.hs.get_datastores().main
|
||||||
|
|
||||||
async def update(self, progress, count):
|
async def update(self, progress: JsonDict, count: int) -> int:
|
||||||
duration_ms = 10
|
duration_ms = 10
|
||||||
await self.clock.sleep((count * duration_ms) / 1000)
|
await self.clock.sleep((count * duration_ms) / 1000)
|
||||||
progress = {"my_key": progress["my_key"] + 1}
|
progress = {"my_key": progress["my_key"] + 1}
|
||||||
|
@ -51,7 +55,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
return count
|
return count
|
||||||
|
|
||||||
def test_do_background_update(self):
|
def test_do_background_update(self) -> None:
|
||||||
# the time we claim it takes to update one item when running the update
|
# the time we claim it takes to update one item when running the update
|
||||||
duration_ms = 10
|
duration_ms = 10
|
||||||
|
|
||||||
|
@ -80,7 +84,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# second step: complete the update
|
# second step: complete the update
|
||||||
# we should now get run with a much bigger number of items to update
|
# we should now get run with a much bigger number of items to update
|
||||||
async def update(progress, count):
|
async def update(progress: JsonDict, count: int) -> int:
|
||||||
self.assertEqual(progress, {"my_key": 2})
|
self.assertEqual(progress, {"my_key": 2})
|
||||||
self.assertAlmostEqual(
|
self.assertAlmostEqual(
|
||||||
count,
|
count,
|
||||||
|
@ -110,7 +114,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
def test_background_update_default_batch_set_by_config(self):
|
def test_background_update_default_batch_set_by_config(self) -> None:
|
||||||
"""
|
"""
|
||||||
Test that the background update is run with the default_batch_size set by the config
|
Test that the background update is run with the default_batch_size set by the config
|
||||||
"""
|
"""
|
||||||
|
@ -133,7 +137,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
|
||||||
# on the first call, we should get run with the default background update size specified in the config
|
# on the first call, we should get run with the default background update size specified in the config
|
||||||
self.update_handler.assert_called_once_with({"my_key": 1}, 20)
|
self.update_handler.assert_called_once_with({"my_key": 1}, 20)
|
||||||
|
|
||||||
def test_background_update_default_sleep_behavior(self):
|
def test_background_update_default_sleep_behavior(self) -> None:
|
||||||
"""
|
"""
|
||||||
Test default background update behavior, which is to sleep
|
Test default background update behavior, which is to sleep
|
||||||
"""
|
"""
|
||||||
|
@ -147,7 +151,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.update_handler.side_effect = self.update
|
self.update_handler.side_effect = self.update
|
||||||
self.update_handler.reset_mock()
|
self.update_handler.reset_mock()
|
||||||
self.updates.start_doing_background_updates(),
|
self.updates.start_doing_background_updates()
|
||||||
|
|
||||||
# 2: advance the reactor less than the default sleep duration (1000ms)
|
# 2: advance the reactor less than the default sleep duration (1000ms)
|
||||||
self.reactor.pump([0.5])
|
self.reactor.pump([0.5])
|
||||||
|
@ -167,7 +171,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
def test_background_update_sleep_set_in_config(self):
|
def test_background_update_sleep_set_in_config(self) -> None:
|
||||||
"""
|
"""
|
||||||
Test that changing the sleep time in the config changes how long it sleeps
|
Test that changing the sleep time in the config changes how long it sleeps
|
||||||
"""
|
"""
|
||||||
|
@ -181,7 +185,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.update_handler.side_effect = self.update
|
self.update_handler.side_effect = self.update
|
||||||
self.update_handler.reset_mock()
|
self.update_handler.reset_mock()
|
||||||
self.updates.start_doing_background_updates(),
|
self.updates.start_doing_background_updates()
|
||||||
|
|
||||||
# 2: advance the reactor less than the configured sleep duration (500ms)
|
# 2: advance the reactor less than the configured sleep duration (500ms)
|
||||||
self.reactor.pump([0.45])
|
self.reactor.pump([0.45])
|
||||||
|
@ -201,7 +205,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
def test_disabling_background_update_sleep(self):
|
def test_disabling_background_update_sleep(self) -> None:
|
||||||
"""
|
"""
|
||||||
Test that disabling sleep in the config results in bg update not sleeping
|
Test that disabling sleep in the config results in bg update not sleeping
|
||||||
"""
|
"""
|
||||||
|
@ -215,7 +219,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.update_handler.side_effect = self.update
|
self.update_handler.side_effect = self.update
|
||||||
self.update_handler.reset_mock()
|
self.update_handler.reset_mock()
|
||||||
self.updates.start_doing_background_updates(),
|
self.updates.start_doing_background_updates()
|
||||||
|
|
||||||
# 2: advance the reactor very little
|
# 2: advance the reactor very little
|
||||||
self.reactor.pump([0.025])
|
self.reactor.pump([0.025])
|
||||||
|
@ -230,7 +234,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
def test_background_update_duration_set_in_config(self):
|
def test_background_update_duration_set_in_config(self) -> None:
|
||||||
"""
|
"""
|
||||||
Test that the desired duration set in the config is used in determining batch size
|
Test that the desired duration set in the config is used in determining batch size
|
||||||
"""
|
"""
|
||||||
|
@ -254,7 +258,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# the first update was run with the default batch size, this should be run with 500ms as the
|
# the first update was run with the default batch size, this should be run with 500ms as the
|
||||||
# desired duration
|
# desired duration
|
||||||
async def update(progress, count):
|
async def update(progress: JsonDict, count: int) -> int:
|
||||||
self.assertEqual(progress, {"my_key": 2})
|
self.assertEqual(progress, {"my_key": 2})
|
||||||
self.assertAlmostEqual(
|
self.assertAlmostEqual(
|
||||||
count,
|
count,
|
||||||
|
@ -275,7 +279,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
def test_background_update_min_batch_set_in_config(self):
|
def test_background_update_min_batch_set_in_config(self) -> None:
|
||||||
"""
|
"""
|
||||||
Test that the minimum batch size set in the config is used
|
Test that the minimum batch size set in the config is used
|
||||||
"""
|
"""
|
||||||
|
@ -290,7 +294,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Run the update with the long-running update item
|
# Run the update with the long-running update item
|
||||||
async def update(progress, count):
|
async def update_long(progress: JsonDict, count: int) -> int:
|
||||||
await self.clock.sleep((count * duration_ms) / 1000)
|
await self.clock.sleep((count * duration_ms) / 1000)
|
||||||
progress = {"my_key": progress["my_key"] + 1}
|
progress = {"my_key": progress["my_key"] + 1}
|
||||||
await self.store.db_pool.runInteraction(
|
await self.store.db_pool.runInteraction(
|
||||||
|
@ -301,7 +305,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
return count
|
return count
|
||||||
|
|
||||||
self.update_handler.side_effect = update
|
self.update_handler.side_effect = update_long
|
||||||
self.update_handler.reset_mock()
|
self.update_handler.reset_mock()
|
||||||
res = self.get_success(
|
res = self.get_success(
|
||||||
self.updates.do_next_background_update(False),
|
self.updates.do_next_background_update(False),
|
||||||
|
@ -311,25 +315,25 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# the first update was run with the default batch size, this should be run with minimum batch size
|
# the first update was run with the default batch size, this should be run with minimum batch size
|
||||||
# as the first items took a very long time
|
# as the first items took a very long time
|
||||||
async def update(progress, count):
|
async def update_short(progress: JsonDict, count: int) -> int:
|
||||||
self.assertEqual(progress, {"my_key": 2})
|
self.assertEqual(progress, {"my_key": 2})
|
||||||
self.assertEqual(count, 5)
|
self.assertEqual(count, 5)
|
||||||
await self.updates._end_background_update("test_update")
|
await self.updates._end_background_update("test_update")
|
||||||
return count
|
return count
|
||||||
|
|
||||||
self.update_handler.side_effect = update
|
self.update_handler.side_effect = update_short
|
||||||
self.get_success(self.updates.do_next_background_update(False))
|
self.get_success(self.updates.do_next_background_update(False))
|
||||||
|
|
||||||
|
|
||||||
class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
|
class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
|
||||||
def prepare(self, reactor, clock, homeserver):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.updates: BackgroundUpdater = self.hs.get_datastores().main.db_pool.updates
|
self.updates: BackgroundUpdater = self.hs.get_datastores().main.db_pool.updates
|
||||||
# the base test class should have run the real bg updates for us
|
# the base test class should have run the real bg updates for us
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
self.get_success(self.updates.has_completed_background_updates())
|
self.get_success(self.updates.has_completed_background_updates())
|
||||||
)
|
)
|
||||||
|
|
||||||
self.update_deferred = Deferred()
|
self.update_deferred: Deferred[int] = Deferred()
|
||||||
self.update_handler = Mock(return_value=self.update_deferred)
|
self.update_handler = Mock(return_value=self.update_deferred)
|
||||||
self.updates.register_background_update_handler(
|
self.updates.register_background_update_handler(
|
||||||
"test_update", self.update_handler
|
"test_update", self.update_handler
|
||||||
|
@ -358,7 +362,7 @@ class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_controller(self):
|
def test_controller(self) -> None:
|
||||||
store = self.hs.get_datastores().main
|
store = self.hs.get_datastores().main
|
||||||
self.get_success(
|
self.get_success(
|
||||||
store.db_pool.simple_insert(
|
store.db_pool.simple_insert(
|
||||||
|
@ -368,7 +372,7 @@ class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set the return value for the context manager.
|
# Set the return value for the context manager.
|
||||||
enter_defer = Deferred()
|
enter_defer: Deferred[int] = Deferred()
|
||||||
self._update_ctx_manager.__aenter__ = Mock(return_value=enter_defer)
|
self._update_ctx_manager.__aenter__ = Mock(return_value=enter_defer)
|
||||||
|
|
||||||
# Start the background update.
|
# Start the background update.
|
||||||
|
|
|
@ -12,7 +12,20 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from synapse.storage.database import make_tuple_comparison_clause
|
from typing import Callable, Tuple
|
||||||
|
from unittest.mock import Mock, call
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
from twisted.internet.defer import CancelledError, Deferred
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.storage.database import (
|
||||||
|
DatabasePool,
|
||||||
|
LoggingTransaction,
|
||||||
|
make_tuple_comparison_clause,
|
||||||
|
)
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
|
||||||
|
@ -22,3 +35,150 @@ class TupleComparisonClauseTestCase(unittest.TestCase):
|
||||||
clause, args = make_tuple_comparison_clause([("a", 1), ("b", 2)])
|
clause, args = make_tuple_comparison_clause([("a", 1), ("b", 2)])
|
||||||
self.assertEqual(clause, "(a,b) > (?,?)")
|
self.assertEqual(clause, "(a,b) > (?,?)")
|
||||||
self.assertEqual(args, [1, 2])
|
self.assertEqual(args, [1, 2])
|
||||||
|
|
||||||
|
|
||||||
|
class CallbacksTestCase(unittest.HomeserverTestCase):
|
||||||
|
"""Tests for transaction callbacks."""
|
||||||
|
|
||||||
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
|
self.store = hs.get_datastores().main
|
||||||
|
self.db_pool: DatabasePool = self.store.db_pool
|
||||||
|
|
||||||
|
def _run_interaction(
|
||||||
|
self, func: Callable[[LoggingTransaction], object]
|
||||||
|
) -> Tuple[Mock, Mock]:
|
||||||
|
"""Run the given function in a database transaction, with callbacks registered.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: The function to be run in a transaction. The transaction will be
|
||||||
|
retried if `func` raises an `OperationalError`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Two mocks, which were registered as an `after_callback` and an
|
||||||
|
`exception_callback` respectively, on every transaction attempt.
|
||||||
|
"""
|
||||||
|
after_callback = Mock()
|
||||||
|
exception_callback = Mock()
|
||||||
|
|
||||||
|
def _test_txn(txn: LoggingTransaction) -> None:
|
||||||
|
txn.call_after(after_callback, 123, 456, extra=789)
|
||||||
|
txn.call_on_exception(exception_callback, 987, 654, extra=321)
|
||||||
|
func(txn)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.get_success_or_raise(
|
||||||
|
self.db_pool.runInteraction("test_transaction", _test_txn)
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return after_callback, exception_callback
|
||||||
|
|
||||||
|
def test_after_callback(self) -> None:
|
||||||
|
"""Test that the after callback is called when a transaction succeeds."""
|
||||||
|
after_callback, exception_callback = self._run_interaction(lambda txn: None)
|
||||||
|
|
||||||
|
after_callback.assert_called_once_with(123, 456, extra=789)
|
||||||
|
exception_callback.assert_not_called()
|
||||||
|
|
||||||
|
def test_exception_callback(self) -> None:
|
||||||
|
"""Test that the exception callback is called when a transaction fails."""
|
||||||
|
_test_txn = Mock(side_effect=ZeroDivisionError)
|
||||||
|
after_callback, exception_callback = self._run_interaction(_test_txn)
|
||||||
|
|
||||||
|
after_callback.assert_not_called()
|
||||||
|
exception_callback.assert_called_once_with(987, 654, extra=321)
|
||||||
|
|
||||||
|
def test_failed_retry(self) -> None:
|
||||||
|
"""Test that the exception callback is called for every failed attempt."""
|
||||||
|
# Always raise an `OperationalError`.
|
||||||
|
_test_txn = Mock(side_effect=self.db_pool.engine.module.OperationalError)
|
||||||
|
after_callback, exception_callback = self._run_interaction(_test_txn)
|
||||||
|
|
||||||
|
after_callback.assert_not_called()
|
||||||
|
exception_callback.assert_has_calls(
|
||||||
|
[
|
||||||
|
call(987, 654, extra=321),
|
||||||
|
call(987, 654, extra=321),
|
||||||
|
call(987, 654, extra=321),
|
||||||
|
call(987, 654, extra=321),
|
||||||
|
call(987, 654, extra=321),
|
||||||
|
call(987, 654, extra=321),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.assertEqual(exception_callback.call_count, 6) # no additional calls
|
||||||
|
|
||||||
|
def test_successful_retry(self) -> None:
|
||||||
|
"""Test callbacks for a failed transaction followed by a successful attempt."""
|
||||||
|
# Raise an `OperationalError` on the first attempt only.
|
||||||
|
_test_txn = Mock(
|
||||||
|
side_effect=[self.db_pool.engine.module.OperationalError, None]
|
||||||
|
)
|
||||||
|
after_callback, exception_callback = self._run_interaction(_test_txn)
|
||||||
|
|
||||||
|
# Calling both `after_callback`s when the first attempt failed is rather
|
||||||
|
# surprising (#12184). Let's document the behaviour in a test.
|
||||||
|
after_callback.assert_has_calls(
|
||||||
|
[
|
||||||
|
call(123, 456, extra=789),
|
||||||
|
call(123, 456, extra=789),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.assertEqual(after_callback.call_count, 2) # no additional calls
|
||||||
|
exception_callback.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
class CancellationTestCase(unittest.HomeserverTestCase):
|
||||||
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
|
self.store = hs.get_datastores().main
|
||||||
|
self.db_pool: DatabasePool = self.store.db_pool
|
||||||
|
|
||||||
|
def test_after_callback(self) -> None:
|
||||||
|
"""Test that the after callback is called when a transaction succeeds."""
|
||||||
|
d: "Deferred[None]"
|
||||||
|
after_callback = Mock()
|
||||||
|
exception_callback = Mock()
|
||||||
|
|
||||||
|
def _test_txn(txn: LoggingTransaction) -> None:
|
||||||
|
txn.call_after(after_callback, 123, 456, extra=789)
|
||||||
|
txn.call_on_exception(exception_callback, 987, 654, extra=321)
|
||||||
|
d.cancel()
|
||||||
|
|
||||||
|
d = defer.ensureDeferred(
|
||||||
|
self.db_pool.runInteraction("test_transaction", _test_txn)
|
||||||
|
)
|
||||||
|
self.get_failure(d, CancelledError)
|
||||||
|
|
||||||
|
after_callback.assert_called_once_with(123, 456, extra=789)
|
||||||
|
exception_callback.assert_not_called()
|
||||||
|
|
||||||
|
def test_exception_callback(self) -> None:
|
||||||
|
"""Test that the exception callback is called when a transaction fails."""
|
||||||
|
d: "Deferred[None]"
|
||||||
|
after_callback = Mock()
|
||||||
|
exception_callback = Mock()
|
||||||
|
|
||||||
|
def _test_txn(txn: LoggingTransaction) -> None:
|
||||||
|
txn.call_after(after_callback, 123, 456, extra=789)
|
||||||
|
txn.call_on_exception(exception_callback, 987, 654, extra=321)
|
||||||
|
d.cancel()
|
||||||
|
# Simulate a retryable failure on every attempt.
|
||||||
|
raise self.db_pool.engine.module.OperationalError()
|
||||||
|
|
||||||
|
d = defer.ensureDeferred(
|
||||||
|
self.db_pool.runInteraction("test_transaction", _test_txn)
|
||||||
|
)
|
||||||
|
self.get_failure(d, CancelledError)
|
||||||
|
|
||||||
|
after_callback.assert_not_called()
|
||||||
|
exception_callback.assert_has_calls(
|
||||||
|
[
|
||||||
|
call(987, 654, extra=321),
|
||||||
|
call(987, 654, extra=321),
|
||||||
|
call(987, 654, extra=321),
|
||||||
|
call(987, 654, extra=321),
|
||||||
|
call(987, 654, extra=321),
|
||||||
|
call(987, 654, extra=321),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.assertEqual(exception_callback.call_count, 6) # no additional calls
|
||||||
|
|
|
@ -13,9 +13,13 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from synapse.storage.database import DatabasePool
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.storage.database import DatabasePool, LoggingTransaction
|
||||||
from synapse.storage.engines import IncorrectDatabaseSetup
|
from synapse.storage.engines import IncorrectDatabaseSetup
|
||||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests.unittest import HomeserverTestCase
|
from tests.unittest import HomeserverTestCase
|
||||||
from tests.utils import USE_POSTGRES_FOR_TESTS
|
from tests.utils import USE_POSTGRES_FOR_TESTS
|
||||||
|
@ -25,13 +29,13 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||||
if not USE_POSTGRES_FOR_TESTS:
|
if not USE_POSTGRES_FOR_TESTS:
|
||||||
skip = "Requires Postgres"
|
skip = "Requires Postgres"
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.db_pool: DatabasePool = self.store.db_pool
|
self.db_pool: DatabasePool = self.store.db_pool
|
||||||
|
|
||||||
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
|
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
|
||||||
|
|
||||||
def _setup_db(self, txn):
|
def _setup_db(self, txn: LoggingTransaction) -> None:
|
||||||
txn.execute("CREATE SEQUENCE foobar_seq")
|
txn.execute("CREATE SEQUENCE foobar_seq")
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"""
|
"""
|
||||||
|
@ -59,12 +63,12 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
|
return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
|
||||||
|
|
||||||
def _insert_rows(self, instance_name: str, number: int):
|
def _insert_rows(self, instance_name: str, number: int) -> None:
|
||||||
"""Insert N rows as the given instance, inserting with stream IDs pulled
|
"""Insert N rows as the given instance, inserting with stream IDs pulled
|
||||||
from the postgres sequence.
|
from the postgres sequence.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _insert(txn):
|
def _insert(txn: LoggingTransaction) -> None:
|
||||||
for _ in range(number):
|
for _ in range(number):
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"INSERT INTO foobar VALUES (nextval('foobar_seq'), ?)",
|
"INSERT INTO foobar VALUES (nextval('foobar_seq'), ?)",
|
||||||
|
@ -80,12 +84,12 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
|
self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
|
||||||
|
|
||||||
def _insert_row_with_id(self, instance_name: str, stream_id: int):
|
def _insert_row_with_id(self, instance_name: str, stream_id: int) -> None:
|
||||||
"""Insert one row as the given instance with given stream_id, updating
|
"""Insert one row as the given instance with given stream_id, updating
|
||||||
the postgres sequence position to match.
|
the postgres sequence position to match.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _insert(txn):
|
def _insert(txn: LoggingTransaction) -> None:
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"INSERT INTO foobar VALUES (?, ?)",
|
"INSERT INTO foobar VALUES (?, ?)",
|
||||||
(
|
(
|
||||||
|
@ -104,7 +108,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert))
|
self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert))
|
||||||
|
|
||||||
def test_empty(self):
|
def test_empty(self) -> None:
|
||||||
"""Test an ID generator against an empty database gives sensible
|
"""Test an ID generator against an empty database gives sensible
|
||||||
current positions.
|
current positions.
|
||||||
"""
|
"""
|
||||||
|
@ -114,7 +118,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||||
# The table is empty so we expect an empty map for positions
|
# The table is empty so we expect an empty map for positions
|
||||||
self.assertEqual(id_gen.get_positions(), {})
|
self.assertEqual(id_gen.get_positions(), {})
|
||||||
|
|
||||||
def test_single_instance(self):
|
def test_single_instance(self) -> None:
|
||||||
"""Test that reads and writes from a single process are handled
|
"""Test that reads and writes from a single process are handled
|
||||||
correctly.
|
correctly.
|
||||||
"""
|
"""
|
||||||
|
@ -130,7 +134,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||||
# Try allocating a new ID gen and check that we only see position
|
# Try allocating a new ID gen and check that we only see position
|
||||||
# advanced after we leave the context manager.
|
# advanced after we leave the context manager.
|
||||||
|
|
||||||
async def _get_next_async():
|
async def _get_next_async() -> None:
|
||||||
async with id_gen.get_next() as stream_id:
|
async with id_gen.get_next() as stream_id:
|
||||||
self.assertEqual(stream_id, 8)
|
self.assertEqual(stream_id, 8)
|
||||||
|
|
||||||
|
@ -142,7 +146,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||||
self.assertEqual(id_gen.get_positions(), {"master": 8})
|
self.assertEqual(id_gen.get_positions(), {"master": 8})
|
||||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
|
self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
|
||||||
|
|
||||||
def test_out_of_order_finish(self):
|
def test_out_of_order_finish(self) -> None:
|
||||||
"""Test that IDs persisted out of order are correctly handled"""
|
"""Test that IDs persisted out of order are correctly handled"""
|
||||||
|
|
||||||
# Prefill table with 7 rows written by 'master'
|
# Prefill table with 7 rows written by 'master'
|
||||||
|
@ -191,7 +195,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||||
self.assertEqual(id_gen.get_positions(), {"master": 11})
|
self.assertEqual(id_gen.get_positions(), {"master": 11})
|
||||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), 11)
|
self.assertEqual(id_gen.get_current_token_for_writer("master"), 11)
|
||||||
|
|
||||||
def test_multi_instance(self):
|
def test_multi_instance(self) -> None:
|
||||||
"""Test that reads and writes from multiple processes are handled
|
"""Test that reads and writes from multiple processes are handled
|
||||||
correctly.
|
correctly.
|
||||||
"""
|
"""
|
||||||
|
@ -215,7 +219,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||||
# Try allocating a new ID gen and check that we only see position
|
# Try allocating a new ID gen and check that we only see position
|
||||||
# advanced after we leave the context manager.
|
# advanced after we leave the context manager.
|
||||||
|
|
||||||
async def _get_next_async():
|
async def _get_next_async() -> None:
|
||||||
async with first_id_gen.get_next() as stream_id:
|
async with first_id_gen.get_next() as stream_id:
|
||||||
self.assertEqual(stream_id, 8)
|
self.assertEqual(stream_id, 8)
|
||||||
|
|
||||||
|
@ -233,7 +237,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||||
# ... but calling `get_next` on the second instance should give a unique
|
# ... but calling `get_next` on the second instance should give a unique
|
||||||
# stream ID
|
# stream ID
|
||||||
|
|
||||||
async def _get_next_async():
|
async def _get_next_async2() -> None:
|
||||||
async with second_id_gen.get_next() as stream_id:
|
async with second_id_gen.get_next() as stream_id:
|
||||||
self.assertEqual(stream_id, 9)
|
self.assertEqual(stream_id, 9)
|
||||||
|
|
||||||
|
@ -241,7 +245,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||||
second_id_gen.get_positions(), {"first": 3, "second": 7}
|
second_id_gen.get_positions(), {"first": 3, "second": 7}
|
||||||
)
|
)
|
||||||
|
|
||||||
self.get_success(_get_next_async())
|
self.get_success(_get_next_async2())
|
||||||
|
|
||||||
self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 9})
|
self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 9})
|
||||||
|
|
||||||
|
@ -249,7 +253,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||||
second_id_gen.advance("first", 8)
|
second_id_gen.advance("first", 8)
|
||||||
self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9})
|
self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9})
|
||||||
|
|
||||||
def test_get_next_txn(self):
|
def test_get_next_txn(self) -> None:
|
||||||
"""Test that the `get_next_txn` function works correctly."""
|
"""Test that the `get_next_txn` function works correctly."""
|
||||||
|
|
||||||
# Prefill table with 7 rows written by 'master'
|
# Prefill table with 7 rows written by 'master'
|
||||||
|
@ -263,7 +267,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||||
# Try allocating a new ID gen and check that we only see position
|
# Try allocating a new ID gen and check that we only see position
|
||||||
# advanced after we leave the context manager.
|
# advanced after we leave the context manager.
|
||||||
|
|
||||||
def _get_next_txn(txn):
|
def _get_next_txn(txn: LoggingTransaction) -> None:
|
||||||
stream_id = id_gen.get_next_txn(txn)
|
stream_id = id_gen.get_next_txn(txn)
|
||||||
self.assertEqual(stream_id, 8)
|
self.assertEqual(stream_id, 8)
|
||||||
|
|
||||||
|
@ -275,7 +279,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||||
self.assertEqual(id_gen.get_positions(), {"master": 8})
|
self.assertEqual(id_gen.get_positions(), {"master": 8})
|
||||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
|
self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
|
||||||
|
|
||||||
def test_get_persisted_upto_position(self):
|
def test_get_persisted_upto_position(self) -> None:
|
||||||
"""Test that `get_persisted_upto_position` correctly tracks updates to
|
"""Test that `get_persisted_upto_position` correctly tracks updates to
|
||||||
positions.
|
positions.
|
||||||
"""
|
"""
|
||||||
|
@ -317,7 +321,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||||
id_gen.advance("second", 15)
|
id_gen.advance("second", 15)
|
||||||
self.assertEqual(id_gen.get_persisted_upto_position(), 11)
|
self.assertEqual(id_gen.get_persisted_upto_position(), 11)
|
||||||
|
|
||||||
def test_get_persisted_upto_position_get_next(self):
|
def test_get_persisted_upto_position_get_next(self) -> None:
|
||||||
"""Test that `get_persisted_upto_position` correctly tracks updates to
|
"""Test that `get_persisted_upto_position` correctly tracks updates to
|
||||||
positions when `get_next` is called.
|
positions when `get_next` is called.
|
||||||
"""
|
"""
|
||||||
|
@ -331,7 +335,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
self.assertEqual(id_gen.get_persisted_upto_position(), 5)
|
self.assertEqual(id_gen.get_persisted_upto_position(), 5)
|
||||||
|
|
||||||
async def _get_next_async():
|
async def _get_next_async() -> None:
|
||||||
async with id_gen.get_next() as stream_id:
|
async with id_gen.get_next() as stream_id:
|
||||||
self.assertEqual(stream_id, 6)
|
self.assertEqual(stream_id, 6)
|
||||||
self.assertEqual(id_gen.get_persisted_upto_position(), 5)
|
self.assertEqual(id_gen.get_persisted_upto_position(), 5)
|
||||||
|
@ -344,7 +348,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||||
# `persisted_upto_position` in this case, then it will be correct in the
|
# `persisted_upto_position` in this case, then it will be correct in the
|
||||||
# other cases that are tested above (since they'll hit the same code).
|
# other cases that are tested above (since they'll hit the same code).
|
||||||
|
|
||||||
def test_restart_during_out_of_order_persistence(self):
|
def test_restart_during_out_of_order_persistence(self) -> None:
|
||||||
"""Test that restarting a process while another process is writing out
|
"""Test that restarting a process while another process is writing out
|
||||||
of order updates are handled correctly.
|
of order updates are handled correctly.
|
||||||
"""
|
"""
|
||||||
|
@ -388,7 +392,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||||
id_gen_worker.advance("master", 9)
|
id_gen_worker.advance("master", 9)
|
||||||
self.assertEqual(id_gen_worker.get_positions(), {"master": 9})
|
self.assertEqual(id_gen_worker.get_positions(), {"master": 9})
|
||||||
|
|
||||||
def test_writer_config_change(self):
|
def test_writer_config_change(self) -> None:
|
||||||
"""Test that changing the writer config correctly works."""
|
"""Test that changing the writer config correctly works."""
|
||||||
|
|
||||||
self._insert_row_with_id("first", 3)
|
self._insert_row_with_id("first", 3)
|
||||||
|
@ -421,7 +425,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
# Check that we get a sane next stream ID with this new config.
|
# Check that we get a sane next stream ID with this new config.
|
||||||
|
|
||||||
async def _get_next_async():
|
async def _get_next_async() -> None:
|
||||||
async with id_gen_3.get_next() as stream_id:
|
async with id_gen_3.get_next() as stream_id:
|
||||||
self.assertEqual(stream_id, 6)
|
self.assertEqual(stream_id, 6)
|
||||||
|
|
||||||
|
@ -435,7 +439,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||||
self.assertEqual(id_gen_5.get_current_token_for_writer("first"), 6)
|
self.assertEqual(id_gen_5.get_current_token_for_writer("first"), 6)
|
||||||
self.assertEqual(id_gen_5.get_current_token_for_writer("third"), 6)
|
self.assertEqual(id_gen_5.get_current_token_for_writer("third"), 6)
|
||||||
|
|
||||||
def test_sequence_consistency(self):
|
def test_sequence_consistency(self) -> None:
|
||||||
"""Test that we error out if the table and sequence diverges."""
|
"""Test that we error out if the table and sequence diverges."""
|
||||||
|
|
||||||
# Prefill with some rows
|
# Prefill with some rows
|
||||||
|
@ -458,13 +462,13 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||||
if not USE_POSTGRES_FOR_TESTS:
|
if not USE_POSTGRES_FOR_TESTS:
|
||||||
skip = "Requires Postgres"
|
skip = "Requires Postgres"
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.db_pool: DatabasePool = self.store.db_pool
|
self.db_pool: DatabasePool = self.store.db_pool
|
||||||
|
|
||||||
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
|
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
|
||||||
|
|
||||||
def _setup_db(self, txn):
|
def _setup_db(self, txn: LoggingTransaction) -> None:
|
||||||
txn.execute("CREATE SEQUENCE foobar_seq")
|
txn.execute("CREATE SEQUENCE foobar_seq")
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"""
|
"""
|
||||||
|
@ -493,10 +497,10 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
return self.get_success(self.db_pool.runWithConnection(_create))
|
return self.get_success(self.db_pool.runWithConnection(_create))
|
||||||
|
|
||||||
def _insert_row(self, instance_name: str, stream_id: int):
|
def _insert_row(self, instance_name: str, stream_id: int) -> None:
|
||||||
"""Insert one row as the given instance with given stream_id."""
|
"""Insert one row as the given instance with given stream_id."""
|
||||||
|
|
||||||
def _insert(txn):
|
def _insert(txn: LoggingTransaction) -> None:
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"INSERT INTO foobar VALUES (?, ?)",
|
"INSERT INTO foobar VALUES (?, ?)",
|
||||||
(
|
(
|
||||||
|
@ -514,13 +518,13 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
self.get_success(self.db_pool.runInteraction("_insert_row", _insert))
|
self.get_success(self.db_pool.runInteraction("_insert_row", _insert))
|
||||||
|
|
||||||
def test_single_instance(self):
|
def test_single_instance(self) -> None:
|
||||||
"""Test that reads and writes from a single process are handled
|
"""Test that reads and writes from a single process are handled
|
||||||
correctly.
|
correctly.
|
||||||
"""
|
"""
|
||||||
id_gen = self._create_id_generator()
|
id_gen = self._create_id_generator()
|
||||||
|
|
||||||
async def _get_next_async():
|
async def _get_next_async() -> None:
|
||||||
async with id_gen.get_next() as stream_id:
|
async with id_gen.get_next() as stream_id:
|
||||||
self._insert_row("master", stream_id)
|
self._insert_row("master", stream_id)
|
||||||
|
|
||||||
|
@ -530,7 +534,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), -1)
|
self.assertEqual(id_gen.get_current_token_for_writer("master"), -1)
|
||||||
self.assertEqual(id_gen.get_persisted_upto_position(), -1)
|
self.assertEqual(id_gen.get_persisted_upto_position(), -1)
|
||||||
|
|
||||||
async def _get_next_async2():
|
async def _get_next_async2() -> None:
|
||||||
async with id_gen.get_next_mult(3) as stream_ids:
|
async with id_gen.get_next_mult(3) as stream_ids:
|
||||||
for stream_id in stream_ids:
|
for stream_id in stream_ids:
|
||||||
self._insert_row("master", stream_id)
|
self._insert_row("master", stream_id)
|
||||||
|
@ -548,14 +552,14 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||||
self.assertEqual(second_id_gen.get_current_token_for_writer("master"), -4)
|
self.assertEqual(second_id_gen.get_current_token_for_writer("master"), -4)
|
||||||
self.assertEqual(second_id_gen.get_persisted_upto_position(), -4)
|
self.assertEqual(second_id_gen.get_persisted_upto_position(), -4)
|
||||||
|
|
||||||
def test_multiple_instance(self):
|
def test_multiple_instance(self) -> None:
|
||||||
"""Tests that having multiple instances that get advanced over
|
"""Tests that having multiple instances that get advanced over
|
||||||
federation works corretly.
|
federation works corretly.
|
||||||
"""
|
"""
|
||||||
id_gen_1 = self._create_id_generator("first", writers=["first", "second"])
|
id_gen_1 = self._create_id_generator("first", writers=["first", "second"])
|
||||||
id_gen_2 = self._create_id_generator("second", writers=["first", "second"])
|
id_gen_2 = self._create_id_generator("second", writers=["first", "second"])
|
||||||
|
|
||||||
async def _get_next_async():
|
async def _get_next_async() -> None:
|
||||||
async with id_gen_1.get_next() as stream_id:
|
async with id_gen_1.get_next() as stream_id:
|
||||||
self._insert_row("first", stream_id)
|
self._insert_row("first", stream_id)
|
||||||
id_gen_2.advance("first", stream_id)
|
id_gen_2.advance("first", stream_id)
|
||||||
|
@ -567,7 +571,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||||
self.assertEqual(id_gen_1.get_persisted_upto_position(), -1)
|
self.assertEqual(id_gen_1.get_persisted_upto_position(), -1)
|
||||||
self.assertEqual(id_gen_2.get_persisted_upto_position(), -1)
|
self.assertEqual(id_gen_2.get_persisted_upto_position(), -1)
|
||||||
|
|
||||||
async def _get_next_async2():
|
async def _get_next_async2() -> None:
|
||||||
async with id_gen_2.get_next() as stream_id:
|
async with id_gen_2.get_next() as stream_id:
|
||||||
self._insert_row("second", stream_id)
|
self._insert_row("second", stream_id)
|
||||||
id_gen_1.advance("second", stream_id)
|
id_gen_1.advance("second", stream_id)
|
||||||
|
@ -584,13 +588,13 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||||
if not USE_POSTGRES_FOR_TESTS:
|
if not USE_POSTGRES_FOR_TESTS:
|
||||||
skip = "Requires Postgres"
|
skip = "Requires Postgres"
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.db_pool: DatabasePool = self.store.db_pool
|
self.db_pool: DatabasePool = self.store.db_pool
|
||||||
|
|
||||||
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
|
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
|
||||||
|
|
||||||
def _setup_db(self, txn):
|
def _setup_db(self, txn: LoggingTransaction) -> None:
|
||||||
txn.execute("CREATE SEQUENCE foobar_seq")
|
txn.execute("CREATE SEQUENCE foobar_seq")
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"""
|
"""
|
||||||
|
@ -642,7 +646,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||||
from the postgres sequence.
|
from the postgres sequence.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _insert(txn):
|
def _insert(txn: LoggingTransaction) -> None:
|
||||||
for _ in range(number):
|
for _ in range(number):
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"INSERT INTO %s VALUES (nextval('foobar_seq'), ?)" % (table,),
|
"INSERT INTO %s VALUES (nextval('foobar_seq'), ?)" % (table,),
|
||||||
|
@ -659,7 +663,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
|
self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
|
||||||
|
|
||||||
def test_load_existing_stream(self):
|
def test_load_existing_stream(self) -> None:
|
||||||
"""Test creating ID gens with multiple tables that have rows from after
|
"""Test creating ID gens with multiple tables that have rows from after
|
||||||
the position in `stream_positions` table.
|
the position in `stream_positions` table.
|
||||||
"""
|
"""
|
||||||
|
|
46
tests/storage/test_unsafe_locale.py
Normal file
46
tests/storage/test_unsafe_locale.py
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
# Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from synapse.storage.database import make_conn
|
||||||
|
from synapse.storage.engines._base import IncorrectDatabaseSetup
|
||||||
|
|
||||||
|
from tests.unittest import HomeserverTestCase
|
||||||
|
from tests.utils import USE_POSTGRES_FOR_TESTS
|
||||||
|
|
||||||
|
|
||||||
|
class UnsafeLocaleTest(HomeserverTestCase):
|
||||||
|
if not USE_POSTGRES_FOR_TESTS:
|
||||||
|
skip = "Requires Postgres"
|
||||||
|
|
||||||
|
@patch("synapse.storage.engines.postgres.PostgresEngine.get_db_locale")
|
||||||
|
def test_unsafe_locale(self, mock_db_locale: MagicMock) -> None:
|
||||||
|
mock_db_locale.return_value = ("B", "B")
|
||||||
|
database = self.hs.get_datastores().databases[0]
|
||||||
|
|
||||||
|
db_conn = make_conn(database._database_config, database.engine, "test_unsafe")
|
||||||
|
with self.assertRaises(IncorrectDatabaseSetup):
|
||||||
|
database.engine.check_database(db_conn)
|
||||||
|
with self.assertRaises(IncorrectDatabaseSetup):
|
||||||
|
database.engine.check_new_database(db_conn)
|
||||||
|
db_conn.close()
|
||||||
|
|
||||||
|
def test_safe_locale(self) -> None:
|
||||||
|
database = self.hs.get_datastores().databases[0]
|
||||||
|
|
||||||
|
db_conn = make_conn(database._database_config, database.engine, "test_unsafe")
|
||||||
|
with db_conn.cursor() as txn:
|
||||||
|
res = database.engine.get_db_locale(txn)
|
||||||
|
self.assertEqual(res, ("C", "C"))
|
||||||
|
db_conn.close()
|
|
@ -12,7 +12,7 @@ from tests.unittest import TestCase
|
||||||
|
|
||||||
|
|
||||||
class DummyDistribution(metadata.Distribution):
|
class DummyDistribution(metadata.Distribution):
|
||||||
def __init__(self, version: str):
|
def __init__(self, version: object):
|
||||||
self._version = version
|
self._version = version
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -30,6 +30,7 @@ old = DummyDistribution("0.1.2")
|
||||||
old_release_candidate = DummyDistribution("0.1.2rc3")
|
old_release_candidate = DummyDistribution("0.1.2rc3")
|
||||||
new = DummyDistribution("1.2.3")
|
new = DummyDistribution("1.2.3")
|
||||||
new_release_candidate = DummyDistribution("1.2.3rc4")
|
new_release_candidate = DummyDistribution("1.2.3rc4")
|
||||||
|
distribution_with_no_version = DummyDistribution(None)
|
||||||
|
|
||||||
# could probably use stdlib TestCase --- no need for twisted here
|
# could probably use stdlib TestCase --- no need for twisted here
|
||||||
|
|
||||||
|
@ -67,6 +68,18 @@ class TestDependencyChecker(TestCase):
|
||||||
# should not raise
|
# should not raise
|
||||||
check_requirements()
|
check_requirements()
|
||||||
|
|
||||||
|
def test_version_reported_as_none(self) -> None:
|
||||||
|
"""Complain if importlib.metadata.version() returns None.
|
||||||
|
|
||||||
|
This shouldn't normally happen, but it was seen in the wild (#12223).
|
||||||
|
"""
|
||||||
|
with patch(
|
||||||
|
"synapse.util.check_dependencies.metadata.requires",
|
||||||
|
return_value=["dummypkg >= 1"],
|
||||||
|
):
|
||||||
|
with self.mock_installed_package(distribution_with_no_version):
|
||||||
|
self.assertRaises(DependencyException, check_requirements)
|
||||||
|
|
||||||
def test_checks_ignore_dev_dependencies(self) -> None:
|
def test_checks_ignore_dev_dependencies(self) -> None:
|
||||||
"""Bot generic and per-extra checks should ignore dev dependencies."""
|
"""Bot generic and per-extra checks should ignore dev dependencies."""
|
||||||
with patch(
|
with patch(
|
||||||
|
|
110
tests/utils.py
110
tests/utils.py
|
@ -15,13 +15,8 @@
|
||||||
|
|
||||||
import atexit
|
import atexit
|
||||||
import os
|
import os
|
||||||
from unittest.mock import Mock, patch
|
|
||||||
from urllib import parse as urlparse
|
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.api.errors import CodeMessageException, cs_error
|
|
||||||
from synapse.api.room_versions import RoomVersions
|
from synapse.api.room_versions import RoomVersions
|
||||||
from synapse.config.homeserver import HomeServerConfig
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
from synapse.config.server import DEFAULT_ROOM_VERSION
|
from synapse.config.server import DEFAULT_ROOM_VERSION
|
||||||
|
@ -187,111 +182,6 @@ def mock_getRawHeaders(headers=None):
|
||||||
return getRawHeaders
|
return getRawHeaders
|
||||||
|
|
||||||
|
|
||||||
# This is a mock /resource/ not an entire server
|
|
||||||
class MockHttpResource:
|
|
||||||
def __init__(self, prefix=""):
|
|
||||||
self.callbacks = [] # 3-tuple of method/pattern/function
|
|
||||||
self.prefix = prefix
|
|
||||||
|
|
||||||
def trigger_get(self, path):
|
|
||||||
return self.trigger(b"GET", path, None)
|
|
||||||
|
|
||||||
@patch("twisted.web.http.Request")
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def trigger(
|
|
||||||
self, http_method, path, content, mock_request, federation_auth_origin=None
|
|
||||||
):
|
|
||||||
"""Fire an HTTP event.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
http_method : The HTTP method
|
|
||||||
path : The HTTP path
|
|
||||||
content : The HTTP body
|
|
||||||
mock_request : Mocked request to pass to the event so it can get
|
|
||||||
content.
|
|
||||||
federation_auth_origin (bytes|None): domain to authenticate as, for federation
|
|
||||||
Returns:
|
|
||||||
A tuple of (code, response)
|
|
||||||
Raises:
|
|
||||||
KeyError If no event is found which will handle the path.
|
|
||||||
"""
|
|
||||||
path = self.prefix + path
|
|
||||||
|
|
||||||
# annoyingly we return a twisted http request which has chained calls
|
|
||||||
# to get at the http content, hence mock it here.
|
|
||||||
mock_content = Mock()
|
|
||||||
config = {"read.return_value": content}
|
|
||||||
mock_content.configure_mock(**config)
|
|
||||||
mock_request.content = mock_content
|
|
||||||
|
|
||||||
mock_request.method = http_method.encode("ascii")
|
|
||||||
mock_request.uri = path.encode("ascii")
|
|
||||||
|
|
||||||
mock_request.getClientIP.return_value = "-"
|
|
||||||
|
|
||||||
headers = {}
|
|
||||||
if federation_auth_origin is not None:
|
|
||||||
headers[b"Authorization"] = [
|
|
||||||
b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,)
|
|
||||||
]
|
|
||||||
mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers)
|
|
||||||
|
|
||||||
# return the right path if the event requires it
|
|
||||||
mock_request.path = path
|
|
||||||
|
|
||||||
# add in query params to the right place
|
|
||||||
try:
|
|
||||||
mock_request.args = urlparse.parse_qs(path.split("?")[1])
|
|
||||||
mock_request.path = path.split("?")[0]
|
|
||||||
path = mock_request.path
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if isinstance(path, bytes):
|
|
||||||
path = path.decode("utf8")
|
|
||||||
|
|
||||||
for (method, pattern, func) in self.callbacks:
|
|
||||||
if http_method != method:
|
|
||||||
continue
|
|
||||||
|
|
||||||
matcher = pattern.match(path)
|
|
||||||
if matcher:
|
|
||||||
try:
|
|
||||||
args = [urlparse.unquote(u) for u in matcher.groups()]
|
|
||||||
|
|
||||||
(code, response) = yield defer.ensureDeferred(
|
|
||||||
func(mock_request, *args)
|
|
||||||
)
|
|
||||||
return code, response
|
|
||||||
except CodeMessageException as e:
|
|
||||||
return e.code, cs_error(e.msg, code=e.errcode)
|
|
||||||
|
|
||||||
raise KeyError("No event can handle %s" % path)
|
|
||||||
|
|
||||||
def register_paths(self, method, path_patterns, callback, servlet_name):
|
|
||||||
for path_pattern in path_patterns:
|
|
||||||
self.callbacks.append((method, path_pattern, callback))
|
|
||||||
|
|
||||||
|
|
||||||
class MockKey:
|
|
||||||
alg = "mock_alg"
|
|
||||||
version = "mock_version"
|
|
||||||
signature = b"\x9a\x87$"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def verify_key(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def sign(self, message):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def verify(self, message, sig):
|
|
||||||
assert sig == b"\x9a\x87$"
|
|
||||||
|
|
||||||
def encode(self):
|
|
||||||
return b"<fake_encoded_key>"
|
|
||||||
|
|
||||||
|
|
||||||
class MockClock:
|
class MockClock:
|
||||||
now = 1000
|
now = 1000
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue