0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-15 22:42:23 +01:00

Merge remote-tracking branch 'upstream/release-v1.56'

This commit is contained in:
Tulir Asokan 2022-03-29 13:30:30 +03:00
commit ef6a9c8a70
98 changed files with 6349 additions and 5436 deletions

View file

@ -377,7 +377,7 @@ jobs:
# Run Complement # Run Complement
- run: | - run: |
set -o pipefail set -o pipefail
go test -v -json -p 1 -tags synapse_blacklist,msc2403,msc2716,msc3030 ./tests/... 2>&1 | gotestfmt go test -v -json -tags synapse_blacklist,msc2403,msc2716,msc3030 ./tests/... 2>&1 | gotestfmt
shell: bash shell: bash
name: Run Complement Tests name: Run Complement Tests
env: env:

3640
CHANGES-pre-1.0.md Normal file

File diff suppressed because it is too large Load diff

3724
CHANGES.md

File diff suppressed because it is too large Load diff

18
debian/changelog vendored
View file

@ -1,3 +1,21 @@
matrix-synapse-py3 (1.56.0~rc1) stable; urgency=medium
* New synapse release 1.56.0~rc1.
-- Synapse Packaging team <packages@matrix.org> Tue, 29 Mar 2022 10:40:50 +0100
matrix-synapse-py3 (1.55.2) stable; urgency=medium
* New synapse release 1.55.2.
-- Synapse Packaging team <packages@matrix.org> Thu, 24 Mar 2022 19:07:11 +0000
matrix-synapse-py3 (1.55.1) stable; urgency=medium
* New synapse release 1.55.1.
-- Synapse Packaging team <packages@matrix.org> Thu, 24 Mar 2022 17:44:23 +0000
matrix-synapse-py3 (1.55.0) stable; urgency=medium matrix-synapse-py3 (1.55.0) stable; urgency=medium
* New synapse release 1.55.0. * New synapse release 1.55.0.

View file

@ -38,6 +38,7 @@ for port in 8080 8081 8082; do
printf '\n\n# Customisation made by demo/start.sh\n\n' printf '\n\n# Customisation made by demo/start.sh\n\n'
echo "public_baseurl: http://localhost:$port/" echo "public_baseurl: http://localhost:$port/"
echo 'enable_registration: true' echo 'enable_registration: true'
echo 'enable_registration_without_verification: true'
echo '' echo ''
# Warning, this heredoc depends on the interaction of tabs and spaces. # Warning, this heredoc depends on the interaction of tabs and spaces.

View file

@ -172,7 +172,7 @@ any of the subsequent implementations of this callback.
_First introduced in Synapse v1.37.0_ _First introduced in Synapse v1.37.0_
```python ```python
async def check_username_for_spam(user_profile: Dict[str, str]) -> bool async def check_username_for_spam(user_profile: synapse.module_api.UserProfile) -> bool
``` ```
Called when computing search results in the user directory. The module must return a Called when computing search results in the user directory. The module must return a
@ -182,9 +182,11 @@ search results; otherwise return `False`.
The profile is represented as a dictionary with the following keys: The profile is represented as a dictionary with the following keys:
* `user_id`: The Matrix ID for this user. * `user_id: str`. The Matrix ID for this user.
* `display_name`: The user's display name. * `display_name: Optional[str]`. The user's display name, or `None` if this user
* `avatar_url`: The `mxc://` URL to the user's avatar. has not set a display name.
* `avatar_url: Optional[str]`. The `mxc://` URL to the user's avatar, or `None`
if this user has not set an avatar.
The module is given a copy of the original dictionary, so modifying it from within the The module is given a copy of the original dictionary, so modifying it from within the
module cannot modify a user's profile when included in user directory search results. module cannot modify a user's profile when included in user directory search results.

View file

@ -225,6 +225,8 @@ oidc_providers:
3. Create an application for synapse in Authentik and link it to the provider. 3. Create an application for synapse in Authentik and link it to the provider.
4. Note the slug of your application, Client ID and Client Secret. 4. Note the slug of your application, Client ID and Client Secret.
Note: RSA keys must be used for signing for Authentik, ECC keys do not work.
Synapse config: Synapse config:
```yaml ```yaml
oidc_providers: oidc_providers:
@ -240,7 +242,7 @@ oidc_providers:
- "email" - "email"
user_mapping_provider: user_mapping_provider:
config: config:
localpart_template: "{{ user.preferred_username }}}" localpart_template: "{{ user.preferred_username }}"
display_name_template: "{{ user.preferred_username|capitalize }}" # TO BE FILLED: If your users have names in Authentik and you want those in Synapse, this should be replaced with user.name|capitalize. display_name_template: "{{ user.preferred_username|capitalize }}" # TO BE FILLED: If your users have names in Authentik and you want those in Synapse, this should be replaced with user.name|capitalize.
``` ```

View file

@ -234,12 +234,13 @@ host all all ::1/128 ident
### Fixing incorrect `COLLATE` or `CTYPE` ### Fixing incorrect `COLLATE` or `CTYPE`
Synapse will refuse to set up a new database if it has the wrong values of Synapse will refuse to set up a new database if it has the wrong values of
`COLLATE` and `CTYPE` set, and will log warnings on existing databases. Using `COLLATE` and `CTYPE` set. Synapse will also refuse to start an existing database with incorrect values
different locales can cause issues if the locale library is updated from of `COLLATE` and `CTYPE` unless the config flag `allow_unsafe_locale`, found in the
`database` section of the config, is set to true. Using different locales can cause issues if the locale library is updated from
underneath the database, or if a different version of the locale is used on any underneath the database, or if a different version of the locale is used on any
replicas. replicas.
The safest way to fix the issue is to dump the database and recreate it with If you have a databse with an unsafe locale, the safest way to fix the issue is to dump the database and recreate it with
the correct locale parameter (as shown above). It is also possible to change the the correct locale parameter (as shown above). It is also possible to change the
parameters on a live database and run a `REINDEX` on the entire database, parameters on a live database and run a `REINDEX` on the entire database,
however extreme care must be taken to avoid database corruption. however extreme care must be taken to avoid database corruption.

View file

@ -182,7 +182,7 @@ matrix.example.com {
``` ```
frontend https frontend https
bind :::443 v4v6 ssl crt /etc/ssl/haproxy/ strict-sni alpn h2,http/1.1 bind *:443,[::]:443 ssl crt /etc/ssl/haproxy/ strict-sni alpn h2,http/1.1
http-request set-header X-Forwarded-Proto https if { ssl_fc } http-request set-header X-Forwarded-Proto https if { ssl_fc }
http-request set-header X-Forwarded-Proto http if !{ ssl_fc } http-request set-header X-Forwarded-Proto http if !{ ssl_fc }
http-request set-header X-Forwarded-For %[src] http-request set-header X-Forwarded-For %[src]
@ -195,7 +195,7 @@ frontend https
use_backend matrix if matrix-host matrix-path use_backend matrix if matrix-host matrix-path
frontend matrix-federation frontend matrix-federation
bind :::8448 v4v6 ssl crt /etc/ssl/haproxy/synapse.pem alpn h2,http/1.1 bind *:8448,[::]:8448 ssl crt /etc/ssl/haproxy/synapse.pem alpn h2,http/1.1
http-request set-header X-Forwarded-Proto https if { ssl_fc } http-request set-header X-Forwarded-Proto https if { ssl_fc }
http-request set-header X-Forwarded-Proto http if !{ ssl_fc } http-request set-header X-Forwarded-Proto http if !{ ssl_fc }
http-request set-header X-Forwarded-For %[src] http-request set-header X-Forwarded-For %[src]

View file

@ -783,6 +783,12 @@ caches:
# 'txn_limit' gives the maximum number of transactions to run per connection # 'txn_limit' gives the maximum number of transactions to run per connection
# before reconnecting. Defaults to 0, which means no limit. # before reconnecting. Defaults to 0, which means no limit.
# #
# 'allow_unsafe_locale' is an option specific to Postgres. Under the default behavior, Synapse will refuse to
# start if the postgres db is set to a non-C locale. You can override this behavior (which is *not* recommended)
# by setting 'allow_unsafe_locale' to true. Note that doing so may corrupt your database. You can find more information
# here: https://matrix-org.github.io/synapse/latest/postgres.html#fixing-incorrect-collate-or-ctype and here:
# https://wiki.postgresql.org/wiki/Locale_data_changes
#
# 'args' gives options which are passed through to the database engine, # 'args' gives options which are passed through to the database engine,
# except for options starting 'cp_', which are used to configure the Twisted # except for options starting 'cp_', which are used to configure the Twisted
# connection pool. For a reference to valid arguments, see: # connection pool. For a reference to valid arguments, see:
@ -1212,10 +1218,18 @@ oembed:
# Registration can be rate-limited using the parameters in the "Ratelimiting" # Registration can be rate-limited using the parameters in the "Ratelimiting"
# section of this file. # section of this file.
# Enable registration for new users. # Enable registration for new users. Defaults to 'false'. It is highly recommended that if you enable registration,
# you use either captcha, email, or token-based verification to verify that new users are not bots. In order to enable registration
# without any verification, you must also set `enable_registration_without_verification`, found below.
# #
#enable_registration: false #enable_registration: false
# Enable registration without email or captcha verification. Note: this option is *not* recommended,
# as registration without verification is a known vector for spam and abuse. Defaults to false. Has no effect
# unless `enable_registration` is also enabled.
#
#enable_registration_without_verification: true
# Time that a user's session remains valid for, after they log in. # Time that a user's session remains valid for, after they log in.
# #
# Note that this is not currently compatible with guest logins. # Note that this is not currently compatible with guest logins.

View file

@ -99,8 +99,21 @@ experimental_features:
groups_enabled: false groups_enabled: false
``` ```
## Change in behaviour for PostgreSQL databases with unsafe locale
Synapse now refuses to start when using PostgreSQL with non-`C` values for `COLLATE` and
`CTYPE` unless the config flag `allow_unsafe_locale`, found in the database section of
the configuration file, is set to `true`. See the [PostgreSQL documentation](https://matrix-org.github.io/synapse/latest/postgres.html#fixing-incorrect-collate-or-ctype)
for more information and instructions on how to fix a database with incorrect values.
# Upgrading to v1.55.0 # Upgrading to v1.55.0
## Open registration without verification is now disabled by default
Synapse will refuse to start if registration is enabled without email, captcha, or token-based verification unless the new config
flag `enable_registration_without_verification` is set to "true".
## `synctl` script has been moved ## `synctl` script has been moved
The `synctl` script The `synctl` script

View file

@ -185,8 +185,8 @@ worker: refer to the [stream writers](#stream-writers) section below for further
information. information.
# Sync requests # Sync requests
^/_matrix/client/(v2_alpha|r0|v3)/sync$ ^/_matrix/client/(r0|v3)/sync$
^/_matrix/client/(api/v1|v2_alpha|r0|v3)/events$ ^/_matrix/client/(api/v1|r0|v3)/events$
^/_matrix/client/(api/v1|r0|v3)/initialSync$ ^/_matrix/client/(api/v1|r0|v3)/initialSync$
^/_matrix/client/(api/v1|r0|v3)/rooms/[^/]+/initialSync$ ^/_matrix/client/(api/v1|r0|v3)/rooms/[^/]+/initialSync$
@ -200,13 +200,9 @@ information.
^/_matrix/federation/v1/query/ ^/_matrix/federation/v1/query/
^/_matrix/federation/v1/make_join/ ^/_matrix/federation/v1/make_join/
^/_matrix/federation/v1/make_leave/ ^/_matrix/federation/v1/make_leave/
^/_matrix/federation/v1/send_join/ ^/_matrix/federation/(v1|v2)/send_join/
^/_matrix/federation/v2/send_join/ ^/_matrix/federation/(v1|v2)/send_leave/
^/_matrix/federation/v1/send_leave/ ^/_matrix/federation/(v1|v2)/invite/
^/_matrix/federation/v2/send_leave/
^/_matrix/federation/v1/invite/
^/_matrix/federation/v2/invite/
^/_matrix/federation/v1/query_auth/
^/_matrix/federation/v1/event_auth/ ^/_matrix/federation/v1/event_auth/
^/_matrix/federation/v1/exchange_third_party_invite/ ^/_matrix/federation/v1/exchange_third_party_invite/
^/_matrix/federation/v1/user/devices/ ^/_matrix/federation/v1/user/devices/
@ -274,6 +270,8 @@ information.
Additionally, the following REST endpoints can be handled for GET requests: Additionally, the following REST endpoints can be handled for GET requests:
^/_matrix/federation/v1/groups/ ^/_matrix/federation/v1/groups/
^/_matrix/client/(api/v1|r0|v3|unstable)/pushrules/
^/_matrix/client/(r0|v3|unstable)/groups/
Pagination requests can also be handled, but all requests for a given Pagination requests can also be handled, but all requests for a given
room must be routed to the same instance. Additionally, care must be taken to room must be routed to the same instance. Additionally, care must be taken to
@ -397,23 +395,23 @@ the stream writer for the `typing` stream:
The following endpoints should be routed directly to the worker configured as The following endpoints should be routed directly to the worker configured as
the stream writer for the `to_device` stream: the stream writer for the `to_device` stream:
^/_matrix/client/(api/v1|r0|v3|unstable)/sendToDevice/ ^/_matrix/client/(r0|v3|unstable)/sendToDevice/
##### The `account_data` stream ##### The `account_data` stream
The following endpoints should be routed directly to the worker configured as The following endpoints should be routed directly to the worker configured as
the stream writer for the `account_data` stream: the stream writer for the `account_data` stream:
^/_matrix/client/(api/v1|r0|v3|unstable)/.*/tags ^/_matrix/client/(r0|v3|unstable)/.*/tags
^/_matrix/client/(api/v1|r0|v3|unstable)/.*/account_data ^/_matrix/client/(r0|v3|unstable)/.*/account_data
##### The `receipts` stream ##### The `receipts` stream
The following endpoints should be routed directly to the worker configured as The following endpoints should be routed directly to the worker configured as
the stream writer for the `receipts` stream: the stream writer for the `receipts` stream:
^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/receipt ^/_matrix/client/(r0|v3|unstable)/rooms/.*/receipt
^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/read_markers ^/_matrix/client/(r0|v3|unstable)/rooms/.*/read_markers
##### The `presence` stream ##### The `presence` stream
@ -528,19 +526,28 @@ Note that if a reverse proxy is used , then `/_matrix/media/` must be routed for
Handles searches in the user directory. It can handle REST endpoints matching Handles searches in the user directory. It can handle REST endpoints matching
the following regular expressions: the following regular expressions:
^/_matrix/client/(api/v1|r0|v3|unstable)/user_directory/search$ ^/_matrix/client/(r0|v3|unstable)/user_directory/search$
When using this worker you must also set `update_user_directory: False` in the When using this worker you must also set `update_user_directory: false` in the
shared configuration file to stop the main synapse running background shared configuration file to stop the main synapse running background
jobs related to updating the user directory. jobs related to updating the user directory.
Above endpoint is not *required* to be routed to this worker. By default,
`update_user_directory` is set to `true`, which means the main process
will handle updates. All workers configured with `client` can handle the above
endpoint as long as either this worker or the main process are configured to
handle it, and are online.
If `update_user_directory` is set to `false`, and this worker is not running,
the above endpoint may give outdated results.
### `synapse.app.frontend_proxy` ### `synapse.app.frontend_proxy`
Proxies some frequently-requested client endpoints to add caching and remove Proxies some frequently-requested client endpoints to add caching and remove
load from the main synapse. It can handle REST endpoints matching the following load from the main synapse. It can handle REST endpoints matching the following
regular expressions: regular expressions:
^/_matrix/client/(api/v1|r0|v3|unstable)/keys/upload ^/_matrix/client/(r0|v3|unstable)/keys/upload
If `use_presence` is False in the homeserver config, it can also handle REST If `use_presence` is False in the homeserver config, it can also handle REST
endpoints matching the following regular expressions: endpoints matching the following regular expressions:

View file

@ -38,17 +38,11 @@ exclude = (?x)
|synapse/_scripts/update_synapse_database.py |synapse/_scripts/update_synapse_database.py
|synapse/storage/databases/__init__.py |synapse/storage/databases/__init__.py
|synapse/storage/databases/main/__init__.py
|synapse/storage/databases/main/cache.py |synapse/storage/databases/main/cache.py
|synapse/storage/databases/main/devices.py |synapse/storage/databases/main/devices.py
|synapse/storage/databases/main/event_federation.py |synapse/storage/databases/main/event_federation.py
|synapse/storage/databases/main/group_server.py
|synapse/storage/databases/main/metrics.py
|synapse/storage/databases/main/monthly_active_users.py
|synapse/storage/databases/main/push_rule.py |synapse/storage/databases/main/push_rule.py
|synapse/storage/databases/main/receipts.py
|synapse/storage/databases/main/roommember.py |synapse/storage/databases/main/roommember.py
|synapse/storage/databases/main/search.py
|synapse/storage/databases/main/state.py |synapse/storage/databases/main/state.py
|synapse/storage/schema/ |synapse/storage/schema/
@ -66,14 +60,6 @@ exclude = (?x)
|tests/federation/test_federation_server.py |tests/federation/test_federation_server.py
|tests/federation/transport/test_knocking.py |tests/federation/transport/test_knocking.py
|tests/federation/transport/test_server.py |tests/federation/transport/test_server.py
|tests/handlers/test_cas.py
|tests/handlers/test_directory.py
|tests/handlers/test_e2e_keys.py
|tests/handlers/test_federation.py
|tests/handlers/test_oidc.py
|tests/handlers/test_presence.py
|tests/handlers/test_profile.py
|tests/handlers/test_saml.py
|tests/handlers/test_typing.py |tests/handlers/test_typing.py
|tests/http/federation/test_matrix_federation_agent.py |tests/http/federation/test_matrix_federation_agent.py
|tests/http/federation/test_srv_resolver.py |tests/http/federation/test_srv_resolver.py
@ -85,7 +71,6 @@ exclude = (?x)
|tests/logging/test_terse_json.py |tests/logging/test_terse_json.py
|tests/module_api/test_api.py |tests/module_api/test_api.py
|tests/push/test_email.py |tests/push/test_email.py
|tests/push/test_http.py
|tests/push/test_presentable_names.py |tests/push/test_presentable_names.py
|tests/push/test_push_rule_evaluator.py |tests/push/test_push_rule_evaluator.py
|tests/rest/client/test_transactions.py |tests/rest/client/test_transactions.py
@ -94,12 +79,7 @@ exclude = (?x)
|tests/server.py |tests/server.py
|tests/server_notices/test_resource_limits_server_notices.py |tests/server_notices/test_resource_limits_server_notices.py
|tests/state/test_v2.py |tests/state/test_v2.py
|tests/storage/test_background_update.py
|tests/storage/test_base.py |tests/storage/test_base.py
|tests/storage/test_client_ips.py
|tests/storage/test_database.py
|tests/storage/test_event_federation.py
|tests/storage/test_id_generators.py
|tests/storage/test_roommember.py |tests/storage/test_roommember.py
|tests/test_metrics.py |tests/test_metrics.py
|tests/test_phone_home.py |tests/test_phone_home.py

View file

@ -66,11 +66,15 @@ def cli():
./scripts-dev/release.py tag ./scripts-dev/release.py tag
# ... wait for asssets to build ... # ... wait for assets to build ...
./scripts-dev/release.py publish ./scripts-dev/release.py publish
./scripts-dev/release.py upload ./scripts-dev/release.py upload
# Optional: generate some nice links for the announcement
./scripts-dev/release.py upload
If the env var GH_TOKEN (or GITHUB_TOKEN) is set, or passed into the If the env var GH_TOKEN (or GITHUB_TOKEN) is set, or passed into the
`tag`/`publish` command, then a new draft release will be created/published. `tag`/`publish` command, then a new draft release will be created/published.
""" """
@ -415,6 +419,41 @@ def upload():
) )
@cli.command()
def announce():
"""Generate markdown to announce the release."""
current_version, _, _ = parse_version_from_module()
tag_name = f"v{current_version}"
click.echo(
f"""
Hi everyone. Synapse {current_version} has just been released.
[notes](https://github.com/matrix-org/synapse/releases/tag/{tag_name}) |\
[docker](https://hub.docker.com/r/matrixdotorg/synapse/tags?name={tag_name}) | \
[debs](https://packages.matrix.org/debian/) | \
[pypi](https://pypi.org/project/matrix-synapse/{current_version}/)"""
)
if "rc" in tag_name:
click.echo(
"""
Announce the RC in
- #homeowners:matrix.org (Synapse Announcements)
- #synapse-dev:matrix.org"""
)
else:
click.echo(
"""
Announce the release in
- #homeowners:matrix.org (Synapse Announcements), bumping the version in the topic
- #synapse:matrix.org (Synapse Admins), bumping the version in the topic
- #synapse-dev:matrix.org
- #synapse-package-maintainers:matrix.org"""
)
def parse_version_from_module() -> Tuple[ def parse_version_from_module() -> Tuple[
version.Version, redbaron.RedBaron, redbaron.Node version.Version, redbaron.RedBaron, redbaron.Node
]: ]:

View file

@ -108,6 +108,7 @@ CONDITIONAL_REQUIREMENTS["mypy"] = [
"types-jsonschema>=3.2.0", "types-jsonschema>=3.2.0",
"types-opentracing>=2.4.2", "types-opentracing>=2.4.2",
"types-Pillow>=8.3.4", "types-Pillow>=8.3.4",
"types-psycopg2>=2.9.9",
"types-pyOpenSSL>=20.0.7", "types-pyOpenSSL>=20.0.7",
"types-PyYAML>=5.4.10", "types-PyYAML>=5.4.10",
"types-requests>=2.26.0", "types-requests>=2.26.0",

View file

@ -68,7 +68,7 @@ try:
except ImportError: except ImportError:
pass pass
__version__ = "1.55.0" __version__ = "1.56.0rc1"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when # We import here so that we don't have to install a bunch of deps when

View file

@ -261,7 +261,10 @@ class SynapseHomeServer(HomeServer):
resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self) resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self)
if name == "metrics" and self.config.metrics.enable_metrics: if name == "metrics" and self.config.metrics.enable_metrics:
resources[METRICS_PREFIX] = MetricsResource(RegistryProxy) metrics_resource: Resource = MetricsResource(RegistryProxy)
if compress:
metrics_resource = gz_wrap(metrics_resource)
resources[METRICS_PREFIX] = metrics_resource
if name == "replication": if name == "replication":
resources[REPLICATION_PREFIX] = ReplicationRestResource(self) resources[REPLICATION_PREFIX] = ReplicationRestResource(self)
@ -348,6 +351,23 @@ def setup(config_options: List[str]) -> SynapseHomeServer:
if config.server.gc_seconds: if config.server.gc_seconds:
synapse.metrics.MIN_TIME_BETWEEN_GCS = config.server.gc_seconds synapse.metrics.MIN_TIME_BETWEEN_GCS = config.server.gc_seconds
if (
config.registration.enable_registration
and not config.registration.enable_registration_without_verification
):
if (
not config.captcha.enable_registration_captcha
and not config.registration.registrations_require_3pid
and not config.registration.registration_requires_token
):
raise ConfigError(
"You have enabled open registration without any verification. This is a known vector for "
"spam and abuse. If you would like to allow public registration, please consider adding email, "
"captcha, or token-based verification. Otherwise this check can be removed by setting the "
"`enable_registration_without_verification` config option to `true`."
)
hs = SynapseHomeServer( hs = SynapseHomeServer(
config.server.server_name, config.server.server_name,
config=config, config=config,

View file

@ -37,6 +37,12 @@ DEFAULT_CONFIG = """\
# 'txn_limit' gives the maximum number of transactions to run per connection # 'txn_limit' gives the maximum number of transactions to run per connection
# before reconnecting. Defaults to 0, which means no limit. # before reconnecting. Defaults to 0, which means no limit.
# #
# 'allow_unsafe_locale' is an option specific to Postgres. Under the default behavior, Synapse will refuse to
# start if the postgres db is set to a non-C locale. You can override this behavior (which is *not* recommended)
# by setting 'allow_unsafe_locale' to true. Note that doing so may corrupt your database. You can find more information
# here: https://matrix-org.github.io/synapse/latest/postgres.html#fixing-incorrect-collate-or-ctype and here:
# https://wiki.postgresql.org/wiki/Locale_data_changes
#
# 'args' gives options which are passed through to the database engine, # 'args' gives options which are passed through to the database engine,
# except for options starting 'cp_', which are used to configure the Twisted # except for options starting 'cp_', which are used to configure the Twisted
# connection pool. For a reference to valid arguments, see: # connection pool. For a reference to valid arguments, see:

View file

@ -33,6 +33,10 @@ class RegistrationConfig(Config):
str(config["disable_registration"]) str(config["disable_registration"])
) )
self.enable_registration_without_verification = strtobool(
str(config.get("enable_registration_without_verification", False))
)
self.registrations_require_3pid = config.get("registrations_require_3pid", []) self.registrations_require_3pid = config.get("registrations_require_3pid", [])
self.allowed_local_3pids = config.get("allowed_local_3pids", []) self.allowed_local_3pids = config.get("allowed_local_3pids", [])
self.enable_3pid_lookup = config.get("enable_3pid_lookup", True) self.enable_3pid_lookup = config.get("enable_3pid_lookup", True)
@ -207,10 +211,18 @@ class RegistrationConfig(Config):
# Registration can be rate-limited using the parameters in the "Ratelimiting" # Registration can be rate-limited using the parameters in the "Ratelimiting"
# section of this file. # section of this file.
# Enable registration for new users. # Enable registration for new users. Defaults to 'false'. It is highly recommended that if you enable registration,
# you use either captcha, email, or token-based verification to verify that new users are not bots. In order to enable registration
# without any verification, you must also set `enable_registration_without_verification`, found below.
# #
#enable_registration: false #enable_registration: false
# Enable registration without email or captcha verification. Note: this option is *not* recommended,
# as registration without verification is a known vector for spam and abuse. Defaults to false. Has no effect
# unless `enable_registration` is also enabled.
#
#enable_registration_without_verification: true
# Time that a user's session remains valid for, after they log in. # Time that a user's session remains valid for, after they log in.
# #
# Note that this is not currently compatible with guest logins. # Note that this is not currently compatible with guest logins.

View file

@ -676,6 +676,10 @@ class ServerConfig(Config):
): ):
raise ConfigError("'custom_template_directory' must be a string") raise ConfigError("'custom_template_directory' must be a string")
self.use_account_validity_in_account_status: bool = (
config.get("use_account_validity_in_account_status") or False
)
def has_tls_listener(self) -> bool: def has_tls_listener(self) -> bool:
return any(listener.tls for listener in self.listeners) return any(listener.tls for listener in self.listeners)

View file

@ -25,8 +25,8 @@ logger = logging.getLogger(__name__)
LEGACY_SPAM_CHECKER_WARNING = """ LEGACY_SPAM_CHECKER_WARNING = """
This server is using a spam checker module that is implementing the deprecated spam This server is using a spam checker module that is implementing the deprecated spam
checker interface. Please check with the module's maintainer to see if a new version checker interface. Please check with the module's maintainer to see if a new version
supporting Synapse's generic modules system is available. supporting Synapse's generic modules system is available. For more information, please
For more information, please see https://matrix-org.github.io/synapse/latest/modules.html see https://matrix-org.github.io/synapse/latest/modules/index.html
---------------------------------------------------------------------------------------""" ---------------------------------------------------------------------------------------"""

View file

@ -21,7 +21,6 @@ from typing import (
Awaitable, Awaitable,
Callable, Callable,
Collection, Collection,
Dict,
List, List,
Optional, Optional,
Tuple, Tuple,
@ -31,7 +30,7 @@ from typing import (
from synapse.rest.media.v1._base import FileInfo from synapse.rest.media.v1._base import FileInfo
from synapse.rest.media.v1.media_storage import ReadableFileWrapper from synapse.rest.media.v1.media_storage import ReadableFileWrapper
from synapse.spam_checker_api import RegistrationBehaviour from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import RoomAlias from synapse.types import RoomAlias, UserProfile
from synapse.util.async_helpers import maybe_awaitable from synapse.util.async_helpers import maybe_awaitable
if TYPE_CHECKING: if TYPE_CHECKING:
@ -50,7 +49,7 @@ USER_MAY_SEND_3PID_INVITE_CALLBACK = Callable[[str, str, str, str], Awaitable[bo
USER_MAY_CREATE_ROOM_CALLBACK = Callable[[str], Awaitable[bool]] USER_MAY_CREATE_ROOM_CALLBACK = Callable[[str], Awaitable[bool]]
USER_MAY_CREATE_ROOM_ALIAS_CALLBACK = Callable[[str, RoomAlias], Awaitable[bool]] USER_MAY_CREATE_ROOM_ALIAS_CALLBACK = Callable[[str, RoomAlias], Awaitable[bool]]
USER_MAY_PUBLISH_ROOM_CALLBACK = Callable[[str, str], Awaitable[bool]] USER_MAY_PUBLISH_ROOM_CALLBACK = Callable[[str, str], Awaitable[bool]]
CHECK_USERNAME_FOR_SPAM_CALLBACK = Callable[[Dict[str, str]], Awaitable[bool]] CHECK_USERNAME_FOR_SPAM_CALLBACK = Callable[[UserProfile], Awaitable[bool]]
LEGACY_CHECK_REGISTRATION_FOR_SPAM_CALLBACK = Callable[ LEGACY_CHECK_REGISTRATION_FOR_SPAM_CALLBACK = Callable[
[ [
Optional[dict], Optional[dict],
@ -383,7 +382,7 @@ class SpamChecker:
return True return True
async def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool: async def check_username_for_spam(self, user_profile: UserProfile) -> bool:
"""Checks if a user ID or display name are considered "spammy" by this server. """Checks if a user ID or display name are considered "spammy" by this server.
If the server considers a username spammy, then it will not be included in If the server considers a username spammy, then it will not be included in

View file

@ -38,8 +38,8 @@ from synapse.util.frozenutils import unfreeze
from . import EventBase from . import EventBase
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.handlers.relations import BundledAggregations
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.databases.main.relations import BundledAggregations
# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\' # Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'

View file

@ -22,7 +22,6 @@ from typing import (
Callable, Callable,
Collection, Collection,
Dict, Dict,
Iterable,
List, List,
Optional, Optional,
Tuple, Tuple,
@ -577,10 +576,10 @@ class FederationServer(FederationBase):
async def _on_context_state_request_compute( async def _on_context_state_request_compute(
self, room_id: str, event_id: Optional[str] self, room_id: str, event_id: Optional[str]
) -> Dict[str, list]: ) -> Dict[str, list]:
pdus: Collection[EventBase]
if event_id: if event_id:
pdus: Iterable[EventBase] = await self.handler.get_state_for_pdu( event_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
room_id, event_id pdus = await self.store.get_events_as_list(event_ids)
)
else: else:
pdus = (await self.state.get_current_state(room_id)).values() pdus = (await self.state.get_current_state(room_id)).values()
@ -1093,7 +1092,7 @@ class FederationServer(FederationBase):
# has started processing). # has started processing).
while True: while True:
async with lock: async with lock:
logger.info("handling received PDU: %s", event) logger.info("handling received PDU in room %s: %s", room_id, event)
try: try:
with nested_logging_context(event.event_id): with nested_logging_context(event.event_id):
await self._federation_event_handler.on_receive_pdu( await self._federation_event_handler.on_receive_pdu(

View file

@ -26,6 +26,10 @@ class AccountHandler:
self._main_store = hs.get_datastores().main self._main_store = hs.get_datastores().main
self._is_mine = hs.is_mine self._is_mine = hs.is_mine
self._federation_client = hs.get_federation_client() self._federation_client = hs.get_federation_client()
self._use_account_validity_in_account_status = (
hs.config.server.use_account_validity_in_account_status
)
self._account_validity_handler = hs.get_account_validity_handler()
async def get_account_statuses( async def get_account_statuses(
self, self,
@ -106,6 +110,13 @@ class AccountHandler:
"deactivated": userinfo.is_deactivated, "deactivated": userinfo.is_deactivated,
} }
if self._use_account_validity_in_account_status:
status[
"org.matrix.expired"
] = await self._account_validity_handler.is_user_expired(
user_id.to_string()
)
return status return status
async def _get_remote_account_statuses( async def _get_remote_account_statuses(

View file

@ -950,54 +950,35 @@ class FederationHandler:
return event return event
async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]:
"""Returns the state at the event. i.e. not including said event."""
event = await self.store.get_event(event_id, check_room_id=room_id)
state_groups = await self.state_store.get_state_groups(room_id, [event_id])
if state_groups:
_, state = list(state_groups.items()).pop()
results = {(e.type, e.state_key): e for e in state}
if event.is_state():
# Get previous state
if "replaces_state" in event.unsigned:
prev_id = event.unsigned["replaces_state"]
if prev_id != event.event_id:
prev_event = await self.store.get_event(prev_id)
results[(event.type, event.state_key)] = prev_event
else:
del results[(event.type, event.state_key)]
res = list(results.values())
return res
else:
return []
async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]: async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
"""Returns the state at the event. i.e. not including said event.""" """Returns the state at the event. i.e. not including said event."""
event = await self.store.get_event(event_id, check_room_id=room_id) event = await self.store.get_event(event_id, check_room_id=room_id)
if event.internal_metadata.outlier:
raise NotFoundError("State not known at event %s" % (event_id,))
state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id]) state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id])
if state_groups: # get_state_groups_ids should return exactly one result
_, state = list(state_groups.items()).pop() assert len(state_groups) == 1
results = state
if event.is_state(): state_map = next(iter(state_groups.values()))
# Get previous state
state_key = event.get_state_key()
if state_key is not None:
# the event was not rejected (get_event raises a NotFoundError for rejected
# events) so the state at the event should include the event itself.
assert (
state_map.get((event.type, state_key)) == event.event_id
), "State at event did not include event itself"
# ... but we need the state *before* that event
if "replaces_state" in event.unsigned: if "replaces_state" in event.unsigned:
prev_id = event.unsigned["replaces_state"] prev_id = event.unsigned["replaces_state"]
if prev_id != event.event_id: state_map[(event.type, state_key)] = prev_id
results[(event.type, event.state_key)] = prev_id
else: else:
results.pop((event.type, event.state_key), None) del state_map[(event.type, state_key)]
return list(results.values()) return list(state_map.values())
else:
return []
async def on_backfill_request( async def on_backfill_request(
self, origin: str, room_id: str, pdu_list: List[str], limit: int self, origin: str, room_id: str, pdu_list: List[str], limit: int

View file

@ -495,6 +495,7 @@ class EventCreationHandler:
allow_no_prev_events: bool = False, allow_no_prev_events: bool = False,
prev_event_ids: Optional[List[str]] = None, prev_event_ids: Optional[List[str]] = None,
auth_event_ids: Optional[List[str]] = None, auth_event_ids: Optional[List[str]] = None,
state_event_ids: Optional[List[str]] = None,
require_consent: bool = True, require_consent: bool = True,
outlier: bool = False, outlier: bool = False,
historical: bool = False, historical: bool = False,
@ -529,6 +530,15 @@ class EventCreationHandler:
If non-None, prev_event_ids must also be provided. If non-None, prev_event_ids must also be provided.
state_event_ids:
The full state at a given event. This is used particularly by the MSC2716
/batch_send endpoint. One use case is with insertion events which float at
the beginning of a historical batch and don't have any `prev_events` to
derive from; we add all of these state events as the explicit state so the
rest of the historical batch can inherit the same state and state_group.
This should normally be left as None, which will cause the auth_event_ids
to be calculated based on the room state at the prev_events.
require_consent: Whether to check if the requester has require_consent: Whether to check if the requester has
consented to the privacy policy. consented to the privacy policy.
@ -614,6 +624,7 @@ class EventCreationHandler:
allow_no_prev_events=allow_no_prev_events, allow_no_prev_events=allow_no_prev_events,
prev_event_ids=prev_event_ids, prev_event_ids=prev_event_ids,
auth_event_ids=auth_event_ids, auth_event_ids=auth_event_ids,
state_event_ids=state_event_ids,
depth=depth, depth=depth,
) )
@ -773,7 +784,7 @@ class EventCreationHandler:
event_dict: dict, event_dict: dict,
allow_no_prev_events: bool = False, allow_no_prev_events: bool = False,
prev_event_ids: Optional[List[str]] = None, prev_event_ids: Optional[List[str]] = None,
auth_event_ids: Optional[List[str]] = None, state_event_ids: Optional[List[str]] = None,
ratelimit: bool = True, ratelimit: bool = True,
txn_id: Optional[str] = None, txn_id: Optional[str] = None,
ignore_shadow_ban: bool = False, ignore_shadow_ban: bool = False,
@ -797,12 +808,14 @@ class EventCreationHandler:
The event IDs to use as the prev events. The event IDs to use as the prev events.
Should normally be left as None to automatically request them Should normally be left as None to automatically request them
from the database. from the database.
auth_event_ids: state_event_ids:
The event ids to use as the auth_events for the new event. The full state at a given event. This is used particularly by the MSC2716
Should normally be left as None, which will cause them to be calculated /batch_send endpoint. One use case is with insertion events which float at
based on the room state at the prev_events. the beginning of a historical batch and don't have any `prev_events` to
derive from; we add all of these state events as the explicit state so the
If non-None, prev_event_ids must also be provided. rest of the historical batch can inherit the same state and state_group.
This should normally be left as None, which will cause the auth_event_ids
to be calculated based on the room state at the prev_events.
ratelimit: Whether to rate limit this send. ratelimit: Whether to rate limit this send.
txn_id: The transaction ID. txn_id: The transaction ID.
ignore_shadow_ban: True if shadow-banned users should be allowed to ignore_shadow_ban: True if shadow-banned users should be allowed to
@ -858,8 +871,9 @@ class EventCreationHandler:
requester, requester,
event_dict, event_dict,
txn_id=txn_id, txn_id=txn_id,
allow_no_prev_events=allow_no_prev_events,
prev_event_ids=prev_event_ids, prev_event_ids=prev_event_ids,
auth_event_ids=auth_event_ids, state_event_ids=state_event_ids,
outlier=outlier, outlier=outlier,
historical=historical, historical=historical,
depth=depth, depth=depth,
@ -895,6 +909,7 @@ class EventCreationHandler:
allow_no_prev_events: bool = False, allow_no_prev_events: bool = False,
prev_event_ids: Optional[List[str]] = None, prev_event_ids: Optional[List[str]] = None,
auth_event_ids: Optional[List[str]] = None, auth_event_ids: Optional[List[str]] = None,
state_event_ids: Optional[List[str]] = None,
depth: Optional[int] = None, depth: Optional[int] = None,
) -> Tuple[EventBase, EventContext]: ) -> Tuple[EventBase, EventContext]:
"""Create a new event for a local client """Create a new event for a local client
@ -917,6 +932,15 @@ class EventCreationHandler:
Should normally be left as None, which will cause them to be calculated Should normally be left as None, which will cause them to be calculated
based on the room state at the prev_events. based on the room state at the prev_events.
state_event_ids:
The full state at a given event. This is used particularly by the MSC2716
/batch_send endpoint. One use case is with insertion events which float at
the beginning of a historical batch and don't have any `prev_events` to
derive from; we add all of these state events as the explicit state so the
rest of the historical batch can inherit the same state and state_group.
This should normally be left as None, which will cause the auth_event_ids
to be calculated based on the room state at the prev_events.
depth: Override the depth used to order the event in the DAG. depth: Override the depth used to order the event in the DAG.
Should normally be set to None, which will cause the depth to be calculated Should normally be set to None, which will cause the depth to be calculated
based on the prev_events. based on the prev_events.
@ -924,31 +948,26 @@ class EventCreationHandler:
Returns: Returns:
Tuple of created event, context Tuple of created event, context
""" """
# Strip down the auth_event_ids to only what we need to auth the event. # Strip down the state_event_ids to only what we need to auth the event.
# For example, we don't need extra m.room.member that don't match event.sender # For example, we don't need extra m.room.member that don't match event.sender
full_state_ids_at_event = None if state_event_ids is not None:
if auth_event_ids is not None: # Do a quick check to make sure that prev_event_ids is present to
# If auth events are provided, prev events must be also. # make the type-checking around `builder.build` happy.
# prev_event_ids could be an empty array though. # prev_event_ids could be an empty array though.
assert prev_event_ids is not None assert prev_event_ids is not None
# Copy the full auth state before it stripped down
full_state_ids_at_event = auth_event_ids.copy()
temp_event = await builder.build( temp_event = await builder.build(
prev_event_ids=prev_event_ids, prev_event_ids=prev_event_ids,
auth_event_ids=auth_event_ids, auth_event_ids=state_event_ids,
depth=depth, depth=depth,
) )
auth_events = await self.store.get_events_as_list(auth_event_ids) state_events = await self.store.get_events_as_list(state_event_ids)
# Create a StateMap[str] # Create a StateMap[str]
auth_event_state_map = { state_map = {(e.type, e.state_key): e.event_id for e in state_events}
(e.type, e.state_key): e.event_id for e in auth_events # Actually strip down and only use the necessary auth events
}
# Actually strip down and use the necessary auth events
auth_event_ids = self._event_auth_handler.compute_auth_events( auth_event_ids = self._event_auth_handler.compute_auth_events(
event=temp_event, event=temp_event,
current_state_ids=auth_event_state_map, current_state_ids=state_map,
for_verification=False, for_verification=False,
) )
@ -991,12 +1010,16 @@ class EventCreationHandler:
context = EventContext.for_outlier() context = EventContext.for_outlier()
elif ( elif (
event.type == EventTypes.MSC2716_INSERTION event.type == EventTypes.MSC2716_INSERTION
and full_state_ids_at_event and state_event_ids
and builder.internal_metadata.is_historical() and builder.internal_metadata.is_historical()
): ):
# Add explicit state to the insertion event so it has state to derive
# from even though it's floating with no `prev_events`. The rest of
# the batch can derive from this state and state_group.
#
# TODO(faster_joins): figure out how this works, and make sure that the # TODO(faster_joins): figure out how this works, and make sure that the
# old state is complete. # old state is complete.
old_state = await self.store.get_events_as_list(full_state_ids_at_event) old_state = await self.store.get_events_as_list(state_event_ids)
context = await self.state.compute_event_context(event, old_state=old_state) context = await self.state.compute_event_context(event, old_state=old_state)
else: else:
context = await self.state.compute_event_context(event) context = await self.state.compute_event_context(event)

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Set from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set
import attr import attr
@ -134,6 +134,7 @@ class PaginationHandler:
self.clock = hs.get_clock() self.clock = hs.get_clock()
self._server_name = hs.hostname self._server_name = hs.hostname
self._room_shutdown_handler = hs.get_room_shutdown_handler() self._room_shutdown_handler = hs.get_room_shutdown_handler()
self._relations_handler = hs.get_relations_handler()
self.pagination_lock = ReadWriteLock() self.pagination_lock = ReadWriteLock()
# IDs of rooms in which there currently an active purge *or delete* operation. # IDs of rooms in which there currently an active purge *or delete* operation.
@ -422,7 +423,7 @@ class PaginationHandler:
pagin_config: PaginationConfig, pagin_config: PaginationConfig,
as_client_event: bool = True, as_client_event: bool = True,
event_filter: Optional[Filter] = None, event_filter: Optional[Filter] = None,
) -> Dict[str, Any]: ) -> JsonDict:
"""Get messages in a room. """Get messages in a room.
Args: Args:
@ -431,6 +432,7 @@ class PaginationHandler:
pagin_config: The pagination config rules to apply, if any. pagin_config: The pagination config rules to apply, if any.
as_client_event: True to get events in client-server format. as_client_event: True to get events in client-server format.
event_filter: Filter to apply to results or None event_filter: Filter to apply to results or None
Returns: Returns:
Pagination API results Pagination API results
""" """
@ -538,7 +540,9 @@ class PaginationHandler:
state_dict = await self.store.get_events(list(state_ids.values())) state_dict = await self.store.get_events(list(state_ids.values()))
state = state_dict.values() state = state_dict.values()
aggregations = await self.store.get_bundled_aggregations(events, user_id) aggregations = await self._relations_handler.get_bundled_aggregations(
events, user_id
)
time_now = self.clock.time_msec() time_now = self.clock.time_msec()

View file

@ -336,12 +336,18 @@ class ProfileHandler:
"""Check that the size and content type of the avatar at the given MXC URI are """Check that the size and content type of the avatar at the given MXC URI are
within the configured limits. within the configured limits.
If the given `mxc` is empty, no checks are performed. (Users are always able to
unset their avatar.)
Args: Args:
mxc: The MXC URI at which the avatar can be found. mxc: The MXC URI at which the avatar can be found.
Returns: Returns:
A boolean indicating whether the file can be allowed to be set as an avatar. A boolean indicating whether the file can be allowed to be set as an avatar.
""" """
if mxc == "":
return True
if not self.max_avatar_size and not self.allowed_avatar_mimetypes: if not self.max_avatar_size and not self.allowed_avatar_mimetypes:
return True return True

View file

@ -0,0 +1,271 @@
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Dict, Iterable, Optional, cast
import attr
from frozendict import frozendict
from synapse.api.constants import RelationTypes
from synapse.api.errors import SynapseError
from synapse.events import EventBase
from synapse.types import JsonDict, Requester, StreamToken
from synapse.visibility import filter_events_for_client
if TYPE_CHECKING:
from synapse.server import HomeServer
from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _ThreadAggregation:
# The latest event in the thread.
latest_event: EventBase
# The latest edit to the latest event in the thread.
latest_edit: Optional[EventBase]
# The total number of events in the thread.
count: int
# True if the current user has sent an event to the thread.
current_user_participated: bool
@attr.s(slots=True, auto_attribs=True)
class BundledAggregations:
"""
The bundled aggregations for an event.
Some values require additional processing during serialization.
"""
annotations: Optional[JsonDict] = None
references: Optional[JsonDict] = None
replace: Optional[EventBase] = None
thread: Optional[_ThreadAggregation] = None
def __bool__(self) -> bool:
return bool(self.annotations or self.references or self.replace or self.thread)
class RelationsHandler:
def __init__(self, hs: "HomeServer"):
self._main_store = hs.get_datastores().main
self._storage = hs.get_storage()
self._auth = hs.get_auth()
self._clock = hs.get_clock()
self._event_handler = hs.get_event_handler()
self._event_serializer = hs.get_event_client_serializer()
async def get_relations(
self,
requester: Requester,
event_id: str,
room_id: str,
relation_type: Optional[str] = None,
event_type: Optional[str] = None,
aggregation_key: Optional[str] = None,
limit: int = 5,
direction: str = "b",
from_token: Optional[StreamToken] = None,
to_token: Optional[StreamToken] = None,
) -> JsonDict:
"""Get related events of a event, ordered by topological ordering.
TODO Accept a PaginationConfig instead of individual pagination parameters.
Args:
requester: The user requesting the relations.
event_id: Fetch events that relate to this event ID.
room_id: The room the event belongs to.
relation_type: Only fetch events with this relation type, if given.
event_type: Only fetch events with this event type, if given.
aggregation_key: Only fetch events with this aggregation key, if given.
limit: Only fetch the most recent `limit` events.
direction: Whether to fetch the most recent first (`"b"`) or the
oldest first (`"f"`).
from_token: Fetch rows from the given token, or from the start if None.
to_token: Fetch rows up to the given token, or up to the end if None.
Returns:
The pagination chunk.
"""
user_id = requester.user.to_string()
# TODO Properly handle a user leaving a room.
(_, member_event_id) = await self._auth.check_user_in_room_or_world_readable(
room_id, user_id, allow_departed_users=True
)
# This gets the original event and checks that a) the event exists and
# b) the user is allowed to view it.
event = await self._event_handler.get_event(requester.user, room_id, event_id)
if event is None:
raise SynapseError(404, "Unknown parent event.")
pagination_chunk = await self._main_store.get_relations_for_event(
event_id=event_id,
event=event,
room_id=room_id,
relation_type=relation_type,
event_type=event_type,
aggregation_key=aggregation_key,
limit=limit,
direction=direction,
from_token=from_token,
to_token=to_token,
)
events = await self._main_store.get_events_as_list(
[c["event_id"] for c in pagination_chunk.chunk]
)
events = await filter_events_for_client(
self._storage, user_id, events, is_peeking=(member_event_id is None)
)
now = self._clock.time_msec()
# Do not bundle aggregations when retrieving the original event because
# we want the content before relations are applied to it.
original_event = self._event_serializer.serialize_event(
event, now, bundle_aggregations=None
)
# The relations returned for the requested event do include their
# bundled aggregations.
aggregations = await self.get_bundled_aggregations(
events, requester.user.to_string()
)
serialized_events = self._event_serializer.serialize_events(
events, now, bundle_aggregations=aggregations
)
return_value = await pagination_chunk.to_dict(self._main_store)
return_value["chunk"] = serialized_events
return_value["original_event"] = original_event
return return_value
async def _get_bundled_aggregation_for_event(
self, event: EventBase, user_id: str
) -> Optional[BundledAggregations]:
"""Generate bundled aggregations for an event.
Note that this does not use a cache, but depends on cached methods.
Args:
event: The event to calculate bundled aggregations for.
user_id: The user requesting the bundled aggregations.
Returns:
The bundled aggregations for an event, if bundled aggregations are
enabled and the event can have bundled aggregations.
"""
# Do not bundle aggregations for an event which represents an edit or an
# annotation. It does not make sense for them to have related events.
relates_to = event.content.get("m.relates_to")
if isinstance(relates_to, (dict, frozendict)):
relation_type = relates_to.get("rel_type")
if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
return None
event_id = event.event_id
room_id = event.room_id
# The bundled aggregations to include, a mapping of relation type to a
# type-specific value. Some types include the direct return type here
# while others need more processing during serialization.
aggregations = BundledAggregations()
annotations = await self._main_store.get_aggregation_groups_for_event(
event_id, room_id
)
if annotations.chunk:
aggregations.annotations = await annotations.to_dict(
cast("DataStore", self)
)
references = await self._main_store.get_relations_for_event(
event_id, event, room_id, RelationTypes.REFERENCE, direction="f"
)
if references.chunk:
aggregations.references = await references.to_dict(cast("DataStore", self))
# Store the bundled aggregations in the event metadata for later use.
return aggregations
async def get_bundled_aggregations(
self, events: Iterable[EventBase], user_id: str
) -> Dict[str, BundledAggregations]:
"""Generate bundled aggregations for events.
Args:
events: The iterable of events to calculate bundled aggregations for.
user_id: The user requesting the bundled aggregations.
Returns:
A map of event ID to the bundled aggregation for the event. Not all
events may have bundled aggregations in the results.
"""
# De-duplicate events by ID to handle the same event requested multiple times.
#
# State events do not get bundled aggregations.
events_by_id = {
event.event_id: event for event in events if not event.is_state()
}
# event ID -> bundled aggregation in non-serialized form.
results: Dict[str, BundledAggregations] = {}
# Fetch other relations per event.
for event in events_by_id.values():
event_result = await self._get_bundled_aggregation_for_event(event, user_id)
if event_result:
results[event.event_id] = event_result
# Fetch any edits (but not for redacted events).
edits = await self._main_store.get_applicable_edits(
[
event_id
for event_id, event in events_by_id.items()
if not event.internal_metadata.is_redacted()
]
)
for event_id, edit in edits.items():
results.setdefault(event_id, BundledAggregations()).replace = edit
# Fetch thread summaries.
summaries = await self._main_store.get_thread_summaries(events_by_id.keys())
# Only fetch participated for a limited selection based on what had
# summaries.
participated = await self._main_store.get_threads_participated(
[event_id for event_id, summary in summaries.items() if summary], user_id
)
for event_id, summary in summaries.items():
if summary:
thread_count, latest_thread_event, edit = summary
results.setdefault(
event_id, BundledAggregations()
).thread = _ThreadAggregation(
latest_event=latest_thread_event,
latest_edit=edit,
count=thread_count,
# If there's a thread summary it must also exist in the
# participated dictionary.
current_user_participated=participated[event_id],
)
return results

View file

@ -60,8 +60,8 @@ from synapse.events import EventBase
from synapse.events.utils import copy_power_levels_contents from synapse.events.utils import copy_power_levels_contents
from synapse.federation.federation_client import InvalidResponseError from synapse.federation.federation_client import InvalidResponseError
from synapse.handlers.federation import get_domains_from_state from synapse.handlers.federation import get_domains_from_state
from synapse.handlers.relations import BundledAggregations
from synapse.rest.admin._base import assert_user_is_admin from synapse.rest.admin._base import assert_user_is_admin
from synapse.storage.databases.main.relations import BundledAggregations
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.streams import EventSource from synapse.streams import EventSource
from synapse.types import ( from synapse.types import (
@ -1128,6 +1128,7 @@ class RoomContextHandler:
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.storage = hs.get_storage() self.storage = hs.get_storage()
self.state_store = self.storage.state self.state_store = self.storage.state
self._relations_handler = hs.get_relations_handler()
async def get_event_context( async def get_event_context(
self, self,
@ -1200,7 +1201,7 @@ class RoomContextHandler:
event = filtered[0] event = filtered[0]
# Fetch the aggregations. # Fetch the aggregations.
aggregations = await self.store.get_bundled_aggregations( aggregations = await self._relations_handler.get_bundled_aggregations(
itertools.chain(events_before, (event,), events_after), itertools.chain(events_before, (event,), events_after),
user.to_string(), user.to_string(),
) )

View file

@ -123,12 +123,11 @@ class RoomBatchHandler:
return create_requester(user_id, app_service=app_service) return create_requester(user_id, app_service=app_service)
async def get_most_recent_auth_event_ids_from_event_id_list( async def get_most_recent_full_state_ids_from_event_id_list(
self, event_ids: List[str] self, event_ids: List[str]
) -> List[str]: ) -> List[str]:
"""Find the most recent auth event ids (derived from state events) that """Find the most recent event_id and grab the full state at that event.
allowed that message to be sent. We will use this as a base We will use this as a base to auth our historical messages against.
to auth our historical messages against.
Args: Args:
event_ids: List of event ID's to look at event_ids: List of event ID's to look at
@ -138,38 +137,37 @@ class RoomBatchHandler:
""" """
( (
most_recent_prev_event_id, most_recent_event_id,
_, _,
) = await self.store.get_max_depth_of(event_ids) ) = await self.store.get_max_depth_of(event_ids)
# mapping from (type, state_key) -> state_event_id # mapping from (type, state_key) -> state_event_id
prev_state_map = await self.state_store.get_state_ids_for_event( prev_state_map = await self.state_store.get_state_ids_for_event(
most_recent_prev_event_id most_recent_event_id
) )
# List of state event ID's # List of state event ID's
prev_state_ids = list(prev_state_map.values()) full_state_ids = list(prev_state_map.values())
auth_event_ids = prev_state_ids
return auth_event_ids return full_state_ids
async def persist_state_events_at_start( async def persist_state_events_at_start(
self, self,
state_events_at_start: List[JsonDict], state_events_at_start: List[JsonDict],
room_id: str, room_id: str,
initial_auth_event_ids: List[str], initial_state_event_ids: List[str],
app_service_requester: Requester, app_service_requester: Requester,
) -> List[str]: ) -> List[str]:
"""Takes all `state_events_at_start` event dictionaries and creates/persists """Takes all `state_events_at_start` event dictionaries and creates/persists
them as floating state events which don't resolve into the current room state. them in a floating state event chain which don't resolve into the current room
They are floating because they reference a fake prev_event which doesn't connect state. They are floating because they reference no prev_events and are marked
to the normal DAG at all. as outliers which disconnects them from the normal DAG.
Args: Args:
state_events_at_start: state_events_at_start:
room_id: Room where you want the events persisted in. room_id: Room where you want the events persisted in.
initial_auth_event_ids: These will be the auth_events for the first initial_state_event_ids:
state event created. Each event created afterwards will be The base set of state for the historical batch which the floating
added to the list of auth events for the next state event state chain will derive from. This should probably be the state
created. from the `prev_event` defined by `/batch_send?prev_event_id=$abc`.
app_service_requester: The requester of an application service. app_service_requester: The requester of an application service.
Returns: Returns:
@ -178,7 +176,7 @@ class RoomBatchHandler:
assert app_service_requester.app_service assert app_service_requester.app_service
state_event_ids_at_start = [] state_event_ids_at_start = []
auth_event_ids = initial_auth_event_ids.copy() state_event_ids = initial_state_event_ids.copy()
# Make the state events float off on their own by specifying no # Make the state events float off on their own by specifying no
# prev_events for the first one in the chain so we don't have a bunch of # prev_events for the first one in the chain so we don't have a bunch of
@ -191,9 +189,7 @@ class RoomBatchHandler:
) )
logger.debug( logger.debug(
"RoomBatchSendEventRestServlet inserting state_event=%s, auth_event_ids=%s", "RoomBatchSendEventRestServlet inserting state_event=%s", state_event
state_event,
auth_event_ids,
) )
event_dict = { event_dict = {
@ -219,16 +215,26 @@ class RoomBatchHandler:
room_id=room_id, room_id=room_id,
action=membership, action=membership,
content=event_dict["content"], content=event_dict["content"],
# Mark as an outlier to disconnect it from the normal DAG
# and not show up between batches of history.
outlier=True, outlier=True,
historical=True, historical=True,
# Only the first event in the chain should be floating. # Only the first event in the state chain should be floating.
# The rest should hang off each other in a chain. # The rest should hang off each other in a chain.
allow_no_prev_events=index == 0, allow_no_prev_events=index == 0,
prev_event_ids=prev_event_ids_for_state_chain, prev_event_ids=prev_event_ids_for_state_chain,
# Since each state event is marked as an outlier, the
# `EventContext.for_outlier()` won't have any `state_ids`
# set and therefore can't derive any state even though the
# prev_events are set. Also since the first event in the
# state chain is floating with no `prev_events`, it can't
# derive state from anywhere automatically. So we need to
# set some state explicitly.
#
# Make sure to use a copy of this list because we modify it # Make sure to use a copy of this list because we modify it
# later in the loop here. Otherwise it will be the same # later in the loop here. Otherwise it will be the same
# reference and also update in the event when we append later. # reference and also update in the event when we append later.
auth_event_ids=auth_event_ids.copy(), state_event_ids=state_event_ids.copy(),
) )
else: else:
# TODO: Add some complement tests that adds state that is not member joins # TODO: Add some complement tests that adds state that is not member joins
@ -242,21 +248,31 @@ class RoomBatchHandler:
state_event["sender"], app_service_requester.app_service state_event["sender"], app_service_requester.app_service
), ),
event_dict, event_dict,
# Mark as an outlier to disconnect it from the normal DAG
# and not show up between batches of history.
outlier=True, outlier=True,
historical=True, historical=True,
# Only the first event in the chain should be floating. # Only the first event in the state chain should be floating.
# The rest should hang off each other in a chain. # The rest should hang off each other in a chain.
allow_no_prev_events=index == 0, allow_no_prev_events=index == 0,
prev_event_ids=prev_event_ids_for_state_chain, prev_event_ids=prev_event_ids_for_state_chain,
# Since each state event is marked as an outlier, the
# `EventContext.for_outlier()` won't have any `state_ids`
# set and therefore can't derive any state even though the
# prev_events are set. Also since the first event in the
# state chain is floating with no `prev_events`, it can't
# derive state from anywhere automatically. So we need to
# set some state explicitly.
#
# Make sure to use a copy of this list because we modify it # Make sure to use a copy of this list because we modify it
# later in the loop here. Otherwise it will be the same # later in the loop here. Otherwise it will be the same
# reference and also update in the event when we append later. # reference and also update in the event when we append later.
auth_event_ids=auth_event_ids.copy(), state_event_ids=state_event_ids.copy(),
) )
event_id = event.event_id event_id = event.event_id
state_event_ids_at_start.append(event_id) state_event_ids_at_start.append(event_id)
auth_event_ids.append(event_id) state_event_ids.append(event_id)
# Connect all the state in a floating chain # Connect all the state in a floating chain
prev_event_ids_for_state_chain = [event_id] prev_event_ids_for_state_chain = [event_id]
@ -267,7 +283,7 @@ class RoomBatchHandler:
events_to_create: List[JsonDict], events_to_create: List[JsonDict],
room_id: str, room_id: str,
inherited_depth: int, inherited_depth: int,
auth_event_ids: List[str], initial_state_event_ids: List[str],
app_service_requester: Requester, app_service_requester: Requester,
) -> List[str]: ) -> List[str]:
"""Create and persists all events provided sequentially. Handles the """Create and persists all events provided sequentially. Handles the
@ -283,8 +299,10 @@ class RoomBatchHandler:
room_id: Room where you want the events persisted in. room_id: Room where you want the events persisted in.
inherited_depth: The depth to create the events at (you will inherited_depth: The depth to create the events at (you will
probably by calling inherit_depth_from_prev_ids(...)). probably by calling inherit_depth_from_prev_ids(...)).
auth_event_ids: Define which events allow you to create the given initial_state_event_ids:
event in the room. This is used to set explicit state for the insertion event at
the start of the historical batch since it's floating with no
prev_events to derive state from automatically.
app_service_requester: The requester of an application service. app_service_requester: The requester of an application service.
Returns: Returns:
@ -292,6 +310,11 @@ class RoomBatchHandler:
""" """
assert app_service_requester.app_service assert app_service_requester.app_service
# We expect the first event in a historical batch to be an insertion event
assert events_to_create[0]["type"] == EventTypes.MSC2716_INSERTION
# We expect the last event in a historical batch to be an batch event
assert events_to_create[-1]["type"] == EventTypes.MSC2716_BATCH
# Make the historical event chain float off on its own by specifying no # Make the historical event chain float off on its own by specifying no
# prev_events for the first event in the chain which causes the HS to # prev_events for the first event in the chain which causes the HS to
# ask for the state at the start of the batch later. # ask for the state at the start of the batch later.
@ -323,11 +346,16 @@ class RoomBatchHandler:
ev["sender"], app_service_requester.app_service ev["sender"], app_service_requester.app_service
), ),
event_dict, event_dict,
# Only the first event in the chain should be floating. # Only the first event (which is the insertion event) in the
# The rest should hang off each other in a chain. # chain should be floating. The rest should hang off each other
# in a chain.
allow_no_prev_events=index == 0, allow_no_prev_events=index == 0,
prev_event_ids=event_dict.get("prev_events"), prev_event_ids=event_dict.get("prev_events"),
auth_event_ids=auth_event_ids, # Since the first event (which is the insertion event) in the
# chain is floating with no `prev_events`, it can't derive state
# from anywhere automatically. So we need to set some state
# explicitly.
state_event_ids=initial_state_event_ids if index == 0 else None,
historical=True, historical=True,
depth=inherited_depth, depth=inherited_depth,
) )
@ -345,10 +373,9 @@ class RoomBatchHandler:
) )
logger.debug( logger.debug(
"RoomBatchSendEventRestServlet inserting event=%s, prev_event_ids=%s, auth_event_ids=%s", "RoomBatchSendEventRestServlet inserting event=%s, prev_event_ids=%s",
event, event,
prev_event_ids, prev_event_ids,
auth_event_ids,
) )
events_to_persist.append((event, context)) events_to_persist.append((event, context))
@ -378,12 +405,12 @@ class RoomBatchHandler:
room_id: str, room_id: str,
batch_id_to_connect_to: str, batch_id_to_connect_to: str,
inherited_depth: int, inherited_depth: int,
auth_event_ids: List[str], initial_state_event_ids: List[str],
app_service_requester: Requester, app_service_requester: Requester,
) -> Tuple[List[str], str]: ) -> Tuple[List[str], str]:
""" """
Handles creating and persisting all of the historical events as well Handles creating and persisting all of the historical events as well as
as insertion and batch meta events to make the batch navigable in the DAG. insertion and batch meta events to make the batch navigable in the DAG.
Args: Args:
events_to_create: List of historical events to create in JSON events_to_create: List of historical events to create in JSON
@ -393,8 +420,13 @@ class RoomBatchHandler:
want this batch to connect to. want this batch to connect to.
inherited_depth: The depth to create the events at (you will inherited_depth: The depth to create the events at (you will
probably by calling inherit_depth_from_prev_ids(...)). probably by calling inherit_depth_from_prev_ids(...)).
auth_event_ids: Define which events allow you to create the given initial_state_event_ids:
event in the room. This is used to set explicit state for the insertion event at
the start of the historical batch since it's floating with no
prev_events to derive state from automatically. This should
probably be the state from the `prev_event` defined by
`/batch_send?prev_event_id=$abc` plus the outcome of
`persist_state_events_at_start`
app_service_requester: The requester of an application service. app_service_requester: The requester of an application service.
Returns: Returns:
@ -440,7 +472,7 @@ class RoomBatchHandler:
events_to_create=events_to_create, events_to_create=events_to_create,
room_id=room_id, room_id=room_id,
inherited_depth=inherited_depth, inherited_depth=inherited_depth,
auth_event_ids=auth_event_ids, initial_state_event_ids=initial_state_event_ids,
app_service_requester=app_service_requester, app_service_requester=app_service_requester,
) )

View file

@ -271,7 +271,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
membership: str, membership: str,
allow_no_prev_events: bool = False, allow_no_prev_events: bool = False,
prev_event_ids: Optional[List[str]] = None, prev_event_ids: Optional[List[str]] = None,
auth_event_ids: Optional[List[str]] = None, state_event_ids: Optional[List[str]] = None,
txn_id: Optional[str] = None, txn_id: Optional[str] = None,
ratelimit: bool = True, ratelimit: bool = True,
content: Optional[dict] = None, content: Optional[dict] = None,
@ -294,10 +294,14 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
events should have a prev_event and we should only use this in special events should have a prev_event and we should only use this in special
cases like MSC2716. cases like MSC2716.
prev_event_ids: The event IDs to use as the prev events prev_event_ids: The event IDs to use as the prev events
auth_event_ids: state_event_ids:
The event ids to use as the auth_events for the new event. The full state at a given event. This is used particularly by the MSC2716
Should normally be left as None, which will cause them to be calculated /batch_send endpoint. One use case is the historical `state_events_at_start`;
based on the room state at the prev_events. since each is marked as an `outlier`, the `EventContext.for_outlier()` won't
have any `state_ids` set and therefore can't derive any state even though the
prev_events are set so we need to set them ourself via this argument.
This should normally be left as None, which will cause the auth_event_ids
to be calculated based on the room state at the prev_events.
txn_id: txn_id:
ratelimit: ratelimit:
@ -352,7 +356,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
txn_id=txn_id, txn_id=txn_id,
allow_no_prev_events=allow_no_prev_events, allow_no_prev_events=allow_no_prev_events,
prev_event_ids=prev_event_ids, prev_event_ids=prev_event_ids,
auth_event_ids=auth_event_ids, state_event_ids=state_event_ids,
require_consent=require_consent, require_consent=require_consent,
outlier=outlier, outlier=outlier,
historical=historical, historical=historical,
@ -455,7 +459,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
historical: bool = False, historical: bool = False,
allow_no_prev_events: bool = False, allow_no_prev_events: bool = False,
prev_event_ids: Optional[List[str]] = None, prev_event_ids: Optional[List[str]] = None,
auth_event_ids: Optional[List[str]] = None, state_event_ids: Optional[List[str]] = None,
) -> Tuple[str, int]: ) -> Tuple[str, int]:
"""Update a user's membership in a room. """Update a user's membership in a room.
@ -483,10 +487,14 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
events should have a prev_event and we should only use this in special events should have a prev_event and we should only use this in special
cases like MSC2716. cases like MSC2716.
prev_event_ids: The event IDs to use as the prev events prev_event_ids: The event IDs to use as the prev events
auth_event_ids: state_event_ids:
The event ids to use as the auth_events for the new event. The full state at a given event. This is used particularly by the MSC2716
Should normally be left as None, which will cause them to be calculated /batch_send endpoint. One use case is the historical `state_events_at_start`;
based on the room state at the prev_events. since each is marked as an `outlier`, the `EventContext.for_outlier()` won't
have any `state_ids` set and therefore can't derive any state even though the
prev_events are set so we need to set them ourself via this argument.
This should normally be left as None, which will cause the auth_event_ids
to be calculated based on the room state at the prev_events.
Returns: Returns:
A tuple of the new event ID and stream ID. A tuple of the new event ID and stream ID.
@ -525,7 +533,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
historical=historical, historical=historical,
allow_no_prev_events=allow_no_prev_events, allow_no_prev_events=allow_no_prev_events,
prev_event_ids=prev_event_ids, prev_event_ids=prev_event_ids,
auth_event_ids=auth_event_ids, state_event_ids=state_event_ids,
) )
return result return result
@ -547,7 +555,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
historical: bool = False, historical: bool = False,
allow_no_prev_events: bool = False, allow_no_prev_events: bool = False,
prev_event_ids: Optional[List[str]] = None, prev_event_ids: Optional[List[str]] = None,
auth_event_ids: Optional[List[str]] = None, state_event_ids: Optional[List[str]] = None,
) -> Tuple[str, int]: ) -> Tuple[str, int]:
"""Helper for update_membership. """Helper for update_membership.
@ -577,10 +585,14 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
events should have a prev_event and we should only use this in special events should have a prev_event and we should only use this in special
cases like MSC2716. cases like MSC2716.
prev_event_ids: The event IDs to use as the prev events prev_event_ids: The event IDs to use as the prev events
auth_event_ids: state_event_ids:
The event ids to use as the auth_events for the new event. The full state at a given event. This is used particularly by the MSC2716
Should normally be left as None, which will cause them to be calculated /batch_send endpoint. One use case is the historical `state_events_at_start`;
based on the room state at the prev_events. since each is marked as an `outlier`, the `EventContext.for_outlier()` won't
have any `state_ids` set and therefore can't derive any state even though the
prev_events are set so we need to set them ourself via this argument.
This should normally be left as None, which will cause the auth_event_ids
to be calculated based on the room state at the prev_events.
Returns: Returns:
A tuple of the new event ID and stream ID. A tuple of the new event ID and stream ID.
@ -707,7 +719,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
ratelimit=ratelimit, ratelimit=ratelimit,
allow_no_prev_events=allow_no_prev_events, allow_no_prev_events=allow_no_prev_events,
prev_event_ids=prev_event_ids, prev_event_ids=prev_event_ids,
auth_event_ids=auth_event_ids, state_event_ids=state_event_ids,
content=content, content=content,
require_consent=require_consent, require_consent=require_consent,
outlier=outlier, outlier=outlier,
@ -931,7 +943,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
txn_id=txn_id, txn_id=txn_id,
ratelimit=ratelimit, ratelimit=ratelimit,
prev_event_ids=latest_event_ids, prev_event_ids=latest_event_ids,
auth_event_ids=auth_event_ids, state_event_ids=state_event_ids,
content=content, content=content,
require_consent=require_consent, require_consent=require_consent,
outlier=outlier, outlier=outlier,

View file

@ -54,6 +54,7 @@ class SearchHandler:
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.hs = hs self.hs = hs
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
self._relations_handler = hs.get_relations_handler()
self.storage = hs.get_storage() self.storage = hs.get_storage()
self.state_store = self.storage.state self.state_store = self.storage.state
self.auth = hs.get_auth() self.auth = hs.get_auth()
@ -354,7 +355,7 @@ class SearchHandler:
aggregations = None aggregations = None
if self._msc3666_enabled: if self._msc3666_enabled:
aggregations = await self.store.get_bundled_aggregations( aggregations = await self._relations_handler.get_bundled_aggregations(
# Generate an iterable of EventBase for all the events that will be # Generate an iterable of EventBase for all the events that will be
# returned, including contextual events. # returned, including contextual events.
itertools.chain( itertools.chain(

View file

@ -28,16 +28,16 @@ from typing import (
import attr import attr
from prometheus_client import Counter from prometheus_client import Counter
from synapse.api.constants import AccountDataTypes, EventTypes, Membership, ReceiptTypes from synapse.api.constants import EventTypes, Membership, ReceiptTypes
from synapse.api.filtering import FilterCollection from synapse.api.filtering import FilterCollection
from synapse.api.presence import UserPresenceState from synapse.api.presence import UserPresenceState
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase from synapse.events import EventBase
from synapse.handlers.relations import BundledAggregations
from synapse.logging.context import current_context from synapse.logging.context import current_context
from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span
from synapse.push.clientformat import format_push_rules_for_user from synapse.push.clientformat import format_push_rules_for_user
from synapse.storage.databases.main.event_push_actions import NotifCounts from synapse.storage.databases.main.event_push_actions import NotifCounts
from synapse.storage.databases.main.relations import BundledAggregations
from synapse.storage.roommember import MemberSummary from synapse.storage.roommember import MemberSummary
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.types import ( from synapse.types import (
@ -269,6 +269,7 @@ class SyncHandler:
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.presence_handler = hs.get_presence_handler() self.presence_handler = hs.get_presence_handler()
self._relations_handler = hs.get_relations_handler()
self.event_sources = hs.get_event_sources() self.event_sources = hs.get_event_sources()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
@ -638,9 +639,11 @@ class SyncHandler:
# as clients will have all the necessary information. # as clients will have all the necessary information.
bundled_aggregations = None bundled_aggregations = None
if limited or newly_joined_room: if limited or newly_joined_room:
bundled_aggregations = await self.store.get_bundled_aggregations( bundled_aggregations = (
await self._relations_handler.get_bundled_aggregations(
recents, sync_config.user.to_string() recents, sync_config.user.to_string()
) )
)
return TimelineBatch( return TimelineBatch(
events=recents, events=recents,
@ -1600,7 +1603,7 @@ class SyncHandler:
return set(), set(), set(), set() return set(), set(), set(), set()
# 3. Work out which rooms need reporting in the sync response. # 3. Work out which rooms need reporting in the sync response.
ignored_users = await self._get_ignored_users(user_id) ignored_users = await self.store.ignored_users(user_id)
if since_token: if since_token:
room_changes = await self._get_rooms_changed( room_changes = await self._get_rooms_changed(
sync_result_builder, ignored_users sync_result_builder, ignored_users
@ -1626,7 +1629,6 @@ class SyncHandler:
logger.debug("Generating room entry for %s", room_entry.room_id) logger.debug("Generating room entry for %s", room_entry.room_id)
await self._generate_room_entry( await self._generate_room_entry(
sync_result_builder, sync_result_builder,
ignored_users,
room_entry, room_entry,
ephemeral=ephemeral_by_room.get(room_entry.room_id, []), ephemeral=ephemeral_by_room.get(room_entry.room_id, []),
tags=tags_by_room.get(room_entry.room_id), tags=tags_by_room.get(room_entry.room_id),
@ -1656,29 +1658,6 @@ class SyncHandler:
newly_left_users, newly_left_users,
) )
async def _get_ignored_users(self, user_id: str) -> FrozenSet[str]:
"""Retrieve the users ignored by the given user from their global account_data.
Returns an empty set if
- there is no global account_data entry for ignored_users
- there is such an entry, but it's not a JSON object.
"""
# TODO: Can we `SELECT ignored_user_id FROM ignored_users WHERE ignorer_user_id=?;` instead?
ignored_account_data = (
await self.store.get_global_account_data_by_type_for_user(
user_id=user_id, data_type=AccountDataTypes.IGNORED_USER_LIST
)
)
# If there is ignored users account data and it matches the proper type,
# then use it.
ignored_users: FrozenSet[str] = frozenset()
if ignored_account_data:
ignored_users_data = ignored_account_data.get("ignored_users", {})
if isinstance(ignored_users_data, dict):
ignored_users = frozenset(ignored_users_data.keys())
return ignored_users
async def _have_rooms_changed( async def _have_rooms_changed(
self, sync_result_builder: "SyncResultBuilder" self, sync_result_builder: "SyncResultBuilder"
) -> bool: ) -> bool:
@ -2021,7 +2000,6 @@ class SyncHandler:
async def _generate_room_entry( async def _generate_room_entry(
self, self,
sync_result_builder: "SyncResultBuilder", sync_result_builder: "SyncResultBuilder",
ignored_users: FrozenSet[str],
room_builder: "RoomSyncResultBuilder", room_builder: "RoomSyncResultBuilder",
ephemeral: List[JsonDict], ephemeral: List[JsonDict],
tags: Optional[Dict[str, Dict[str, Any]]], tags: Optional[Dict[str, Dict[str, Any]]],
@ -2050,7 +2028,6 @@ class SyncHandler:
Args: Args:
sync_result_builder sync_result_builder
ignored_users: Set of users ignored by user.
room_builder room_builder
ephemeral: List of new ephemeral events for room ephemeral: List of new ephemeral events for room
tags: List of *all* tags for room, or None if there has been tags: List of *all* tags for room, or None if there has been

View file

@ -19,8 +19,8 @@ import synapse.metrics
from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules, Membership from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules, Membership
from synapse.handlers.state_deltas import MatchChange, StateDeltasHandler from synapse.handlers.state_deltas import MatchChange, StateDeltasHandler
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.databases.main.user_directory import SearchResult
from synapse.storage.roommember import ProfileInfo from synapse.storage.roommember import ProfileInfo
from synapse.types import JsonDict
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
if TYPE_CHECKING: if TYPE_CHECKING:
@ -78,7 +78,7 @@ class UserDirectoryHandler(StateDeltasHandler):
async def search_users( async def search_users(
self, user_id: str, search_term: str, limit: int self, user_id: str, search_term: str, limit: int
) -> JsonDict: ) -> SearchResult:
"""Searches for users in directory """Searches for users in directory
Returns: Returns:

View file

@ -111,6 +111,7 @@ from synapse.types import (
StateMap, StateMap,
UserID, UserID,
UserInfo, UserInfo,
UserProfile,
create_requester, create_requester,
) )
from synapse.util import Clock from synapse.util import Clock
@ -150,6 +151,7 @@ __all__ = [
"EventBase", "EventBase",
"StateMap", "StateMap",
"ProfileInfo", "ProfileInfo",
"UserProfile",
] ]
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -609,15 +611,18 @@ class ModuleApi:
localpart: str, localpart: str,
displayname: Optional[str] = None, displayname: Optional[str] = None,
emails: Optional[List[str]] = None, emails: Optional[List[str]] = None,
admin: bool = False,
) -> "defer.Deferred[str]": ) -> "defer.Deferred[str]":
"""Registers a new user with given localpart and optional displayname, emails. """Registers a new user with given localpart and optional displayname, emails.
Added in Synapse v1.2.0. Added in Synapse v1.2.0.
Changed in Synapse v1.56.0: add 'admin' argument to register the user as admin.
Args: Args:
localpart: The localpart of the new user. localpart: The localpart of the new user.
displayname: The displayname of the new user. displayname: The displayname of the new user.
emails: Emails to bind to the new user. emails: Emails to bind to the new user.
admin: True if the user should be registered as a server admin.
Raises: Raises:
SynapseError if there is an error performing the registration. Check the SynapseError if there is an error performing the registration. Check the
@ -631,6 +636,7 @@ class ModuleApi:
localpart=localpart, localpart=localpart,
default_display_name=displayname, default_display_name=displayname,
bind_emails=emails or [], bind_emails=emails or [],
admin=admin,
) )
) )
@ -665,7 +671,8 @@ class ModuleApi:
def record_user_external_id( def record_user_external_id(
self, auth_provider_id: str, remote_user_id: str, registered_user_id: str self, auth_provider_id: str, remote_user_id: str, registered_user_id: str
) -> defer.Deferred: ) -> defer.Deferred:
"""Record a mapping from an external user id to a mxid """Record a mapping between an external user id from a single sign-on provider
and a mxid.
Added in Synapse v1.9.0. Added in Synapse v1.9.0.
@ -1280,6 +1287,30 @@ class ModuleApi:
""" """
await self._registration_handler.check_username(username) await self._registration_handler.check_username(username)
async def store_remote_3pid_association(
self, user_id: str, medium: str, address: str, id_server: str
) -> None:
"""Stores an existing association between a user ID and a third-party identifier.
The association must already exist on the remote identity server.
Added in Synapse v1.56.0.
Args:
user_id: The user ID that's been associated with the 3PID.
medium: The medium of the 3PID (current supported values are "msisdn" and
"email").
address: The address of the 3PID.
id_server: The identity server the 3PID association has been registered on.
This should only be the domain (or IP address, optionally with the port
number) for the identity server. This will be used to reach out to the
identity server using HTTPS (unless specified otherwise by Synapse's
configuration) when attempting to unbind the third-party identifier.
"""
await self._store.add_user_bound_threepid(user_id, medium, address, id_server)
class PublicRoomListManager: class PublicRoomListManager:
"""Contains methods for adding to, removing from and querying whether a room """Contains methods for adding to, removing from and querying whether a room

View file

@ -24,6 +24,7 @@ from synapse.event_auth import get_user_power_level
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.state import POWER_KEY from synapse.state import POWER_KEY
from synapse.storage.databases.main.roommember import EventIdMembership
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.caches import CacheMetric, register_cache from synapse.util.caches import CacheMetric, register_cache
from synapse.util.caches.descriptors import lru_cache from synapse.util.caches.descriptors import lru_cache
@ -213,7 +214,7 @@ class BulkPushRuleEvaluator:
if not event.is_state(): if not event.is_state():
ignorers = await self.store.ignored_by(event.sender) ignorers = await self.store.ignored_by(event.sender)
else: else:
ignorers = set() ignorers = frozenset()
for uid, rules in rules_by_user.items(): for uid, rules in rules_by_user.items():
if event.sender == uid: if event.sender == uid:
@ -292,7 +293,7 @@ def _condition_checker(
return True return True
MemberMap = Dict[str, Tuple[str, str]] MemberMap = Dict[str, Optional[EventIdMembership]]
Rule = Dict[str, dict] Rule = Dict[str, dict]
RulesByUser = Dict[str, List[Rule]] RulesByUser = Dict[str, List[Rule]]
StateGroup = Union[object, int] StateGroup = Union[object, int]
@ -306,7 +307,7 @@ class RulesForRoomData:
*only* include data, and not references to e.g. the data stores. *only* include data, and not references to e.g. the data stores.
""" """
# event_id -> (user_id, state) # event_id -> EventIdMembership
member_map: MemberMap = attr.Factory(dict) member_map: MemberMap = attr.Factory(dict)
# user_id -> rules # user_id -> rules
rules_by_user: RulesByUser = attr.Factory(dict) rules_by_user: RulesByUser = attr.Factory(dict)
@ -447,11 +448,10 @@ class RulesForRoom:
res = self.data.member_map.get(event_id, None) res = self.data.member_map.get(event_id, None)
if res: if res:
user_id, state = res if res.membership == Membership.JOIN:
if state == Membership.JOIN: rules = self.data.rules_by_user.get(res.user_id, None)
rules = self.data.rules_by_user.get(user_id, None)
if rules: if rules:
ret_rules_by_user[user_id] = rules ret_rules_by_user[res.user_id] = rules
continue continue
# If a user has left a room we remove their push rule. If they # If a user has left a room we remove their push rule. If they
@ -502,24 +502,26 @@ class RulesForRoom:
""" """
sequence = self.data.sequence sequence = self.data.sequence
rows = await self.store.get_membership_from_event_ids(member_event_ids.values()) members = await self.store.get_membership_from_event_ids(
member_event_ids.values()
)
members = {row["event_id"]: (row["user_id"], row["membership"]) for row in rows} # If the event is a join event then it will be in current state events
# If the event is a join event then it will be in current state evnts
# map but not in the DB, so we have to explicitly insert it. # map but not in the DB, so we have to explicitly insert it.
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
for event_id in member_event_ids.values(): for event_id in member_event_ids.values():
if event_id == event.event_id: if event_id == event.event_id:
members[event_id] = (event.state_key, event.membership) members[event_id] = EventIdMembership(
user_id=event.state_key, membership=event.membership
)
if logger.isEnabledFor(logging.DEBUG): if logger.isEnabledFor(logging.DEBUG):
logger.debug("Found members %r: %r", self.room_id, members.values()) logger.debug("Found members %r: %r", self.room_id, members.values())
joined_user_ids = { joined_user_ids = {
user_id entry.user_id
for user_id, membership in members.values() for entry in members.values()
if membership == Membership.JOIN if entry and entry.membership == Membership.JOIN
} }
logger.debug("Joined: %r", joined_user_ids) logger.debug("Joined: %r", joined_user_ids)

View file

@ -18,6 +18,7 @@ from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, TypeVar
import bleach import bleach
import jinja2 import jinja2
from markupsafe import Markup
from synapse.api.constants import EventTypes, Membership, RoomTypes from synapse.api.constants import EventTypes, Membership, RoomTypes
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
@ -867,7 +868,7 @@ class Mailer:
) )
def safe_markup(raw_html: str) -> jinja2.Markup: def safe_markup(raw_html: str) -> Markup:
""" """
Sanitise a raw HTML string to a set of allowed tags and attributes, and linkify any bare URLs. Sanitise a raw HTML string to a set of allowed tags and attributes, and linkify any bare URLs.
@ -877,7 +878,7 @@ def safe_markup(raw_html: str) -> jinja2.Markup:
Returns: Returns:
A Markup object ready to safely use in a Jinja template. A Markup object ready to safely use in a Jinja template.
""" """
return jinja2.Markup( return Markup(
bleach.linkify( bleach.linkify(
bleach.clean( bleach.clean(
raw_html, raw_html,
@ -891,7 +892,7 @@ def safe_markup(raw_html: str) -> jinja2.Markup:
) )
def safe_text(raw_text: str) -> jinja2.Markup: def safe_text(raw_text: str) -> Markup:
""" """
Sanitise text (escape any HTML tags), and then linkify any bare URLs. Sanitise text (escape any HTML tags), and then linkify any bare URLs.
@ -901,7 +902,7 @@ def safe_text(raw_text: str) -> jinja2.Markup:
Returns: Returns:
A Markup object ready to safely use in a Jinja template. A Markup object ready to safely use in a Jinja template.
""" """
return jinja2.Markup( return Markup(
bleach.linkify(bleach.clean(raw_text, tags=[], attributes=[], strip=False)) bleach.linkify(bleach.clean(raw_text, tags=[], attributes=[], strip=False))
) )

View file

@ -74,7 +74,10 @@ REQUIREMENTS = [
# Note: 21.1.0 broke `/sync`, see #9936 # Note: 21.1.0 broke `/sync`, see #9936
"attrs>=19.2.0,!=21.1.0", "attrs>=19.2.0,!=21.1.0",
"netaddr>=0.7.18", "netaddr>=0.7.18",
"Jinja2>=2.9", # Jinja 2.x is incompatible with MarkupSafe>=2.1. To ensure that admins do not
# end up with a broken installation, with recent MarkupSafe but old Jinja, we
# add a lower bound to the Jinja2 dependency.
"Jinja2>=3.0",
"bleach>=1.4.3", "bleach>=1.4.3",
# We use `ParamSpec`, which was added in `typing-extensions` 3.10.0.0. # We use `ParamSpec`, which was added in `typing-extensions` 3.10.0.0.
"typing-extensions>=3.10.0", "typing-extensions>=3.10.0",

View file

@ -32,6 +32,7 @@ from synapse.rest.client import (
knock, knock,
login as v1_login, login as v1_login,
logout, logout,
mutual_rooms,
notifications, notifications,
openid, openid,
password_policy, password_policy,
@ -49,7 +50,6 @@ from synapse.rest.client import (
room_keys, room_keys,
room_upgrade_rest_servlet, room_upgrade_rest_servlet,
sendtodevice, sendtodevice,
shared_rooms,
sync, sync,
tags, tags,
thirdparty, thirdparty,
@ -132,4 +132,4 @@ class ClientRestResource(JsonResource):
admin.register_servlets_for_client_rest_resource(hs, client_resource) admin.register_servlets_for_client_rest_resource(hs, client_resource)
# unstable # unstable
shared_rooms.register_servlets(hs, client_resource) mutual_rooms.register_servlets(hs, client_resource)

View file

@ -28,13 +28,13 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class UserSharedRoomsServlet(RestServlet): class UserMutualRoomsServlet(RestServlet):
""" """
GET /uk.half-shot.msc2666/user/shared_rooms/{user_id} HTTP/1.1 GET /uk.half-shot.msc2666/user/mutual_rooms/{user_id} HTTP/1.1
""" """
PATTERNS = client_patterns( PATTERNS = client_patterns(
"/uk.half-shot.msc2666/user/shared_rooms/(?P<user_id>[^/]*)", "/uk.half-shot.msc2666/user/mutual_rooms/(?P<user_id>[^/]*)",
releases=(), # This is an unstable feature releases=(), # This is an unstable feature
) )
@ -42,17 +42,19 @@ class UserSharedRoomsServlet(RestServlet):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.user_directory_active = hs.config.server.update_user_directory self.user_directory_search_enabled = (
hs.config.userdirectory.user_directory_search_enabled
)
async def on_GET( async def on_GET(
self, request: SynapseRequest, user_id: str self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
if not self.user_directory_active: if not self.user_directory_search_enabled:
raise SynapseError( raise SynapseError(
code=400, code=400,
msg="The user directory is disabled on this server. Cannot determine shared rooms.", msg="User directory searching is disabled. Cannot determine shared rooms.",
errcode=Codes.FORBIDDEN, errcode=Codes.UNKNOWN,
) )
UserID.from_string(user_id) UserID.from_string(user_id)
@ -64,7 +66,8 @@ class UserSharedRoomsServlet(RestServlet):
msg="You cannot request a list of shared rooms with yourself", msg="You cannot request a list of shared rooms with yourself",
errcode=Codes.FORBIDDEN, errcode=Codes.FORBIDDEN,
) )
rooms = await self.store.get_shared_rooms_for_users(
rooms = await self.store.get_mutual_rooms_for_users(
requester.user.to_string(), user_id requester.user.to_string(), user_id
) )
@ -72,4 +75,4 @@ class UserSharedRoomsServlet(RestServlet):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
UserSharedRoomsServlet(hs).register(http_server) UserMutualRoomsServlet(hs).register(http_server)

View file

@ -51,9 +51,7 @@ class RelationPaginationServlet(RestServlet):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.clock = hs.get_clock() self._relations_handler = hs.get_relations_handler()
self._event_serializer = hs.get_event_client_serializer()
self.event_handler = hs.get_event_handler()
async def on_GET( async def on_GET(
self, self,
@ -65,16 +63,6 @@ class RelationPaginationServlet(RestServlet):
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
await self.auth.check_user_in_room_or_world_readable(
room_id, requester.user.to_string(), allow_departed_users=True
)
# This gets the original event and checks that a) the event exists and
# b) the user is allowed to view it.
event = await self.event_handler.get_event(requester.user, room_id, parent_id)
if event is None:
raise SynapseError(404, "Unknown parent event.")
limit = parse_integer(request, "limit", default=5) limit = parse_integer(request, "limit", default=5)
direction = parse_string( direction = parse_string(
request, "org.matrix.msc3715.dir", default="b", allowed_values=["f", "b"] request, "org.matrix.msc3715.dir", default="b", allowed_values=["f", "b"]
@ -90,9 +78,9 @@ class RelationPaginationServlet(RestServlet):
if to_token_str: if to_token_str:
to_token = await StreamToken.from_string(self.store, to_token_str) to_token = await StreamToken.from_string(self.store, to_token_str)
pagination_chunk = await self.store.get_relations_for_event( result = await self._relations_handler.get_relations(
requester=requester,
event_id=parent_id, event_id=parent_id,
event=event,
room_id=room_id, room_id=room_id,
relation_type=relation_type, relation_type=relation_type,
event_type=event_type, event_type=event_type,
@ -102,30 +90,7 @@ class RelationPaginationServlet(RestServlet):
to_token=to_token, to_token=to_token,
) )
events = await self.store.get_events_as_list( return 200, result
[c["event_id"] for c in pagination_chunk.chunk]
)
now = self.clock.time_msec()
# Do not bundle aggregations when retrieving the original event because
# we want the content before relations are applied to it.
original_event = self._event_serializer.serialize_event(
event, now, bundle_aggregations=None
)
# The relations returned for the requested event do include their
# bundled aggregations.
aggregations = await self.store.get_bundled_aggregations(
events, requester.user.to_string()
)
serialized_events = self._event_serializer.serialize_events(
events, now, bundle_aggregations=aggregations
)
return_value = await pagination_chunk.to_dict(self.store)
return_value["chunk"] = serialized_events
return_value["original_event"] = original_event
return 200, return_value
class RelationAggregationPaginationServlet(RestServlet): class RelationAggregationPaginationServlet(RestServlet):
@ -245,9 +210,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.clock = hs.get_clock() self._relations_handler = hs.get_relations_handler()
self._event_serializer = hs.get_event_client_serializer()
self.event_handler = hs.get_event_handler()
async def on_GET( async def on_GET(
self, self,
@ -260,18 +223,6 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
await self.auth.check_user_in_room_or_world_readable(
room_id,
requester.user.to_string(),
allow_departed_users=True,
)
# This checks that a) the event exists and b) the user is allowed to
# view it.
event = await self.event_handler.get_event(requester.user, room_id, parent_id)
if event is None:
raise SynapseError(404, "Unknown parent event.")
if relation_type != RelationTypes.ANNOTATION: if relation_type != RelationTypes.ANNOTATION:
raise SynapseError(400, "Relation type must be 'annotation'") raise SynapseError(400, "Relation type must be 'annotation'")
@ -286,9 +237,9 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
if to_token_str: if to_token_str:
to_token = await StreamToken.from_string(self.store, to_token_str) to_token = await StreamToken.from_string(self.store, to_token_str)
result = await self.store.get_relations_for_event( result = await self._relations_handler.get_relations(
requester=requester,
event_id=parent_id, event_id=parent_id,
event=event,
room_id=room_id, room_id=room_id,
relation_type=relation_type, relation_type=relation_type,
event_type=event_type, event_type=event_type,
@ -298,17 +249,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
to_token=to_token, to_token=to_token,
) )
events = await self.store.get_events_as_list( return 200, result
[c["event_id"] for c in result.chunk]
)
now = self.clock.time_msec()
serialized_events = self._event_serializer.serialize_events(events, now)
return_value = await result.to_dict(self.store)
return_value["chunk"] = serialized_events
return 200, return_value
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:

View file

@ -649,6 +649,7 @@ class RoomEventServlet(RestServlet):
self._store = hs.get_datastores().main self._store = hs.get_datastores().main
self.event_handler = hs.get_event_handler() self.event_handler = hs.get_event_handler()
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
self._relations_handler = hs.get_relations_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
async def on_GET( async def on_GET(
@ -667,7 +668,7 @@ class RoomEventServlet(RestServlet):
if event: if event:
# Ensure there are bundled aggregations available. # Ensure there are bundled aggregations available.
aggregations = await self._store.get_bundled_aggregations( aggregations = await self._relations_handler.get_bundled_aggregations(
[event], requester.user.to_string() [event], requester.user.to_string()
) )

View file

@ -124,14 +124,14 @@ class RoomBatchSendEventRestServlet(RestServlet):
) )
# For the event we are inserting next to (`prev_event_ids_from_query`), # For the event we are inserting next to (`prev_event_ids_from_query`),
# find the most recent auth events (derived from state events) that # find the most recent state events that allowed that message to be
# allowed that message to be sent. We will use that as a base # sent. We will use that as a base to auth our historical messages
# to auth our historical messages against. # against.
auth_event_ids = await self.room_batch_handler.get_most_recent_auth_event_ids_from_event_id_list( state_event_ids = await self.room_batch_handler.get_most_recent_full_state_ids_from_event_id_list(
prev_event_ids_from_query prev_event_ids_from_query
) )
if not auth_event_ids: if not state_event_ids:
raise SynapseError( raise SynapseError(
HTTPStatus.BAD_REQUEST, HTTPStatus.BAD_REQUEST,
"No auth events found for given prev_event query parameter. The prev_event=%s probably does not exist." "No auth events found for given prev_event query parameter. The prev_event=%s probably does not exist."
@ -148,13 +148,13 @@ class RoomBatchSendEventRestServlet(RestServlet):
await self.room_batch_handler.persist_state_events_at_start( await self.room_batch_handler.persist_state_events_at_start(
state_events_at_start=body["state_events_at_start"], state_events_at_start=body["state_events_at_start"],
room_id=room_id, room_id=room_id,
initial_auth_event_ids=auth_event_ids, initial_state_event_ids=state_event_ids,
app_service_requester=requester, app_service_requester=requester,
) )
) )
# Update our ongoing auth event ID list with all of the new state we # Update our ongoing auth event ID list with all of the new state we
# just created # just created
auth_event_ids.extend(state_event_ids_at_start) state_event_ids.extend(state_event_ids_at_start)
inherited_depth = await self.room_batch_handler.inherit_depth_from_prev_ids( inherited_depth = await self.room_batch_handler.inherit_depth_from_prev_ids(
prev_event_ids_from_query prev_event_ids_from_query
@ -196,7 +196,12 @@ class RoomBatchSendEventRestServlet(RestServlet):
), ),
base_insertion_event_dict, base_insertion_event_dict,
prev_event_ids=base_insertion_event_dict.get("prev_events"), prev_event_ids=base_insertion_event_dict.get("prev_events"),
auth_event_ids=auth_event_ids, # Also set the explicit state here because we want to resolve
# any `state_events_at_start` here too. It's not strictly
# necessary to accomplish anything but if someone asks for the
# state at this point, we probably want to show them the
# historical state that was part of this batch.
state_event_ids=state_event_ids,
historical=True, historical=True,
depth=inherited_depth, depth=inherited_depth,
) )
@ -212,7 +217,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
room_id=room_id, room_id=room_id,
batch_id_to_connect_to=batch_id_to_connect_to, batch_id_to_connect_to=batch_id_to_connect_to,
inherited_depth=inherited_depth, inherited_depth=inherited_depth,
auth_event_ids=auth_event_ids, initial_state_event_ids=state_event_ids,
app_service_requester=requester, app_service_requester=requester,
) )

View file

@ -19,7 +19,7 @@ from synapse.api.errors import SynapseError
from synapse.http.server import HttpServer from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.types import JsonDict from synapse.types import JsonMapping
from ._base import client_patterns from ._base import client_patterns
@ -38,7 +38,7 @@ class UserDirectorySearchRestServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.user_directory_handler = hs.get_user_directory_handler() self.user_directory_handler = hs.get_user_directory_handler()
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonMapping]:
"""Searches for users in directory """Searches for users in directory
Returns: Returns:

View file

@ -16,7 +16,6 @@ import itertools
import logging import logging
import re import re
from typing import TYPE_CHECKING, Dict, Generator, Iterable, Optional, Set, Union from typing import TYPE_CHECKING, Dict, Generator, Iterable, Optional, Set, Union
from urllib import parse as urlparse
if TYPE_CHECKING: if TYPE_CHECKING:
from lxml import etree from lxml import etree
@ -144,9 +143,7 @@ def decode_body(
return etree.fromstring(body, parser) return etree.fromstring(body, parser)
def parse_html_to_open_graph( def parse_html_to_open_graph(tree: "etree.Element") -> Dict[str, Optional[str]]:
tree: "etree.Element", media_uri: str
) -> Dict[str, Optional[str]]:
""" """
Parse the HTML document into an Open Graph response. Parse the HTML document into an Open Graph response.
@ -155,7 +152,6 @@ def parse_html_to_open_graph(
Args: Args:
tree: The parsed HTML document. tree: The parsed HTML document.
media_url: The URI used to download the body.
Returns: Returns:
The Open Graph response as a dictionary. The Open Graph response as a dictionary.
@ -209,7 +205,7 @@ def parse_html_to_open_graph(
"//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content" "//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"
) )
if meta_image: if meta_image:
og["og:image"] = rebase_url(meta_image[0], media_uri) og["og:image"] = meta_image[0]
else: else:
# TODO: consider inlined CSS styles as well as width & height attribs # TODO: consider inlined CSS styles as well as width & height attribs
images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]") images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
@ -320,37 +316,6 @@ def _iterate_over_text(
) )
def rebase_url(url: str, base: str) -> str:
"""
Resolves a potentially relative `url` against an absolute `base` URL.
For example:
>>> rebase_url("subpage", "https://example.com/foo/")
'https://example.com/foo/subpage'
>>> rebase_url("sibling", "https://example.com/foo")
'https://example.com/sibling'
>>> rebase_url("/bar", "https://example.com/foo/")
'https://example.com/bar'
>>> rebase_url("https://alice.com/a/", "https://example.com/foo/")
'https://alice.com/a'
"""
base_parts = urlparse.urlparse(base)
# Convert the parsed URL to a list for (potential) modification.
url_parts = list(urlparse.urlparse(url))
# Add a scheme, if one does not exist.
if not url_parts[0]:
url_parts[0] = base_parts.scheme or "http"
# Fix up the hostname, if this is not a data URL.
if url_parts[0] != "data" and not url_parts[1]:
url_parts[1] = base_parts.netloc
# If the path does not start with a /, nest it under the base path's last
# directory.
if not url_parts[2].startswith("/"):
url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts.path) + url_parts[2]
return urlparse.urlunparse(url_parts)
def summarize_paragraphs( def summarize_paragraphs(
text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500 text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500
) -> Optional[str]: ) -> Optional[str]:

View file

@ -22,7 +22,7 @@ import shutil
import sys import sys
import traceback import traceback
from typing import TYPE_CHECKING, BinaryIO, Iterable, Optional, Tuple from typing import TYPE_CHECKING, BinaryIO, Iterable, Optional, Tuple
from urllib import parse as urlparse from urllib.parse import urljoin, urlparse, urlsplit
from urllib.request import urlopen from urllib.request import urlopen
import attr import attr
@ -44,11 +44,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.media.v1._base import get_filename_from_headers from synapse.rest.media.v1._base import get_filename_from_headers
from synapse.rest.media.v1.media_storage import MediaStorage from synapse.rest.media.v1.media_storage import MediaStorage
from synapse.rest.media.v1.oembed import OEmbedProvider from synapse.rest.media.v1.oembed import OEmbedProvider
from synapse.rest.media.v1.preview_html import ( from synapse.rest.media.v1.preview_html import decode_body, parse_html_to_open_graph
decode_body,
parse_html_to_open_graph,
rebase_url,
)
from synapse.types import JsonDict, UserID from synapse.types import JsonDict, UserID
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.async_helpers import ObservableDeferred from synapse.util.async_helpers import ObservableDeferred
@ -187,7 +183,7 @@ class PreviewUrlResource(DirectServeJsonResource):
ts = self.clock.time_msec() ts = self.clock.time_msec()
# XXX: we could move this into _do_preview if we wanted. # XXX: we could move this into _do_preview if we wanted.
url_tuple = urlparse.urlsplit(url) url_tuple = urlsplit(url)
for entry in self.url_preview_url_blacklist: for entry in self.url_preview_url_blacklist:
match = True match = True
for attrib in entry: for attrib in entry:
@ -322,7 +318,7 @@ class PreviewUrlResource(DirectServeJsonResource):
# Parse Open Graph information from the HTML in case the oEmbed # Parse Open Graph information from the HTML in case the oEmbed
# response failed or is incomplete. # response failed or is incomplete.
og_from_html = parse_html_to_open_graph(tree, media_info.uri) og_from_html = parse_html_to_open_graph(tree)
# Compile the Open Graph response by using the scraped # Compile the Open Graph response by using the scraped
# information from the HTML and overlaying any information # information from the HTML and overlaying any information
@ -588,12 +584,17 @@ class PreviewUrlResource(DirectServeJsonResource):
if "og:image" not in og or not og["og:image"]: if "og:image" not in og or not og["og:image"]:
return return
# The image URL from the HTML might be relative to the previewed page,
# convert it to an URL which can be requested directly.
image_url = og["og:image"]
url_parts = urlparse(image_url)
if url_parts.scheme != "data":
image_url = urljoin(media_info.uri, image_url)
# FIXME: it might be cleaner to use the same flow as the main /preview_url # FIXME: it might be cleaner to use the same flow as the main /preview_url
# request itself and benefit from the same caching etc. But for now we # request itself and benefit from the same caching etc. But for now we
# just rely on the caching on the master request to speed things up. # just rely on the caching on the master request to speed things up.
image_info = await self._handle_url( image_info = await self._handle_url(image_url, user, allow_data_urls=True)
rebase_url(og["og:image"], media_info.uri), user, allow_data_urls=True
)
if _is_media(image_info.media_type): if _is_media(image_info.media_type):
# TODO: make sure we don't choke on white-on-transparent images # TODO: make sure we don't choke on white-on-transparent images

View file

@ -94,6 +94,7 @@ from synapse.handlers.profile import ProfileHandler
from synapse.handlers.read_marker import ReadMarkerHandler from synapse.handlers.read_marker import ReadMarkerHandler
from synapse.handlers.receipts import ReceiptsHandler from synapse.handlers.receipts import ReceiptsHandler
from synapse.handlers.register import RegistrationHandler from synapse.handlers.register import RegistrationHandler
from synapse.handlers.relations import RelationsHandler
from synapse.handlers.room import ( from synapse.handlers.room import (
RoomContextHandler, RoomContextHandler,
RoomCreationHandler, RoomCreationHandler,
@ -719,6 +720,10 @@ class HomeServer(metaclass=abc.ABCMeta):
def get_pagination_handler(self) -> PaginationHandler: def get_pagination_handler(self) -> PaginationHandler:
return PaginationHandler(self) return PaginationHandler(self)
@cache_in_self
def get_relations_handler(self) -> RelationsHandler:
return RelationsHandler(self)
@cache_in_self @cache_in_self
def get_room_context_handler(self) -> RoomContextHandler: def get_room_context_handler(self) -> RoomContextHandler:
return RoomContextHandler(self) return RoomContextHandler(self)

View file

@ -41,6 +41,7 @@ from prometheus_client import Histogram
from typing_extensions import Literal from typing_extensions import Literal
from twisted.enterprise import adbapi from twisted.enterprise import adbapi
from twisted.internet import defer
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.config.database import DatabaseConnectionConfig from synapse.config.database import DatabaseConnectionConfig
@ -55,6 +56,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.types import Connection, Cursor from synapse.storage.types import Connection, Cursor
from synapse.util.async_helpers import delay_cancellation
from synapse.util.iterutils import batch_iter from synapse.util.iterutils import batch_iter
if TYPE_CHECKING: if TYPE_CHECKING:
@ -286,13 +288,17 @@ class LoggingTransaction:
""" """
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
from psycopg2.extras import execute_batch # type: ignore from psycopg2.extras import execute_batch
self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args) self._do_execute(
lambda the_sql: execute_batch(self.txn, the_sql, args), sql
)
else: else:
self.executemany(sql, args) self.executemany(sql, args)
def execute_values(self, sql: str, *args: Any, fetch: bool = True) -> List[Tuple]: def execute_values(
self, sql: str, values: Iterable[Iterable[Any]], fetch: bool = True
) -> List[Tuple]:
"""Corresponds to psycopg2.extras.execute_values. Only available when """Corresponds to psycopg2.extras.execute_values. Only available when
using postgres. using postgres.
@ -300,10 +306,11 @@ class LoggingTransaction:
rows (e.g. INSERTs). rows (e.g. INSERTs).
""" """
assert isinstance(self.database_engine, PostgresEngine) assert isinstance(self.database_engine, PostgresEngine)
from psycopg2.extras import execute_values # type: ignore from psycopg2.extras import execute_values
return self._do_execute( return self._do_execute(
lambda *x: execute_values(self.txn, *x, fetch=fetch), sql, *args lambda the_sql: execute_values(self.txn, the_sql, values, fetch=fetch),
sql,
) )
def execute(self, sql: str, *args: Any) -> None: def execute(self, sql: str, *args: Any) -> None:
@ -732,6 +739,8 @@ class DatabasePool:
Returns: Returns:
The result of func The result of func
""" """
async def _runInteraction() -> R:
after_callbacks: List[_CallbackListEntry] = [] after_callbacks: List[_CallbackListEntry] = []
exception_callbacks: List[_CallbackListEntry] = [] exception_callbacks: List[_CallbackListEntry] = []
@ -754,12 +763,21 @@ class DatabasePool:
for after_callback, after_args, after_kwargs in after_callbacks: for after_callback, after_args, after_kwargs in after_callbacks:
after_callback(*after_args, **after_kwargs) after_callback(*after_args, **after_kwargs)
return cast(R, result)
except Exception: except Exception:
for after_callback, after_args, after_kwargs in exception_callbacks: for after_callback, after_args, after_kwargs in exception_callbacks:
after_callback(*after_args, **after_kwargs) after_callback(*after_args, **after_kwargs)
raise raise
return cast(R, result) # To handle cancellation, we ensure that `after_callback`s and
# `exception_callback`s are always run, since the transaction will complete
# on another thread regardless of cancellation.
#
# We also wait until everything above is done before releasing the
# `CancelledError`, so that logging contexts won't get used after they have been
# finished.
return await delay_cancellation(defer.ensureDeferred(_runInteraction()))
async def runWithConnection( async def runWithConnection(
self, self,

View file

@ -14,7 +14,17 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, cast from typing import (
TYPE_CHECKING,
Any,
Dict,
FrozenSet,
Iterable,
List,
Optional,
Tuple,
cast,
)
from synapse.api.constants import AccountDataTypes from synapse.api.constants import AccountDataTypes
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
@ -365,7 +375,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
) )
@cached(max_entries=5000, iterable=True) @cached(max_entries=5000, iterable=True)
async def ignored_by(self, user_id: str) -> Set[str]: async def ignored_by(self, user_id: str) -> FrozenSet[str]:
""" """
Get users which ignore the given user. Get users which ignore the given user.
@ -375,7 +385,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
Return: Return:
The user IDs which ignore the given user. The user IDs which ignore the given user.
""" """
return set( return frozenset(
await self.db_pool.simple_select_onecol( await self.db_pool.simple_select_onecol(
table="ignored_users", table="ignored_users",
keyvalues={"ignored_user_id": user_id}, keyvalues={"ignored_user_id": user_id},
@ -384,6 +394,26 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
) )
) )
@cached(max_entries=5000, iterable=True)
async def ignored_users(self, user_id: str) -> FrozenSet[str]:
"""
Get users which the given user ignores.
Params:
user_id: The user ID which is making the request.
Return:
The user IDs which are ignored by the given user.
"""
return frozenset(
await self.db_pool.simple_select_onecol(
table="ignored_users",
keyvalues={"ignorer_user_id": user_id},
retcol="ignored_user_id",
desc="ignored_users",
)
)
def process_replication_rows( def process_replication_rows(
self, self,
stream_name: str, stream_name: str,
@ -529,6 +559,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
else: else:
currently_ignored_users = set() currently_ignored_users = set()
# If the data has not changed, nothing to do.
if previously_ignored_users == currently_ignored_users:
return
# Delete entries which are no longer ignored. # Delete entries which are no longer ignored.
self.db_pool.simple_delete_many_txn( self.db_pool.simple_delete_many_txn(
txn, txn,
@ -551,6 +585,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
# Invalidate the cache for any ignored users which were added or removed. # Invalidate the cache for any ignored users which were added or removed.
for ignored_user_id in previously_ignored_users ^ currently_ignored_users: for ignored_user_id in previously_ignored_users ^ currently_ignored_users:
self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,)) self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,))
self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,))
async def purge_account_data_for_user(self, user_id: str) -> None: async def purge_account_data_for_user(self, user_id: str) -> None:
""" """

View file

@ -23,6 +23,7 @@ from synapse.replication.tcp.streams.events import (
EventsStream, EventsStream,
EventsStreamCurrentStateRow, EventsStreamCurrentStateRow,
EventsStreamEventRow, EventsStreamEventRow,
EventsStreamRow,
) )
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import ( from synapse.storage.database import (
@ -31,6 +32,7 @@ from synapse.storage.database import (
LoggingTransaction, LoggingTransaction,
) )
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.util.caches.descriptors import _CachedFunction
from synapse.util.iterutils import batch_iter from synapse.util.iterutils import batch_iter
if TYPE_CHECKING: if TYPE_CHECKING:
@ -82,7 +84,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
if last_id == current_id: if last_id == current_id:
return [], current_id, False return [], current_id, False
def get_all_updated_caches_txn(txn): def get_all_updated_caches_txn(
txn: LoggingTransaction,
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
# We purposefully don't bound by the current token, as we want to # We purposefully don't bound by the current token, as we want to
# send across cache invalidations as quickly as possible. Cache # send across cache invalidations as quickly as possible. Cache
# invalidations are idempotent, so duplicates are fine. # invalidations are idempotent, so duplicates are fine.
@ -107,7 +111,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
"get_all_updated_caches", get_all_updated_caches_txn "get_all_updated_caches", get_all_updated_caches_txn
) )
def process_replication_rows(self, stream_name, instance_name, token, rows): def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None:
if stream_name == EventsStream.NAME: if stream_name == EventsStream.NAME:
for row in rows: for row in rows:
self._process_event_stream_row(token, row) self._process_event_stream_row(token, row)
@ -142,10 +148,11 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
super().process_replication_rows(stream_name, instance_name, token, rows) super().process_replication_rows(stream_name, instance_name, token, rows)
def _process_event_stream_row(self, token, row): def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None:
data = row.data data = row.data
if row.type == EventsStreamEventRow.TypeId: if row.type == EventsStreamEventRow.TypeId:
assert isinstance(data, EventsStreamEventRow)
self._invalidate_caches_for_event( self._invalidate_caches_for_event(
token, token,
data.event_id, data.event_id,
@ -157,9 +164,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
backfilled=False, backfilled=False,
) )
elif row.type == EventsStreamCurrentStateRow.TypeId: elif row.type == EventsStreamCurrentStateRow.TypeId:
self._curr_state_delta_stream_cache.entity_has_changed( assert isinstance(data, EventsStreamCurrentStateRow)
row.data.room_id, token self._curr_state_delta_stream_cache.entity_has_changed(data.room_id, token)
)
if data.type == EventTypes.Member: if data.type == EventTypes.Member:
self.get_rooms_for_user_with_stream_ordering.invalidate( self.get_rooms_for_user_with_stream_ordering.invalidate(
@ -170,15 +176,15 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
def _invalidate_caches_for_event( def _invalidate_caches_for_event(
self, self,
stream_ordering, stream_ordering: int,
event_id, event_id: str,
room_id, room_id: str,
etype, etype: str,
state_key, state_key: Optional[str],
redacts, redacts: Optional[str],
relates_to, relates_to: Optional[str],
backfilled, backfilled: bool,
): ) -> None:
self._invalidate_get_event_cache(event_id) self._invalidate_get_event_cache(event_id)
self.have_seen_event.invalidate((room_id, event_id)) self.have_seen_event.invalidate((room_id, event_id))
@ -186,6 +192,10 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self.get_unread_event_push_actions_by_room_for_user.invalidate((room_id,)) self.get_unread_event_push_actions_by_room_for_user.invalidate((room_id,))
# The `_get_membership_from_event_id` is immutable, except for the
# case where we look up an event *before* persisting it.
self._get_membership_from_event_id.invalidate((event_id,))
if not backfilled: if not backfilled:
self._events_stream_cache.entity_has_changed(room_id, stream_ordering) self._events_stream_cache.entity_has_changed(room_id, stream_ordering)
@ -207,7 +217,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self.get_thread_summary.invalidate((relates_to,)) self.get_thread_summary.invalidate((relates_to,))
self.get_thread_participated.invalidate((relates_to,)) self.get_thread_participated.invalidate((relates_to,))
async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]): async def invalidate_cache_and_stream(
self, cache_name: str, keys: Tuple[Any, ...]
) -> None:
"""Invalidates the cache and adds it to the cache stream so slaves """Invalidates the cache and adds it to the cache stream so slaves
will know to invalidate their caches. will know to invalidate their caches.
@ -227,7 +239,12 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
keys, keys,
) )
def _invalidate_cache_and_stream(self, txn, cache_func, keys): def _invalidate_cache_and_stream(
self,
txn: LoggingTransaction,
cache_func: _CachedFunction,
keys: Tuple[Any, ...],
) -> None:
"""Invalidates the cache and adds it to the cache stream so slaves """Invalidates the cache and adds it to the cache stream so slaves
will know to invalidate their caches. will know to invalidate their caches.
@ -238,7 +255,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
txn.call_after(cache_func.invalidate, keys) txn.call_after(cache_func.invalidate, keys)
self._send_invalidation_to_replication(txn, cache_func.__name__, keys) self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
def _invalidate_all_cache_and_stream(self, txn, cache_func): def _invalidate_all_cache_and_stream(
self, txn: LoggingTransaction, cache_func: _CachedFunction
) -> None:
"""Invalidates the entire cache and adds it to the cache stream so slaves """Invalidates the entire cache and adds it to the cache stream so slaves
will know to invalidate their caches. will know to invalidate their caches.
""" """
@ -279,8 +298,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
) )
def _send_invalidation_to_replication( def _send_invalidation_to_replication(
self, txn, cache_name: str, keys: Optional[Iterable[Any]] self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[Any]]
): ) -> None:
"""Notifies replication that given cache has been invalidated. """Notifies replication that given cache has been invalidated.
Note that this does *not* invalidate the cache locally. Note that this does *not* invalidate the cache locally.
@ -315,7 +334,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
"instance_name": self._instance_name, "instance_name": self._instance_name,
"cache_func": cache_name, "cache_func": cache_name,
"keys": keys, "keys": keys,
"invalidation_ts": self.clock.time_msec(), "invalidation_ts": self._clock.time_msec(),
}, },
) )

View file

@ -1073,9 +1073,15 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
/* Get the depth and stream_ordering of the prev_event_id from the events table */ /* Get the depth and stream_ordering of the prev_event_id from the events table */
INNER JOIN events INNER JOIN events
ON prev_event_id = events.event_id ON prev_event_id = events.event_id
/* exclude outliers from the results (we don't have the state, so cannot
* verify if the requesting server can see them).
*/
WHERE NOT events.outlier
/* Look for an edge which matches the given event_id */ /* Look for an edge which matches the given event_id */
WHERE event_edges.event_id = ? AND event_edges.event_id = ? AND NOT event_edges.is_state
AND event_edges.is_state = ?
/* Because we can have many events at the same depth, /* Because we can have many events at the same depth,
* we want to also tie-break and sort on stream_ordering */ * we want to also tie-break and sort on stream_ordering */
ORDER BY depth DESC, stream_ordering DESC ORDER BY depth DESC, stream_ordering DESC
@ -1084,7 +1090,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
txn.execute( txn.execute(
connected_prev_event_query, connected_prev_event_query,
(event_id, False, limit), (event_id, limit),
) )
return [ return [
BackfillQueueNavigationItem( BackfillQueueNavigationItem(

View file

@ -1745,6 +1745,13 @@ class PersistEventsStore:
(event.state_key,), (event.state_key,),
) )
# The `_get_membership_from_event_id` is immutable, except for the
# case where we look up an event *before* persisting it.
txn.call_after(
self.store._get_membership_from_event_id.invalidate,
(event.event_id,),
)
# We update the local_current_membership table only if the event is # We update the local_current_membership table only if the event is
# "current", i.e., its something that has just happened. # "current", i.e., its something that has just happened.
# #

View file

@ -13,13 +13,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast
from typing_extensions import TypedDict from typing_extensions import TypedDict
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import json_encoder from synapse.util import json_encoder
@ -75,7 +79,7 @@ class GroupServerWorkerStore(SQLBaseStore):
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
# TODO: Pagination # TODO: Pagination
keyvalues = {"group_id": group_id} keyvalues: JsonDict = {"group_id": group_id}
if not include_private: if not include_private:
keyvalues["is_public"] = True keyvalues["is_public"] = True
@ -117,7 +121,7 @@ class GroupServerWorkerStore(SQLBaseStore):
# TODO: Pagination # TODO: Pagination
def _get_rooms_in_group_txn(txn): def _get_rooms_in_group_txn(txn: LoggingTransaction) -> List[_RoomInGroup]:
sql = """ sql = """
SELECT room_id, is_public FROM group_rooms SELECT room_id, is_public FROM group_rooms
WHERE group_id = ? WHERE group_id = ?
@ -176,8 +180,10 @@ class GroupServerWorkerStore(SQLBaseStore):
* "order": int, the sort order of rooms in this category * "order": int, the sort order of rooms in this category
""" """
def _get_rooms_for_summary_txn(txn): def _get_rooms_for_summary_txn(
keyvalues = {"group_id": group_id} txn: LoggingTransaction,
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
keyvalues: JsonDict = {"group_id": group_id}
if not include_private: if not include_private:
keyvalues["is_public"] = True keyvalues["is_public"] = True
@ -241,7 +247,7 @@ class GroupServerWorkerStore(SQLBaseStore):
"get_rooms_for_summary", _get_rooms_for_summary_txn "get_rooms_for_summary", _get_rooms_for_summary_txn
) )
async def get_group_categories(self, group_id): async def get_group_categories(self, group_id: str) -> JsonDict:
rows = await self.db_pool.simple_select_list( rows = await self.db_pool.simple_select_list(
table="group_room_categories", table="group_room_categories",
keyvalues={"group_id": group_id}, keyvalues={"group_id": group_id},
@ -257,7 +263,7 @@ class GroupServerWorkerStore(SQLBaseStore):
for row in rows for row in rows
} }
async def get_group_category(self, group_id, category_id): async def get_group_category(self, group_id: str, category_id: str) -> JsonDict:
category = await self.db_pool.simple_select_one( category = await self.db_pool.simple_select_one(
table="group_room_categories", table="group_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id}, keyvalues={"group_id": group_id, "category_id": category_id},
@ -269,7 +275,7 @@ class GroupServerWorkerStore(SQLBaseStore):
return category return category
async def get_group_roles(self, group_id): async def get_group_roles(self, group_id: str) -> JsonDict:
rows = await self.db_pool.simple_select_list( rows = await self.db_pool.simple_select_list(
table="group_roles", table="group_roles",
keyvalues={"group_id": group_id}, keyvalues={"group_id": group_id},
@ -285,7 +291,7 @@ class GroupServerWorkerStore(SQLBaseStore):
for row in rows for row in rows
} }
async def get_group_role(self, group_id, role_id): async def get_group_role(self, group_id: str, role_id: str) -> JsonDict:
role = await self.db_pool.simple_select_one( role = await self.db_pool.simple_select_one(
table="group_roles", table="group_roles",
keyvalues={"group_id": group_id, "role_id": role_id}, keyvalues={"group_id": group_id, "role_id": role_id},
@ -311,15 +317,19 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_local_groups_for_room", desc="get_local_groups_for_room",
) )
async def get_users_for_summary_by_role(self, group_id, include_private=False): async def get_users_for_summary_by_role(
self, group_id: str, include_private: bool = False
) -> Tuple[List[JsonDict], JsonDict]:
"""Get the users and roles that should be included in a summary request """Get the users and roles that should be included in a summary request
Returns: Returns:
([users], [roles]) ([users], [roles])
""" """
def _get_users_for_summary_txn(txn): def _get_users_for_summary_txn(
keyvalues = {"group_id": group_id} txn: LoggingTransaction,
) -> Tuple[List[JsonDict], JsonDict]:
keyvalues: JsonDict = {"group_id": group_id}
if not include_private: if not include_private:
keyvalues["is_public"] = True keyvalues["is_public"] = True
@ -406,7 +416,9 @@ class GroupServerWorkerStore(SQLBaseStore):
allow_none=True, allow_none=True,
) )
async def get_users_membership_info_in_group(self, group_id, user_id): async def get_users_membership_info_in_group(
self, group_id: str, user_id: str
) -> JsonDict:
"""Get a dict describing the membership of a user in a group. """Get a dict describing the membership of a user in a group.
Example if joined: Example if joined:
@ -421,7 +433,7 @@ class GroupServerWorkerStore(SQLBaseStore):
An empty dict if the user is not join/invite/etc An empty dict if the user is not join/invite/etc
""" """
def _get_users_membership_in_group_txn(txn): def _get_users_membership_in_group_txn(txn: LoggingTransaction) -> JsonDict:
row = self.db_pool.simple_select_one_txn( row = self.db_pool.simple_select_one_txn(
txn, txn,
table="group_users", table="group_users",
@ -463,10 +475,14 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_publicised_groups_for_user", desc="get_publicised_groups_for_user",
) )
async def get_attestations_need_renewals(self, valid_until_ms): async def get_attestations_need_renewals(
self, valid_until_ms: int
) -> List[Dict[str, Any]]:
"""Get all attestations that need to be renewed until givent time""" """Get all attestations that need to be renewed until givent time"""
def _get_attestations_need_renewals_txn(txn): def _get_attestations_need_renewals_txn(
txn: LoggingTransaction,
) -> List[Dict[str, Any]]:
sql = """ sql = """
SELECT group_id, user_id FROM group_attestations_renewals SELECT group_id, user_id FROM group_attestations_renewals
WHERE valid_until_ms <= ? WHERE valid_until_ms <= ?
@ -478,7 +494,9 @@ class GroupServerWorkerStore(SQLBaseStore):
"get_attestations_need_renewals", _get_attestations_need_renewals_txn "get_attestations_need_renewals", _get_attestations_need_renewals_txn
) )
async def get_remote_attestation(self, group_id, user_id): async def get_remote_attestation(
self, group_id: str, user_id: str
) -> Optional[JsonDict]:
"""Get the attestation that proves the remote agrees that the user is """Get the attestation that proves the remote agrees that the user is
in the group. in the group.
""" """
@ -504,8 +522,8 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_joined_groups", desc="get_joined_groups",
) )
async def get_all_groups_for_user(self, user_id, now_token): async def get_all_groups_for_user(self, user_id, now_token) -> List[JsonDict]:
def _get_all_groups_for_user_txn(txn): def _get_all_groups_for_user_txn(txn: LoggingTransaction) -> List[JsonDict]:
sql = """ sql = """
SELECT group_id, type, membership, u.content SELECT group_id, type, membership, u.content
FROM local_group_updates AS u FROM local_group_updates AS u
@ -528,15 +546,16 @@ class GroupServerWorkerStore(SQLBaseStore):
"get_all_groups_for_user", _get_all_groups_for_user_txn "get_all_groups_for_user", _get_all_groups_for_user_txn
) )
async def get_groups_changes_for_user(self, user_id, from_token, to_token): async def get_groups_changes_for_user(
from_token = int(from_token) self, user_id: str, from_token: int, to_token: int
has_changed = self._group_updates_stream_cache.has_entity_changed( ) -> List[JsonDict]:
has_changed = self._group_updates_stream_cache.has_entity_changed( # type: ignore[attr-defined]
user_id, from_token user_id, from_token
) )
if not has_changed: if not has_changed:
return [] return []
def _get_groups_changes_for_user_txn(txn): def _get_groups_changes_for_user_txn(txn: LoggingTransaction) -> List[JsonDict]:
sql = """ sql = """
SELECT group_id, membership, type, u.content SELECT group_id, membership, type, u.content
FROM local_group_updates AS u FROM local_group_updates AS u
@ -583,12 +602,14 @@ class GroupServerWorkerStore(SQLBaseStore):
""" """
last_id = int(last_id) last_id = int(last_id)
has_changed = self._group_updates_stream_cache.has_any_entity_changed(last_id) has_changed = self._group_updates_stream_cache.has_any_entity_changed(last_id) # type: ignore[attr-defined]
if not has_changed: if not has_changed:
return [], current_id, False return [], current_id, False
def _get_all_groups_changes_txn(txn): def _get_all_groups_changes_txn(
txn: LoggingTransaction,
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
sql = """ sql = """
SELECT stream_id, group_id, user_id, type, content SELECT stream_id, group_id, user_id, type, content
FROM local_group_updates FROM local_group_updates
@ -596,10 +617,13 @@ class GroupServerWorkerStore(SQLBaseStore):
LIMIT ? LIMIT ?
""" """
txn.execute(sql, (last_id, current_id, limit)) txn.execute(sql, (last_id, current_id, limit))
updates = [ updates = cast(
List[Tuple[int, tuple]],
[
(stream_id, (group_id, user_id, gtype, db_to_json(content_json))) (stream_id, (group_id, user_id, gtype, db_to_json(content_json)))
for stream_id, group_id, user_id, gtype, content_json in txn for stream_id, group_id, user_id, gtype, content_json in txn
] ],
)
limited = False limited = False
upto_token = current_id upto_token = current_id
@ -633,8 +657,8 @@ class GroupServerStore(GroupServerWorkerStore):
self, self,
group_id: str, group_id: str,
room_id: str, room_id: str,
category_id: str, category_id: Optional[str],
order: int, order: Optional[int],
is_public: Optional[bool], is_public: Optional[bool],
) -> None: ) -> None:
"""Add (or update) room's entry in summary. """Add (or update) room's entry in summary.
@ -661,11 +685,11 @@ class GroupServerStore(GroupServerWorkerStore):
def _add_room_to_summary_txn( def _add_room_to_summary_txn(
self, self,
txn, txn: LoggingTransaction,
group_id: str, group_id: str,
room_id: str, room_id: str,
category_id: str, category_id: Optional[str],
order: int, order: Optional[int],
is_public: Optional[bool], is_public: Optional[bool],
) -> None: ) -> None:
"""Add (or update) room's entry in summary. """Add (or update) room's entry in summary.
@ -750,7 +774,7 @@ class GroupServerStore(GroupServerWorkerStore):
WHERE group_id = ? AND category_id = ? WHERE group_id = ? AND category_id = ?
""" """
txn.execute(sql, (group_id, category_id)) txn.execute(sql, (group_id, category_id))
(order,) = txn.fetchone() (order,) = cast(Tuple[int], txn.fetchone())
if existing: if existing:
to_update = {} to_update = {}
@ -766,7 +790,7 @@ class GroupServerStore(GroupServerWorkerStore):
"category_id": category_id, "category_id": category_id,
"room_id": room_id, "room_id": room_id,
}, },
values=to_update, updatevalues=to_update,
) )
else: else:
if is_public is None: if is_public is None:
@ -785,7 +809,7 @@ class GroupServerStore(GroupServerWorkerStore):
) )
async def remove_room_from_summary( async def remove_room_from_summary(
self, group_id: str, room_id: str, category_id: str self, group_id: str, room_id: str, category_id: Optional[str]
) -> int: ) -> int:
if category_id is None: if category_id is None:
category_id = _DEFAULT_CATEGORY_ID category_id = _DEFAULT_CATEGORY_ID
@ -808,8 +832,8 @@ class GroupServerStore(GroupServerWorkerStore):
is_public: Optional[bool], is_public: Optional[bool],
) -> None: ) -> None:
"""Add/update room category for group""" """Add/update room category for group"""
insertion_values = {} insertion_values: JsonDict = {}
update_values = {"category_id": category_id} # This cannot be empty update_values: JsonDict = {"category_id": category_id} # This cannot be empty
if profile is None: if profile is None:
insertion_values["profile"] = "{}" insertion_values["profile"] = "{}"
@ -844,8 +868,8 @@ class GroupServerStore(GroupServerWorkerStore):
is_public: Optional[bool], is_public: Optional[bool],
) -> None: ) -> None:
"""Add/remove user role""" """Add/remove user role"""
insertion_values = {} insertion_values: JsonDict = {}
update_values = {"role_id": role_id} # This cannot be empty update_values: JsonDict = {"role_id": role_id} # This cannot be empty
if profile is None: if profile is None:
insertion_values["profile"] = "{}" insertion_values["profile"] = "{}"
@ -876,8 +900,8 @@ class GroupServerStore(GroupServerWorkerStore):
self, self,
group_id: str, group_id: str,
user_id: str, user_id: str,
role_id: str, role_id: Optional[str],
order: int, order: Optional[int],
is_public: Optional[bool], is_public: Optional[bool],
) -> None: ) -> None:
"""Add (or update) user's entry in summary. """Add (or update) user's entry in summary.
@ -904,13 +928,13 @@ class GroupServerStore(GroupServerWorkerStore):
def _add_user_to_summary_txn( def _add_user_to_summary_txn(
self, self,
txn, txn: LoggingTransaction,
group_id: str, group_id: str,
user_id: str, user_id: str,
role_id: str, role_id: Optional[str],
order: int, order: Optional[int],
is_public: Optional[bool], is_public: Optional[bool],
): ) -> None:
"""Add (or update) user's entry in summary. """Add (or update) user's entry in summary.
Args: Args:
@ -989,7 +1013,7 @@ class GroupServerStore(GroupServerWorkerStore):
WHERE group_id = ? AND role_id = ? WHERE group_id = ? AND role_id = ?
""" """
txn.execute(sql, (group_id, role_id)) txn.execute(sql, (group_id, role_id))
(order,) = txn.fetchone() (order,) = cast(Tuple[int], txn.fetchone())
if existing: if existing:
to_update = {} to_update = {}
@ -1005,7 +1029,7 @@ class GroupServerStore(GroupServerWorkerStore):
"role_id": role_id, "role_id": role_id,
"user_id": user_id, "user_id": user_id,
}, },
values=to_update, updatevalues=to_update,
) )
else: else:
if is_public is None: if is_public is None:
@ -1024,7 +1048,7 @@ class GroupServerStore(GroupServerWorkerStore):
) )
async def remove_user_from_summary( async def remove_user_from_summary(
self, group_id: str, user_id: str, role_id: str self, group_id: str, user_id: str, role_id: Optional[str]
) -> int: ) -> int:
if role_id is None: if role_id is None:
role_id = _DEFAULT_ROLE_ID role_id = _DEFAULT_ROLE_ID
@ -1065,7 +1089,7 @@ class GroupServerStore(GroupServerWorkerStore):
Optional if the user and group are on the same server Optional if the user and group are on the same server
""" """
def _add_user_to_group_txn(txn): def _add_user_to_group_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_insert_txn( self.db_pool.simple_insert_txn(
txn, txn,
table="group_users", table="group_users",
@ -1108,7 +1132,7 @@ class GroupServerStore(GroupServerWorkerStore):
await self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn) await self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn)
async def remove_user_from_group(self, group_id: str, user_id: str) -> None: async def remove_user_from_group(self, group_id: str, user_id: str) -> None:
def _remove_user_from_group_txn(txn): def _remove_user_from_group_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, txn,
table="group_users", table="group_users",
@ -1159,7 +1183,7 @@ class GroupServerStore(GroupServerWorkerStore):
) )
async def remove_room_from_group(self, group_id: str, room_id: str) -> None: async def remove_room_from_group(self, group_id: str, room_id: str) -> None:
def _remove_room_from_group_txn(txn): def _remove_room_from_group_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, txn,
table="group_rooms", table="group_rooms",
@ -1216,7 +1240,9 @@ class GroupServerStore(GroupServerWorkerStore):
content = content or {} content = content or {}
def _register_user_group_membership_txn(txn, next_id): def _register_user_group_membership_txn(
txn: LoggingTransaction, next_id: int
) -> int:
# TODO: Upsert? # TODO: Upsert?
self.db_pool.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, txn,
@ -1249,7 +1275,7 @@ class GroupServerStore(GroupServerWorkerStore):
), ),
}, },
) )
self._group_updates_stream_cache.entity_has_changed(user_id, next_id) self._group_updates_stream_cache.entity_has_changed(user_id, next_id) # type: ignore[attr-defined]
# TODO: Insert profile to ensure it comes down stream if its a join. # TODO: Insert profile to ensure it comes down stream if its a join.
@ -1289,7 +1315,7 @@ class GroupServerStore(GroupServerWorkerStore):
return next_id return next_id
async with self._group_updates_id_gen.get_next() as next_id: async with self._group_updates_id_gen.get_next() as next_id: # type: ignore[attr-defined]
res = await self.db_pool.runInteraction( res = await self.db_pool.runInteraction(
"register_user_group_membership", "register_user_group_membership",
_register_user_group_membership_txn, _register_user_group_membership_txn,
@ -1298,7 +1324,13 @@ class GroupServerStore(GroupServerWorkerStore):
return res return res
async def create_group( async def create_group(
self, group_id, user_id, name, avatar_url, short_description, long_description self,
group_id: str,
user_id: str,
name: str,
avatar_url: str,
short_description: str,
long_description: str,
) -> None: ) -> None:
await self.db_pool.simple_insert( await self.db_pool.simple_insert(
table="groups", table="groups",
@ -1313,7 +1345,7 @@ class GroupServerStore(GroupServerWorkerStore):
desc="create_group", desc="create_group",
) )
async def update_group_profile(self, group_id, profile): async def update_group_profile(self, group_id: str, profile: JsonDict) -> None:
await self.db_pool.simple_update_one( await self.db_pool.simple_update_one(
table="groups", table="groups",
keyvalues={"group_id": group_id}, keyvalues={"group_id": group_id},
@ -1361,8 +1393,8 @@ class GroupServerStore(GroupServerWorkerStore):
desc="remove_attestation_renewal", desc="remove_attestation_renewal",
) )
def get_group_stream_token(self): def get_group_stream_token(self) -> int:
return self._group_updates_id_gen.get_current_token() return self._group_updates_id_gen.get_current_token() # type: ignore[attr-defined]
async def delete_group(self, group_id: str) -> None: async def delete_group(self, group_id: str) -> None:
"""Deletes a group fully from the database. """Deletes a group fully from the database.
@ -1371,7 +1403,7 @@ class GroupServerStore(GroupServerWorkerStore):
group_id: The group ID to delete. group_id: The group ID to delete.
""" """
def _delete_group_txn(txn): def _delete_group_txn(txn: LoggingTransaction) -> None:
tables = [ tables = [
"groups", "groups",
"group_users", "group_users",

View file

@ -156,7 +156,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
hs: "HomeServer", hs: "HomeServer",
): ):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.server_name = hs.hostname self.server_name: str = hs.hostname
async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]: async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]:
"""Get the metadata for a local piece of media """Get the metadata for a local piece of media

View file

@ -12,15 +12,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Dict, List, Optional from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import ( from synapse.storage.database import (
DatabasePool, DatabasePool,
LoggingDatabaseConnection, LoggingDatabaseConnection,
LoggingTransaction,
make_in_list_sql_clause, make_in_list_sql_clause,
) )
from synapse.storage.databases.main.registration import RegistrationWorkerStore
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
from synapse.util.threepids import canonicalise_email from synapse.util.threepids import canonicalise_email
@ -56,7 +58,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
Number of current monthly active users Number of current monthly active users
""" """
def _count_users(txn): def _count_users(txn: LoggingTransaction) -> int:
# Exclude app service users # Exclude app service users
sql = """ sql = """
SELECT COUNT(*) SELECT COUNT(*)
@ -66,7 +68,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
WHERE (users.appservice_id IS NULL OR users.appservice_id = ''); WHERE (users.appservice_id IS NULL OR users.appservice_id = '');
""" """
txn.execute(sql) txn.execute(sql)
(count,) = txn.fetchone() (count,) = cast(Tuple[int], txn.fetchone())
return count return count
return await self.db_pool.runInteraction("count_users", _count_users) return await self.db_pool.runInteraction("count_users", _count_users)
@ -84,7 +86,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
""" """
def _count_users_by_service(txn): def _count_users_by_service(txn: LoggingTransaction) -> Dict[str, int]:
sql = """ sql = """
SELECT COALESCE(appservice_id, 'native'), COUNT(*) SELECT COALESCE(appservice_id, 'native'), COUNT(*)
FROM monthly_active_users FROM monthly_active_users
@ -93,7 +95,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
""" """
txn.execute(sql) txn.execute(sql)
result = txn.fetchall() result = cast(List[Tuple[str, int]], txn.fetchall())
return dict(result) return dict(result)
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
@ -141,12 +143,12 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
) )
@wrap_as_background_process("reap_monthly_active_users") @wrap_as_background_process("reap_monthly_active_users")
async def reap_monthly_active_users(self): async def reap_monthly_active_users(self) -> None:
"""Cleans out monthly active user table to ensure that no stale """Cleans out monthly active user table to ensure that no stale
entries exist. entries exist.
""" """
def _reap_users(txn, reserved_users): def _reap_users(txn: LoggingTransaction, reserved_users: List[str]) -> None:
""" """
Args: Args:
reserved_users (tuple): reserved users to preserve reserved_users (tuple): reserved users to preserve
@ -210,10 +212,10 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
# is racy. # is racy.
# Have resolved to invalidate the whole cache for now and do # Have resolved to invalidate the whole cache for now and do
# something about it if and when the perf becomes significant # something about it if and when the perf becomes significant
self._invalidate_all_cache_and_stream( self._invalidate_all_cache_and_stream( # type: ignore[attr-defined]
txn, self.user_last_seen_monthly_active txn, self.user_last_seen_monthly_active
) )
self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ()) self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ()) # type: ignore[attr-defined]
reserved_users = await self.get_registered_reserved_users() reserved_users = await self.get_registered_reserved_users()
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
@ -221,7 +223,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
) )
class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerStore):
def __init__( def __init__(
self, self,
database: DatabasePool, database: DatabasePool,
@ -242,13 +244,15 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value], hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value],
) )
def _initialise_reserved_users(self, txn, threepids): def _initialise_reserved_users(
self, txn: LoggingTransaction, threepids: List[dict]
) -> None:
"""Ensures that reserved threepids are accounted for in the MAU table, should """Ensures that reserved threepids are accounted for in the MAU table, should
be called on start up. be called on start up.
Args: Args:
txn (cursor): txn:
threepids (list[dict]): List of threepid dicts to reserve threepids: List of threepid dicts to reserve
""" """
# XXX what is this function trying to achieve? It upserts into # XXX what is this function trying to achieve? It upserts into
@ -299,7 +303,9 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
"upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id "upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id
) )
def upsert_monthly_active_user_txn(self, txn, user_id): def upsert_monthly_active_user_txn(
self, txn: LoggingTransaction, user_id: str
) -> None:
"""Updates or inserts monthly active user member """Updates or inserts monthly active user member
We consciously do not call is_support_txn from this method because it We consciously do not call is_support_txn from this method because it
@ -336,7 +342,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
txn, self.user_last_seen_monthly_active, (user_id,) txn, self.user_last_seen_monthly_active, (user_id,)
) )
async def populate_monthly_active_users(self, user_id): async def populate_monthly_active_users(self, user_id: str) -> None:
"""Checks on the state of monthly active user limits and optionally """Checks on the state of monthly active user limits and optionally
add the user to the monthly active tables add the user to the monthly active tables
@ -345,7 +351,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
""" """
if self._limit_usage_by_mau or self._mau_stats_only: if self._limit_usage_by_mau or self._mau_stats_only:
# Trial users and guests should not be included as part of MAU group # Trial users and guests should not be included as part of MAU group
is_guest = await self.is_guest(user_id) is_guest = await self.is_guest(user_id) # type: ignore[attr-defined]
if is_guest: if is_guest:
return return
is_trial = await self.is_trial_user(user_id) is_trial = await self.is_trial_user(user_id)

View file

@ -24,10 +24,9 @@ from typing import (
Optional, Optional,
Set, Set,
Tuple, Tuple,
cast,
) )
from twisted.internet import defer
from synapse.api.constants import ReceiptTypes from synapse.api.constants import ReceiptTypes
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import ReceiptsStream from synapse.replication.tcp.streams import ReceiptsStream
@ -38,7 +37,11 @@ from synapse.storage.database import (
LoggingTransaction, LoggingTransaction,
) )
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.storage.util.id_generators import (
AbstractStreamIdTracker,
MultiWriterIdGenerator,
StreamIdGenerator,
)
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
@ -58,6 +61,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
hs: "HomeServer", hs: "HomeServer",
): ):
self._instance_name = hs.get_instance_name() self._instance_name = hs.get_instance_name()
self._receipts_id_gen: AbstractStreamIdTracker
if isinstance(database.engine, PostgresEngine): if isinstance(database.engine, PostgresEngine):
self._can_write_to_receipts = ( self._can_write_to_receipts = (
@ -161,7 +165,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
" AND user_id = ?" " AND user_id = ?"
) )
txn.execute(sql, (user_id,)) txn.execute(sql, (user_id,))
return txn.fetchall() return cast(List[Tuple[str, str, int, int]], txn.fetchall())
rows = await self.db_pool.runInteraction( rows = await self.db_pool.runInteraction(
"get_receipts_for_user_with_orderings", f "get_receipts_for_user_with_orderings", f
@ -257,7 +261,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
if not rows: if not rows:
return [] return []
content = {} content: JsonDict = {}
for row in rows: for row in rows:
content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[ content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[
row["user_id"] row["user_id"]
@ -305,7 +309,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"_get_linearized_receipts_for_rooms", f "_get_linearized_receipts_for_rooms", f
) )
results = {} results: JsonDict = {}
for row in txn_results: for row in txn_results:
# We want a single event per room, since we want to batch the # We want a single event per room, since we want to batch the
# receipts by room, event and type. # receipts by room, event and type.
@ -370,7 +374,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"get_linearized_receipts_for_all_rooms", f "get_linearized_receipts_for_all_rooms", f
) )
results = {} results: JsonDict = {}
for row in txn_results: for row in txn_results:
# We want a single event per room, since we want to batch the # We want a single event per room, since we want to batch the
# receipts by room, event and type. # receipts by room, event and type.
@ -399,7 +403,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
""" """
if last_id == current_id: if last_id == current_id:
return defer.succeed([]) return []
def _get_users_sent_receipts_between_txn(txn: LoggingTransaction) -> List[str]: def _get_users_sent_receipts_between_txn(txn: LoggingTransaction) -> List[str]:
sql = """ sql = """
@ -453,7 +457,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
""" """
txn.execute(sql, (last_id, current_id, limit)) txn.execute(sql, (last_id, current_id, limit))
updates = [(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn] updates = cast(
List[Tuple[int, list]],
[(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn],
)
limited = False limited = False
upper_bound = current_id upper_bound = current_id
@ -496,7 +503,13 @@ class ReceiptsWorkerStore(SQLBaseStore):
self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id) self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
self.get_receipts_for_room.invalidate((room_id, receipt_type)) self.get_receipts_for_room.invalidate((room_id, receipt_type))
def process_replication_rows(self, stream_name, instance_name, token, rows): def process_replication_rows(
self,
stream_name: str,
instance_name: str,
token: int,
rows: Iterable[Any],
) -> None:
if stream_name == ReceiptsStream.NAME: if stream_name == ReceiptsStream.NAME:
self._receipts_id_gen.advance(instance_name, token) self._receipts_id_gen.advance(instance_name, token)
for row in rows: for row in rows:
@ -584,7 +597,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
) )
if receipt_type == ReceiptTypes.READ and stream_ordering is not None: if receipt_type == ReceiptTypes.READ and stream_ordering is not None:
self._remove_old_push_actions_before_txn( self._remove_old_push_actions_before_txn( # type: ignore[attr-defined]
txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering
) )
@ -637,7 +650,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"insert_receipt_conv", graph_to_linear "insert_receipt_conv", graph_to_linear
) )
async with self._receipts_id_gen.get_next() as stream_id: async with self._receipts_id_gen.get_next() as stream_id: # type: ignore[attr-defined]
event_ts = await self.db_pool.runInteraction( event_ts = await self.db_pool.runInteraction(
"insert_linearized_receipt", "insert_linearized_receipt",
self.insert_linearized_receipt_txn, self.insert_linearized_receipt_txn,

View file

@ -22,6 +22,7 @@ import attr
from synapse.api.constants import UserTypes from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
from synapse.config.homeserver import HomeServerConfig
from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage.database import ( from synapse.storage.database import (
DatabasePool, DatabasePool,
@ -123,7 +124,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
): ):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.config = hs.config self.config: HomeServerConfig = hs.config
# Note: we don't check this sequence for consistency as we'd have to # Note: we don't check this sequence for consistency as we'd have to
# call `find_max_generated_user_id_localpart` each time, which is # call `find_max_generated_user_id_localpart` each time, which is

View file

@ -27,7 +27,6 @@ from typing import (
) )
import attr import attr
from frozendict import frozendict
from synapse.api.constants import RelationTypes from synapse.api.constants import RelationTypes
from synapse.events import EventBase from synapse.events import EventBase
@ -41,45 +40,15 @@ from synapse.storage.database import (
from synapse.storage.databases.main.stream import generate_pagination_where_clause from synapse.storage.databases.main.stream import generate_pagination_where_clause
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.storage.relations import AggregationPaginationToken, PaginationChunk from synapse.storage.relations import AggregationPaginationToken, PaginationChunk
from synapse.types import JsonDict, RoomStreamToken, StreamToken from synapse.types import RoomStreamToken, StreamToken
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _ThreadAggregation:
# The latest event in the thread.
latest_event: EventBase
# The latest edit to the latest event in the thread.
latest_edit: Optional[EventBase]
# The total number of events in the thread.
count: int
# True if the current user has sent an event to the thread.
current_user_participated: bool
@attr.s(slots=True, auto_attribs=True)
class BundledAggregations:
"""
The bundled aggregations for an event.
Some values require additional processing during serialization.
"""
annotations: Optional[JsonDict] = None
references: Optional[JsonDict] = None
replace: Optional[EventBase] = None
thread: Optional[_ThreadAggregation] = None
def __bool__(self) -> bool:
return bool(self.annotations or self.references or self.replace or self.thread)
class RelationsWorkerStore(SQLBaseStore): class RelationsWorkerStore(SQLBaseStore):
def __init__( def __init__(
self, self,
@ -384,7 +353,7 @@ class RelationsWorkerStore(SQLBaseStore):
raise NotImplementedError() raise NotImplementedError()
@cachedList(cached_method_name="get_applicable_edit", list_name="event_ids") @cachedList(cached_method_name="get_applicable_edit", list_name="event_ids")
async def _get_applicable_edits( async def get_applicable_edits(
self, event_ids: Collection[str] self, event_ids: Collection[str]
) -> Dict[str, Optional[EventBase]]: ) -> Dict[str, Optional[EventBase]]:
"""Get the most recent edit (if any) that has happened for the given """Get the most recent edit (if any) that has happened for the given
@ -473,7 +442,7 @@ class RelationsWorkerStore(SQLBaseStore):
raise NotImplementedError() raise NotImplementedError()
@cachedList(cached_method_name="get_thread_summary", list_name="event_ids") @cachedList(cached_method_name="get_thread_summary", list_name="event_ids")
async def _get_thread_summaries( async def get_thread_summaries(
self, event_ids: Collection[str] self, event_ids: Collection[str]
) -> Dict[str, Optional[Tuple[int, EventBase, Optional[EventBase]]]]: ) -> Dict[str, Optional[Tuple[int, EventBase, Optional[EventBase]]]]:
"""Get the number of threaded replies, the latest reply (if any), and the latest edit for that reply for the given event. """Get the number of threaded replies, the latest reply (if any), and the latest edit for that reply for the given event.
@ -587,7 +556,7 @@ class RelationsWorkerStore(SQLBaseStore):
latest_events = await self.get_events(latest_event_ids.values()) # type: ignore[attr-defined] latest_events = await self.get_events(latest_event_ids.values()) # type: ignore[attr-defined]
# Check to see if any of those events are edited. # Check to see if any of those events are edited.
latest_edits = await self._get_applicable_edits(latest_event_ids.values()) latest_edits = await self.get_applicable_edits(latest_event_ids.values())
# Map to the event IDs to the thread summary. # Map to the event IDs to the thread summary.
# #
@ -610,7 +579,7 @@ class RelationsWorkerStore(SQLBaseStore):
raise NotImplementedError() raise NotImplementedError()
@cachedList(cached_method_name="get_thread_participated", list_name="event_ids") @cachedList(cached_method_name="get_thread_participated", list_name="event_ids")
async def _get_threads_participated( async def get_threads_participated(
self, event_ids: Collection[str], user_id: str self, event_ids: Collection[str], user_id: str
) -> Dict[str, bool]: ) -> Dict[str, bool]:
"""Get whether the requesting user participated in the given threads. """Get whether the requesting user participated in the given threads.
@ -766,114 +735,6 @@ class RelationsWorkerStore(SQLBaseStore):
"get_if_user_has_annotated_event", _get_if_user_has_annotated_event "get_if_user_has_annotated_event", _get_if_user_has_annotated_event
) )
async def _get_bundled_aggregation_for_event(
self, event: EventBase, user_id: str
) -> Optional[BundledAggregations]:
"""Generate bundled aggregations for an event.
Note that this does not use a cache, but depends on cached methods.
Args:
event: The event to calculate bundled aggregations for.
user_id: The user requesting the bundled aggregations.
Returns:
The bundled aggregations for an event, if bundled aggregations are
enabled and the event can have bundled aggregations.
"""
# Do not bundle aggregations for an event which represents an edit or an
# annotation. It does not make sense for them to have related events.
relates_to = event.content.get("m.relates_to")
if isinstance(relates_to, (dict, frozendict)):
relation_type = relates_to.get("rel_type")
if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
return None
event_id = event.event_id
room_id = event.room_id
# The bundled aggregations to include, a mapping of relation type to a
# type-specific value. Some types include the direct return type here
# while others need more processing during serialization.
aggregations = BundledAggregations()
annotations = await self.get_aggregation_groups_for_event(event_id, room_id)
if annotations.chunk:
aggregations.annotations = await annotations.to_dict(
cast("DataStore", self)
)
references = await self.get_relations_for_event(
event_id, event, room_id, RelationTypes.REFERENCE, direction="f"
)
if references.chunk:
aggregations.references = await references.to_dict(cast("DataStore", self))
# Store the bundled aggregations in the event metadata for later use.
return aggregations
async def get_bundled_aggregations(
self, events: Iterable[EventBase], user_id: str
) -> Dict[str, BundledAggregations]:
"""Generate bundled aggregations for events.
Args:
events: The iterable of events to calculate bundled aggregations for.
user_id: The user requesting the bundled aggregations.
Returns:
A map of event ID to the bundled aggregation for the event. Not all
events may have bundled aggregations in the results.
"""
# De-duplicate events by ID to handle the same event requested multiple times.
#
# State events do not get bundled aggregations.
events_by_id = {
event.event_id: event for event in events if not event.is_state()
}
# event ID -> bundled aggregation in non-serialized form.
results: Dict[str, BundledAggregations] = {}
# Fetch other relations per event.
for event in events_by_id.values():
event_result = await self._get_bundled_aggregation_for_event(event, user_id)
if event_result:
results[event.event_id] = event_result
# Fetch any edits (but not for redacted events).
edits = await self._get_applicable_edits(
[
event_id
for event_id, event in events_by_id.items()
if not event.internal_metadata.is_redacted()
]
)
for event_id, edit in edits.items():
results.setdefault(event_id, BundledAggregations()).replace = edit
# Fetch thread summaries.
summaries = await self._get_thread_summaries(events_by_id.keys())
# Only fetch participated for a limited selection based on what had
# summaries.
participated = await self._get_threads_participated(summaries.keys(), user_id)
for event_id, summary in summaries.items():
if summary:
thread_count, latest_thread_event, edit = summary
results.setdefault(
event_id, BundledAggregations()
).thread = _ThreadAggregation(
latest_event=latest_thread_event,
latest_edit=edit,
count=thread_count,
# If there's a thread summary it must also exist in the
# participated dictionary.
current_user_participated=participated[event_id],
)
return results
class RelationsStore(RelationsWorkerStore): class RelationsStore(RelationsWorkerStore):
pass pass

View file

@ -34,6 +34,7 @@ import attr
from synapse.api.constants import EventContentFields, EventTypes, JoinRules from synapse.api.constants import EventContentFields, EventTypes, JoinRules
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.api.room_versions import RoomVersion, RoomVersions from synapse.api.room_versions import RoomVersion, RoomVersions
from synapse.config.homeserver import HomeServerConfig
from synapse.events import EventBase from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import ( from synapse.storage.database import (
@ -98,7 +99,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
): ):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.config = hs.config self.config: HomeServerConfig = hs.config
async def store_room( async def store_room(
self, self,

View file

@ -63,6 +63,14 @@ _MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership" _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
@attr.s(frozen=True, slots=True, auto_attribs=True)
class EventIdMembership:
"""Returned by `get_membership_from_event_ids`"""
user_id: str
membership: str
class RoomMemberWorkerStore(EventsWorkerStore): class RoomMemberWorkerStore(EventsWorkerStore):
def __init__( def __init__(
self, self,
@ -772,7 +780,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
retcols=("user_id", "display_name", "avatar_url", "event_id"), retcols=("user_id", "display_name", "avatar_url", "event_id"),
keyvalues={"membership": Membership.JOIN}, keyvalues={"membership": Membership.JOIN},
batch_size=500, batch_size=500,
desc="_get_membership_from_event_ids", desc="_get_joined_profiles_from_event_ids",
) )
return { return {
@ -1000,12 +1008,26 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return set(room_ids) return set(room_ids)
@cached(max_entries=5000)
async def _get_membership_from_event_id(
self, member_event_id: str
) -> Optional[EventIdMembership]:
raise NotImplementedError()
@cachedList(
cached_method_name="_get_membership_from_event_id", list_name="member_event_ids"
)
async def get_membership_from_event_ids( async def get_membership_from_event_ids(
self, member_event_ids: Iterable[str] self, member_event_ids: Iterable[str]
) -> List[dict]: ) -> Dict[str, Optional[EventIdMembership]]:
"""Get user_id and membership of a set of event IDs.""" """Get user_id and membership of a set of event IDs.
return await self.db_pool.simple_select_many_batch( Returns:
Mapping from event ID to `EventIdMembership` if the event is a
membership event, otherwise the value is None.
"""
rows = await self.db_pool.simple_select_many_batch(
table="room_memberships", table="room_memberships",
column="event_id", column="event_id",
iterable=member_event_ids, iterable=member_event_ids,
@ -1015,6 +1037,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
desc="get_membership_from_event_ids", desc="get_membership_from_event_ids",
) )
return {
row["event_id"]: EventIdMembership(
membership=row["membership"], user_id=row["user_id"]
)
for row in rows
}
async def is_local_host_in_room_ignoring_users( async def is_local_host_in_room_ignoring_users(
self, room_id: str, ignore_users: Collection[str] self, room_id: str, ignore_users: Collection[str]
) -> bool: ) -> bool:

View file

@ -14,7 +14,7 @@
import logging import logging
import re import re
from typing import TYPE_CHECKING, Collection, Iterable, List, Optional, Set from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set
import attr import attr
@ -74,7 +74,7 @@ class SearchWorkerStore(SQLBaseStore):
" VALUES (?,?,?,to_tsvector('english', ?),?,?)" " VALUES (?,?,?,to_tsvector('english', ?),?,?)"
) )
args = ( args1 = (
( (
entry.event_id, entry.event_id,
entry.room_id, entry.room_id,
@ -86,14 +86,14 @@ class SearchWorkerStore(SQLBaseStore):
for entry in entries for entry in entries
) )
txn.execute_batch(sql, args) txn.execute_batch(sql, args1)
elif isinstance(self.database_engine, Sqlite3Engine): elif isinstance(self.database_engine, Sqlite3Engine):
sql = ( sql = (
"INSERT INTO event_search (event_id, room_id, key, value)" "INSERT INTO event_search (event_id, room_id, key, value)"
" VALUES (?,?,?,?)" " VALUES (?,?,?,?)"
) )
args = ( args2 = (
( (
entry.event_id, entry.event_id,
entry.room_id, entry.room_id,
@ -102,7 +102,7 @@ class SearchWorkerStore(SQLBaseStore):
) )
for entry in entries for entry in entries
) )
txn.execute_batch(sql, args) txn.execute_batch(sql, args2)
else: else:
# This should be unreachable. # This should be unreachable.
@ -427,7 +427,7 @@ class SearchStore(SearchBackgroundUpdateStore):
search_query = _parse_query(self.database_engine, search_term) search_query = _parse_query(self.database_engine, search_term)
args = [] args: List[Any] = []
# Make sure we don't explode because the person is in too many rooms. # Make sure we don't explode because the person is in too many rooms.
# We filter the results below regardless. # We filter the results below regardless.
@ -496,7 +496,7 @@ class SearchStore(SearchBackgroundUpdateStore):
# We set redact_behaviour to BLOCK here to prevent redacted events being returned in # We set redact_behaviour to BLOCK here to prevent redacted events being returned in
# search results (which is a data leak) # search results (which is a data leak)
events = await self.get_events_as_list( events = await self.get_events_as_list( # type: ignore[attr-defined]
[r["event_id"] for r in results], [r["event_id"] for r in results],
redact_behaviour=EventRedactBehaviour.BLOCK, redact_behaviour=EventRedactBehaviour.BLOCK,
) )
@ -530,7 +530,7 @@ class SearchStore(SearchBackgroundUpdateStore):
room_ids: Collection[str], room_ids: Collection[str],
search_term: str, search_term: str,
keys: Iterable[str], keys: Iterable[str],
limit, limit: int,
pagination_token: Optional[str] = None, pagination_token: Optional[str] = None,
) -> JsonDict: ) -> JsonDict:
"""Performs a full text search over events with given keys. """Performs a full text search over events with given keys.
@ -549,7 +549,7 @@ class SearchStore(SearchBackgroundUpdateStore):
search_query = _parse_query(self.database_engine, search_term) search_query = _parse_query(self.database_engine, search_term)
args = [] args: List[Any] = []
# Make sure we don't explode because the person is in too many rooms. # Make sure we don't explode because the person is in too many rooms.
# We filter the results below regardless. # We filter the results below regardless.
@ -573,9 +573,9 @@ class SearchStore(SearchBackgroundUpdateStore):
if pagination_token: if pagination_token:
try: try:
origin_server_ts, stream = pagination_token.split(",") origin_server_ts_str, stream_str = pagination_token.split(",")
origin_server_ts = int(origin_server_ts) origin_server_ts = int(origin_server_ts_str)
stream = int(stream) stream = int(stream_str)
except Exception: except Exception:
raise SynapseError(400, "Invalid pagination token") raise SynapseError(400, "Invalid pagination token")
@ -654,7 +654,7 @@ class SearchStore(SearchBackgroundUpdateStore):
# We set redact_behaviour to BLOCK here to prevent redacted events being returned in # We set redact_behaviour to BLOCK here to prevent redacted events being returned in
# search results (which is a data leak) # search results (which is a data leak)
events = await self.get_events_as_list( events = await self.get_events_as_list( # type: ignore[attr-defined]
[r["event_id"] for r in results], [r["event_id"] for r in results],
redact_behaviour=EventRedactBehaviour.BLOCK, redact_behaviour=EventRedactBehaviour.BLOCK,
) )

View file

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import collections.abc import collections.abc
import logging import logging
from typing import TYPE_CHECKING, Iterable, Optional, Set from typing import TYPE_CHECKING, Collection, Iterable, Optional, Set, Tuple
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
@ -29,7 +29,7 @@ from synapse.storage.database import (
from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.types import StateMap from synapse.types import JsonDict, StateMap
from synapse.util.caches import intern_string from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
@ -241,7 +241,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
# We delegate to the cached version # We delegate to the cached version
return await self.get_current_state_ids(room_id) return await self.get_current_state_ids(room_id)
def _get_filtered_current_state_ids_txn(txn): def _get_filtered_current_state_ids_txn(
txn: LoggingTransaction,
) -> StateMap[str]:
results = {} results = {}
sql = """ sql = """
SELECT type, state_key, event_id FROM current_state_events SELECT type, state_key, event_id FROM current_state_events
@ -281,11 +283,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
event_id = state.get((EventTypes.CanonicalAlias, "")) event_id = state.get((EventTypes.CanonicalAlias, ""))
if not event_id: if not event_id:
return return None
event = await self.get_event(event_id, allow_none=True) event = await self.get_event(event_id, allow_none=True)
if not event: if not event:
return return None
return event.content.get("canonical_alias") return event.content.get("canonical_alias")
@ -304,7 +306,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
list_name="event_ids", list_name="event_ids",
num_args=1, num_args=1,
) )
async def _get_state_group_for_events(self, event_ids): async def _get_state_group_for_events(self, event_ids: Collection[str]) -> JsonDict:
"""Returns mapping event_id -> state_group""" """Returns mapping event_id -> state_group"""
rows = await self.db_pool.simple_select_many_batch( rows = await self.db_pool.simple_select_many_batch(
table="event_to_state_groups", table="event_to_state_groups",
@ -355,7 +357,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
): ):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.server_name = hs.hostname self.server_name: str = hs.hostname
self.db_pool.updates.register_background_index_update( self.db_pool.updates.register_background_index_update(
self.CURRENT_STATE_INDEX_UPDATE_NAME, self.CURRENT_STATE_INDEX_UPDATE_NAME,
@ -375,7 +377,9 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
self._background_remove_left_rooms, self._background_remove_left_rooms,
) )
async def _background_remove_left_rooms(self, progress, batch_size): async def _background_remove_left_rooms(
self, progress: JsonDict, batch_size: int
) -> int:
"""Background update to delete rows from `current_state_events` and """Background update to delete rows from `current_state_events` and
`event_forward_extremities` tables of rooms that the server is no `event_forward_extremities` tables of rooms that the server is no
longer joined to. longer joined to.
@ -383,7 +387,9 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
last_room_id = progress.get("last_room_id", "") last_room_id = progress.get("last_room_id", "")
def _background_remove_left_rooms_txn(txn): def _background_remove_left_rooms_txn(
txn: LoggingTransaction,
) -> Tuple[bool, Set[str]]:
# get a batch of room ids to consider # get a batch of room ids to consider
sql = """ sql = """
SELECT DISTINCT room_id FROM current_state_events SELECT DISTINCT room_id FROM current_state_events

View file

@ -108,7 +108,7 @@ class StatsStore(StateDeltasStore):
): ):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.server_name = hs.hostname self.server_name: str = hs.hostname
self.clock = self.hs.get_clock() self.clock = self.hs.get_clock()
self.stats_enabled = hs.config.stats.stats_enabled self.stats_enabled = hs.config.stats.stats_enabled

View file

@ -26,6 +26,8 @@ from typing import (
cast, cast,
) )
from typing_extensions import TypedDict
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
if TYPE_CHECKING: if TYPE_CHECKING:
@ -40,7 +42,12 @@ from synapse.storage.database import (
from synapse.storage.databases.main.state import StateFilter from synapse.storage.databases.main.state import StateFilter
from synapse.storage.databases.main.state_deltas import StateDeltasStore from synapse.storage.databases.main.state_deltas import StateDeltasStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.types import JsonDict, get_domain_from_id, get_localpart_from_id from synapse.types import (
JsonDict,
UserProfile,
get_domain_from_id,
get_localpart_from_id,
)
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -61,7 +68,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
) -> None: ) -> None:
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.server_name = hs.hostname self.server_name: str = hs.hostname
self.db_pool.updates.register_background_update_handler( self.db_pool.updates.register_background_update_handler(
"populate_user_directory_createtables", "populate_user_directory_createtables",
@ -591,6 +598,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
) )
class SearchResult(TypedDict):
limited: bool
results: List[UserProfile]
class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
# How many records do we calculate before sending it to # How many records do we calculate before sending it to
# add_users_who_share_private_rooms? # add_users_who_share_private_rooms?
@ -718,7 +730,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
users.update(rows) users.update(rows)
return list(users) return list(users)
async def get_shared_rooms_for_users( async def get_mutual_rooms_for_users(
self, user_id: str, other_user_id: str self, user_id: str, other_user_id: str
) -> Set[str]: ) -> Set[str]:
""" """
@ -732,7 +744,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
A set of room ID's that the users share. A set of room ID's that the users share.
""" """
def _get_shared_rooms_for_users_txn( def _get_mutual_rooms_for_users_txn(
txn: LoggingTransaction, txn: LoggingTransaction,
) -> List[Dict[str, str]]: ) -> List[Dict[str, str]]:
txn.execute( txn.execute(
@ -756,7 +768,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
return rows return rows
rows = await self.db_pool.runInteraction( rows = await self.db_pool.runInteraction(
"get_shared_rooms_for_users", _get_shared_rooms_for_users_txn "get_mutual_rooms_for_users", _get_mutual_rooms_for_users_txn
) )
return {row["room_id"] for row in rows} return {row["room_id"] for row in rows}
@ -777,7 +789,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
async def search_user_dir( async def search_user_dir(
self, user_id: str, search_term: str, limit: int self, user_id: str, search_term: str, limit: int
) -> JsonDict: ) -> SearchResult:
"""Searches for users in directory """Searches for users in directory
Returns: Returns:
@ -910,8 +922,11 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
# This should be unreachable. # This should be unreachable.
raise Exception("Unrecognized database engine") raise Exception("Unrecognized database engine")
results = await self.db_pool.execute( results = cast(
List[UserProfile],
await self.db_pool.execute(
"search_user_dir", self.db_pool.cursor_to_dict, sql, *args "search_user_dir", self.db_pool.cursor_to_dict, sql, *args
),
) )
limited = len(results) > limit limited = len(results) > limit

View file

@ -27,7 +27,7 @@ def create_engine(database_config) -> BaseDatabaseEngine:
if name == "psycopg2": if name == "psycopg2":
# Note that psycopg2cffi-compat provides the psycopg2 module on pypy. # Note that psycopg2cffi-compat provides the psycopg2 module on pypy.
import psycopg2 # type: ignore import psycopg2
return PostgresEngine(psycopg2, database_config) return PostgresEngine(psycopg2, database_config)

View file

@ -47,17 +47,26 @@ class PostgresEngine(BaseDatabaseEngine):
self.default_isolation_level = ( self.default_isolation_level = (
self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
) )
self.config = database_config
@property @property
def single_threaded(self) -> bool: def single_threaded(self) -> bool:
return False return False
def get_db_locale(self, txn):
txn.execute(
"SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()"
)
collation, ctype = txn.fetchone()
return collation, ctype
def check_database(self, db_conn, allow_outdated_version: bool = False): def check_database(self, db_conn, allow_outdated_version: bool = False):
# Get the version of PostgreSQL that we're using. As per the psycopg2 # Get the version of PostgreSQL that we're using. As per the psycopg2
# docs: The number is formed by converting the major, minor, and # docs: The number is formed by converting the major, minor, and
# revision numbers into two-decimal-digit numbers and appending them # revision numbers into two-decimal-digit numbers and appending them
# together. For example, version 8.1.5 will be returned as 80105 # together. For example, version 8.1.5 will be returned as 80105
self._version = db_conn.server_version self._version = db_conn.server_version
allow_unsafe_locale = self.config.get("allow_unsafe_locale", False)
# Are we on a supported PostgreSQL version? # Are we on a supported PostgreSQL version?
if not allow_outdated_version and self._version < 100000: if not allow_outdated_version and self._version < 100000:
@ -72,21 +81,30 @@ class PostgresEngine(BaseDatabaseEngine):
"See docs/postgres.md for more information." % (rows[0][0],) "See docs/postgres.md for more information." % (rows[0][0],)
) )
txn.execute( collation, ctype = self.get_db_locale(txn)
"SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()"
)
collation, ctype = txn.fetchone()
if collation != "C": if collation != "C":
logger.warning( logger.warning(
"Database has incorrect collation of %r. Should be 'C'",
collation,
)
if not allow_unsafe_locale:
raise IncorrectDatabaseSetup(
"Database has incorrect collation of %r. Should be 'C'\n" "Database has incorrect collation of %r. Should be 'C'\n"
"See docs/postgres.md for more information.", "See docs/postgres.md for more information. You can override this check by"
"setting 'allow_unsafe_locale' to true in the database config.",
collation, collation,
) )
if ctype != "C": if ctype != "C":
if not allow_unsafe_locale:
logger.warning( logger.warning(
"Database has incorrect ctype of %r. Should be 'C'",
ctype,
)
raise IncorrectDatabaseSetup(
"Database has incorrect ctype of %r. Should be 'C'\n" "Database has incorrect ctype of %r. Should be 'C'\n"
"See docs/postgres.md for more information.", "See docs/postgres.md for more information. You can override this check by"
"setting 'allow_unsafe_locale' to true in the database config.",
ctype, ctype,
) )
@ -95,10 +113,7 @@ class PostgresEngine(BaseDatabaseEngine):
apply stricter checks on new databases versus existing database. apply stricter checks on new databases versus existing database.
""" """
txn.execute( collation, ctype = self.get_db_locale(txn)
"SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()"
)
collation, ctype = txn.fetchone()
errors = [] errors = []

View file

@ -1023,8 +1023,13 @@ class EventsPersistenceStorage:
# Check if any of the changes that we don't have events for are joins. # Check if any of the changes that we don't have events for are joins.
if events_to_check: if events_to_check:
rows = await self.main_store.get_membership_from_event_ids(events_to_check) members = await self.main_store.get_membership_from_event_ids(
is_still_joined = any(row["membership"] == Membership.JOIN for row in rows) events_to_check
)
is_still_joined = any(
member and member.membership == Membership.JOIN
for member in members.values()
)
if is_still_joined: if is_still_joined:
return True return True
@ -1060,9 +1065,11 @@ class EventsPersistenceStorage:
), event_id in current_state.items() ), event_id in current_state.items()
if typ == EventTypes.Member and not self.is_mine_id(state_key) if typ == EventTypes.Member and not self.is_mine_id(state_key)
] ]
rows = await self.main_store.get_membership_from_event_ids(remote_event_ids) members = await self.main_store.get_membership_from_event_ids(remote_event_ids)
potentially_left_users.update( potentially_left_users.update(
row["user_id"] for row in rows if row["membership"] == Membership.JOIN member.user_id
for member in members.values()
if member and member.membership == Membership.JOIN
) )
return False return False

View file

@ -34,6 +34,7 @@ from typing import (
import attr import attr
from frozendict import frozendict from frozendict import frozendict
from signedjson.key import decode_verify_key_bytes from signedjson.key import decode_verify_key_bytes
from typing_extensions import TypedDict
from unpaddedbase64 import decode_base64 from unpaddedbase64 import decode_base64
from zope.interface import Interface from zope.interface import Interface
@ -63,6 +64,10 @@ MutableStateMap = MutableMapping[StateKey, T]
# JSON types. These could be made stronger, but will do for now. # JSON types. These could be made stronger, but will do for now.
# A JSON-serialisable dict. # A JSON-serialisable dict.
JsonDict = Dict[str, Any] JsonDict = Dict[str, Any]
# A JSON-serialisable mapping; roughly speaking an immutable JSONDict.
# Useful when you have a TypedDict which isn't going to be mutated and you don't want
# to cast to JsonDict everywhere.
JsonMapping = Mapping[str, Any]
# A JSON-serialisable object. # A JSON-serialisable object.
JsonSerializable = object JsonSerializable = object
@ -791,3 +796,9 @@ class UserInfo:
is_deactivated: bool is_deactivated: bool
is_guest: bool is_guest: bool
is_shadow_banned: bool is_shadow_banned: bool
class UserProfile(TypedDict):
user_id: str
display_name: Optional[str]
avatar_url: Optional[str]

View file

@ -128,6 +128,19 @@ def _incorrect_version(
) )
def _no_reported_version(requirement: Requirement, extra: Optional[str] = None) -> str:
if extra:
return (
f"Synapse {VERSION} needs {requirement} for {extra}, "
f"but can't determine {requirement.name}'s version"
)
else:
return (
f"Synapse {VERSION} needs {requirement}, "
f"but can't determine {requirement.name}'s version"
)
def check_requirements(extra: Optional[str] = None) -> None: def check_requirements(extra: Optional[str] = None) -> None:
"""Check Synapse's dependencies are present and correctly versioned. """Check Synapse's dependencies are present and correctly versioned.
@ -163,8 +176,17 @@ def check_requirements(extra: Optional[str] = None) -> None:
deps_unfulfilled.append(requirement.name) deps_unfulfilled.append(requirement.name)
errors.append(_not_installed(requirement, extra)) errors.append(_not_installed(requirement, extra))
else: else:
if dist.version is None:
# This shouldn't happen---it suggests a borked virtualenv. (See #12223)
# Try to give a vaguely helpful error message anyway.
# Type-ignore: the annotations don't reflect reality: see
# https://github.com/python/typeshed/issues/7513
# https://bugs.python.org/issue47060
deps_unfulfilled.append(requirement.name) # type: ignore[unreachable]
errors.append(_no_reported_version(requirement, extra))
# We specify prereleases=True to allow prereleases such as RCs. # We specify prereleases=True to allow prereleases such as RCs.
if not requirement.specifier.contains(dist.version, prereleases=True): elif not requirement.specifier.contains(dist.version, prereleases=True):
deps_unfulfilled.append(requirement.name) deps_unfulfilled.append(requirement.name)
errors.append(_incorrect_version(requirement, dist.version, extra)) errors.append(_incorrect_version(requirement, dist.version, extra))

View file

@ -14,12 +14,7 @@
import logging import logging
from typing import Dict, FrozenSet, List, Optional from typing import Dict, FrozenSet, List, Optional
from synapse.api.constants import ( from synapse.api.constants import EventTypes, HistoryVisibility, Membership
AccountDataTypes,
EventTypes,
HistoryVisibility,
Membership,
)
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.storage import Storage from synapse.storage import Storage
@ -87,15 +82,8 @@ async def filter_events_for_client(
state_filter=StateFilter.from_types(types), state_filter=StateFilter.from_types(types),
) )
ignore_dict_content = await storage.main.get_global_account_data_by_type_for_user( # Get the users who are ignored by the requesting user.
user_id, AccountDataTypes.IGNORED_USER_LIST ignore_list = await storage.main.ignored_users(user_id)
)
ignore_list: FrozenSet[str] = frozenset()
if ignore_dict_content:
ignored_users_dict = ignore_dict_content.get("ignored_users", {})
if isinstance(ignored_users_dict, dict):
ignore_list = frozenset(ignored_users_dict.keys())
erased_senders = await storage.main.are_users_erased(e.sender for e in events) erased_senders = await storage.main.are_users_erased(e.sender for e in events)

View file

@ -11,14 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import synapse.app.homeserver
from synapse.config import ConfigError from synapse.config import ConfigError
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from tests.unittest import TestCase from tests.config.utils import ConfigFileTestCase
from tests.utils import default_config from tests.utils import default_config
class RegistrationConfigTestCase(TestCase): class RegistrationConfigTestCase(ConfigFileTestCase):
def test_session_lifetime_must_not_be_exceeded_by_smaller_lifetimes(self): def test_session_lifetime_must_not_be_exceeded_by_smaller_lifetimes(self):
""" """
session_lifetime should logically be larger than, or at least as large as, session_lifetime should logically be larger than, or at least as large as,
@ -76,3 +78,19 @@ class RegistrationConfigTestCase(TestCase):
HomeServerConfig().parse_config_dict( HomeServerConfig().parse_config_dict(
{"session_lifetime": "31m", "refresh_token_lifetime": "31m", **config_dict} {"session_lifetime": "31m", "refresh_token_lifetime": "31m", **config_dict}
) )
def test_refuse_to_start_if_open_registration_and_no_verification(self):
self.generate_config()
self.add_lines_to_config(
[
" ",
"enable_registration: true",
"registrations_require_3pid: []",
"enable_registration_captcha: false",
"registration_requires_token: false",
]
)
# Test that allowing open registration without verification raises an error
with self.assertRaises(ConfigError):
synapse.app.homeserver.setup(["-c", self.config_file])

View file

@ -11,9 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Dict
from unittest.mock import Mock from unittest.mock import Mock
from twisted.test.proto_helpers import MemoryReactor
from synapse.handlers.cas import CasResponse from synapse.handlers.cas import CasResponse
from synapse.server import HomeServer
from synapse.util import Clock
from tests.test_utils import simple_async_mock from tests.test_utils import simple_async_mock
from tests.unittest import HomeserverTestCase, override_config from tests.unittest import HomeserverTestCase, override_config
@ -24,7 +29,7 @@ SERVER_URL = "https://issuer/"
class CasHandlerTestCase(HomeserverTestCase): class CasHandlerTestCase(HomeserverTestCase):
def default_config(self): def default_config(self) -> Dict[str, Any]:
config = super().default_config() config = super().default_config()
config["public_baseurl"] = BASE_URL config["public_baseurl"] = BASE_URL
cas_config = { cas_config = {
@ -40,7 +45,7 @@ class CasHandlerTestCase(HomeserverTestCase):
return config return config
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver() hs = self.setup_test_homeserver()
self.handler = hs.get_cas_handler() self.handler = hs.get_cas_handler()
@ -51,7 +56,7 @@ class CasHandlerTestCase(HomeserverTestCase):
return hs return hs
def test_map_cas_user_to_user(self): def test_map_cas_user_to_user(self) -> None:
"""Ensure that mapping the CAS user returned from a provider to an MXID works properly.""" """Ensure that mapping the CAS user returned from a provider to an MXID works properly."""
# stub out the auth handler # stub out the auth handler
@ -75,7 +80,7 @@ class CasHandlerTestCase(HomeserverTestCase):
auth_provider_session_id=None, auth_provider_session_id=None,
) )
def test_map_cas_user_to_existing_user(self): def test_map_cas_user_to_existing_user(self) -> None:
"""Existing users can log in with CAS account.""" """Existing users can log in with CAS account."""
store = self.hs.get_datastores().main store = self.hs.get_datastores().main
self.get_success( self.get_success(
@ -119,7 +124,7 @@ class CasHandlerTestCase(HomeserverTestCase):
auth_provider_session_id=None, auth_provider_session_id=None,
) )
def test_map_cas_user_to_invalid_localpart(self): def test_map_cas_user_to_invalid_localpart(self) -> None:
"""CAS automaps invalid characters to base-64 encoding.""" """CAS automaps invalid characters to base-64 encoding."""
# stub out the auth handler # stub out the auth handler
@ -150,7 +155,7 @@ class CasHandlerTestCase(HomeserverTestCase):
} }
} }
) )
def test_required_attributes(self): def test_required_attributes(self) -> None:
"""The required attributes must be met from the CAS response.""" """The required attributes must be met from the CAS response."""
# stub out the auth handler # stub out the auth handler
@ -166,7 +171,7 @@ class CasHandlerTestCase(HomeserverTestCase):
auth_handler.complete_sso_login.assert_not_called() auth_handler.complete_sso_login.assert_not_called()
# The response doesn't have any department. # The response doesn't have any department.
cas_response = CasResponse("test_user", {"userGroup": "staff"}) cas_response = CasResponse("test_user", {"userGroup": ["staff"]})
request.reset_mock() request.reset_mock()
self.get_success( self.get_success(
self.handler._handle_cas_response(request, cas_response, "redirect_uri", "") self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")

View file

@ -12,14 +12,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Awaitable, Callable, Dict
from unittest.mock import Mock from unittest.mock import Mock
from twisted.test.proto_helpers import MemoryReactor
import synapse.api.errors import synapse.api.errors
import synapse.rest.admin import synapse.rest.admin
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.rest.client import directory, login, room from synapse.rest.client import directory, login, room
from synapse.types import RoomAlias, create_requester from synapse.server import HomeServer
from synapse.types import JsonDict, RoomAlias, create_requester
from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable from tests.test_utils import make_awaitable
@ -28,13 +32,15 @@ from tests.test_utils import make_awaitable
class DirectoryTestCase(unittest.HomeserverTestCase): class DirectoryTestCase(unittest.HomeserverTestCase):
"""Tests the directory service.""" """Tests the directory service."""
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.mock_federation = Mock() self.mock_federation = Mock()
self.mock_registry = Mock() self.mock_registry = Mock()
self.query_handlers = {} self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {}
def register_query_handler(query_type, handler): def register_query_handler(
query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
) -> None:
self.query_handlers[query_type] = handler self.query_handlers[query_type] = handler
self.mock_registry.register_query_handler = register_query_handler self.mock_registry.register_query_handler = register_query_handler
@ -54,7 +60,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
return hs return hs
def test_get_local_association(self): def test_get_local_association(self) -> None:
self.get_success( self.get_success(
self.store.create_room_alias_association( self.store.create_room_alias_association(
self.my_room, "!8765qwer:test", ["test"] self.my_room, "!8765qwer:test", ["test"]
@ -65,7 +71,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
self.assertEqual({"room_id": "!8765qwer:test", "servers": ["test"]}, result) self.assertEqual({"room_id": "!8765qwer:test", "servers": ["test"]}, result)
def test_get_remote_association(self): def test_get_remote_association(self) -> None:
self.mock_federation.make_query.return_value = make_awaitable( self.mock_federation.make_query.return_value = make_awaitable(
{"room_id": "!8765qwer:test", "servers": ["test", "remote"]} {"room_id": "!8765qwer:test", "servers": ["test", "remote"]}
) )
@ -83,7 +89,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
ignore_backoff=True, ignore_backoff=True,
) )
def test_incoming_fed_query(self): def test_incoming_fed_query(self) -> None:
self.get_success( self.get_success(
self.store.create_room_alias_association( self.store.create_room_alias_association(
self.your_room, "!8765asdf:test", ["test"] self.your_room, "!8765asdf:test", ["test"]
@ -105,7 +111,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
directory.register_servlets, directory.register_servlets,
] ]
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.handler = hs.get_directory_handler() self.handler = hs.get_directory_handler()
# Create user # Create user
@ -125,7 +131,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
self.test_user_tok = self.login("user", "pass") self.test_user_tok = self.login("user", "pass")
self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok) self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok)
def test_create_alias_joined_room(self): def test_create_alias_joined_room(self) -> None:
"""A user can create an alias for a room they're in.""" """A user can create an alias for a room they're in."""
self.get_success( self.get_success(
self.handler.create_association( self.handler.create_association(
@ -135,7 +141,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
) )
) )
def test_create_alias_other_room(self): def test_create_alias_other_room(self) -> None:
"""A user cannot create an alias for a room they're NOT in.""" """A user cannot create an alias for a room they're NOT in."""
other_room_id = self.helper.create_room_as( other_room_id = self.helper.create_room_as(
self.admin_user, tok=self.admin_user_tok self.admin_user, tok=self.admin_user_tok
@ -150,7 +156,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
synapse.api.errors.SynapseError, synapse.api.errors.SynapseError,
) )
def test_create_alias_admin(self): def test_create_alias_admin(self) -> None:
"""An admin can create an alias for a room they're NOT in.""" """An admin can create an alias for a room they're NOT in."""
other_room_id = self.helper.create_room_as( other_room_id = self.helper.create_room_as(
self.test_user, tok=self.test_user_tok self.test_user, tok=self.test_user_tok
@ -173,7 +179,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
directory.register_servlets, directory.register_servlets,
] ]
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.handler = hs.get_directory_handler() self.handler = hs.get_directory_handler()
self.state_handler = hs.get_state_handler() self.state_handler = hs.get_state_handler()
@ -195,7 +201,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
self.test_user_tok = self.login("user", "pass") self.test_user_tok = self.login("user", "pass")
self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok) self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok)
def _create_alias(self, user): def _create_alias(self, user) -> None:
# Create a new alias to this room. # Create a new alias to this room.
self.get_success( self.get_success(
self.store.create_room_alias_association( self.store.create_room_alias_association(
@ -203,7 +209,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
) )
) )
def test_delete_alias_not_allowed(self): def test_delete_alias_not_allowed(self) -> None:
"""A user that doesn't meet the expected guidelines cannot delete an alias.""" """A user that doesn't meet the expected guidelines cannot delete an alias."""
self._create_alias(self.admin_user) self._create_alias(self.admin_user)
self.get_failure( self.get_failure(
@ -213,7 +219,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
synapse.api.errors.AuthError, synapse.api.errors.AuthError,
) )
def test_delete_alias_creator(self): def test_delete_alias_creator(self) -> None:
"""An alias creator can delete their own alias.""" """An alias creator can delete their own alias."""
# Create an alias from a different user. # Create an alias from a different user.
self._create_alias(self.test_user) self._create_alias(self.test_user)
@ -232,7 +238,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
synapse.api.errors.SynapseError, synapse.api.errors.SynapseError,
) )
def test_delete_alias_admin(self): def test_delete_alias_admin(self) -> None:
"""A server admin can delete an alias created by another user.""" """A server admin can delete an alias created by another user."""
# Create an alias from a different user. # Create an alias from a different user.
self._create_alias(self.test_user) self._create_alias(self.test_user)
@ -251,7 +257,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
synapse.api.errors.SynapseError, synapse.api.errors.SynapseError,
) )
def test_delete_alias_sufficient_power(self): def test_delete_alias_sufficient_power(self) -> None:
"""A user with a sufficient power level should be able to delete an alias.""" """A user with a sufficient power level should be able to delete an alias."""
self._create_alias(self.admin_user) self._create_alias(self.admin_user)
@ -288,7 +294,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
directory.register_servlets, directory.register_servlets,
] ]
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.handler = hs.get_directory_handler() self.handler = hs.get_directory_handler()
self.state_handler = hs.get_state_handler() self.state_handler = hs.get_state_handler()
@ -317,7 +323,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
) )
return room_alias return room_alias
def _set_canonical_alias(self, content): def _set_canonical_alias(self, content) -> None:
"""Configure the canonical alias state on the room.""" """Configure the canonical alias state on the room."""
self.helper.send_state( self.helper.send_state(
self.room_id, self.room_id,
@ -334,7 +340,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
) )
) )
def test_remove_alias(self): def test_remove_alias(self) -> None:
"""Removing an alias that is the canonical alias should remove it there too.""" """Removing an alias that is the canonical alias should remove it there too."""
# Set this new alias as the canonical alias for this room # Set this new alias as the canonical alias for this room
self._set_canonical_alias( self._set_canonical_alias(
@ -356,7 +362,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
self.assertNotIn("alias", data["content"]) self.assertNotIn("alias", data["content"])
self.assertNotIn("alt_aliases", data["content"]) self.assertNotIn("alt_aliases", data["content"])
def test_remove_other_alias(self): def test_remove_other_alias(self) -> None:
"""Removing an alias listed as in alt_aliases should remove it there too.""" """Removing an alias listed as in alt_aliases should remove it there too."""
# Create a second alias. # Create a second alias.
other_test_alias = "#test2:test" other_test_alias = "#test2:test"
@ -393,7 +399,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
servlets = [directory.register_servlets, room.register_servlets] servlets = [directory.register_servlets, room.register_servlets]
def default_config(self): def default_config(self) -> Dict[str, Any]:
config = super().default_config() config = super().default_config()
# Add custom alias creation rules to the config. # Add custom alias creation rules to the config.
@ -403,7 +409,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
return config return config
def test_denied(self): def test_denied(self) -> None:
room_id = self.helper.create_room_as(self.user_id) room_id = self.helper.create_room_as(self.user_id)
channel = self.make_request( channel = self.make_request(
@ -413,7 +419,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
) )
self.assertEqual(403, channel.code, channel.result) self.assertEqual(403, channel.code, channel.result)
def test_allowed(self): def test_allowed(self) -> None:
room_id = self.helper.create_room_as(self.user_id) room_id = self.helper.create_room_as(self.user_id)
channel = self.make_request( channel = self.make_request(
@ -423,7 +429,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
) )
self.assertEqual(200, channel.code, channel.result) self.assertEqual(200, channel.code, channel.result)
def test_denied_during_creation(self): def test_denied_during_creation(self) -> None:
"""A room alias that is not allowed should be rejected during creation.""" """A room alias that is not allowed should be rejected during creation."""
# Invalid room alias. # Invalid room alias.
self.helper.create_room_as( self.helper.create_room_as(
@ -432,7 +438,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
extra_content={"room_alias_name": "foo"}, extra_content={"room_alias_name": "foo"},
) )
def test_allowed_during_creation(self): def test_allowed_during_creation(self) -> None:
"""A valid room alias should be allowed during creation.""" """A valid room alias should be allowed during creation."""
room_id = self.helper.create_room_as( room_id = self.helper.create_room_as(
self.user_id, self.user_id,
@ -459,7 +465,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
data = {"room_alias_name": "unofficial_test"} data = {"room_alias_name": "unofficial_test"}
allowed_localpart = "allowed" allowed_localpart = "allowed"
def default_config(self): def default_config(self) -> Dict[str, Any]:
config = super().default_config() config = super().default_config()
# Add custom room list publication rules to the config. # Add custom room list publication rules to the config.
@ -474,7 +480,9 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
return config return config
def prepare(self, reactor, clock, hs): def prepare(
self, reactor: MemoryReactor, clock: Clock, hs: HomeServer
) -> HomeServer:
self.allowed_user_id = self.register_user(self.allowed_localpart, "pass") self.allowed_user_id = self.register_user(self.allowed_localpart, "pass")
self.allowed_access_token = self.login(self.allowed_localpart, "pass") self.allowed_access_token = self.login(self.allowed_localpart, "pass")
@ -483,7 +491,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
return hs return hs
def test_denied_without_publication_permission(self): def test_denied_without_publication_permission(self) -> None:
""" """
Try to create a room, register an alias for it, and publish it, Try to create a room, register an alias for it, and publish it,
as a user without permission to publish rooms. as a user without permission to publish rooms.
@ -497,7 +505,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
expect_code=403, expect_code=403,
) )
def test_allowed_when_creating_private_room(self): def test_allowed_when_creating_private_room(self) -> None:
""" """
Try to create a room, register an alias for it, and NOT publish it, Try to create a room, register an alias for it, and NOT publish it,
as a user without permission to publish rooms. as a user without permission to publish rooms.
@ -511,7 +519,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
expect_code=200, expect_code=200,
) )
def test_allowed_with_publication_permission(self): def test_allowed_with_publication_permission(self) -> None:
""" """
Try to create a room, register an alias for it, and publish it, Try to create a room, register an alias for it, and publish it,
as a user WITH permission to publish rooms. as a user WITH permission to publish rooms.
@ -525,7 +533,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
expect_code=200, expect_code=200,
) )
def test_denied_publication_with_invalid_alias(self): def test_denied_publication_with_invalid_alias(self) -> None:
""" """
Try to create a room, register an alias for it, and publish it, Try to create a room, register an alias for it, and publish it,
as a user WITH permission to publish rooms. as a user WITH permission to publish rooms.
@ -538,7 +546,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
expect_code=403, expect_code=403,
) )
def test_can_create_as_private_room_after_rejection(self): def test_can_create_as_private_room_after_rejection(self) -> None:
""" """
After failing to publish a room with an alias as a user without publish permission, After failing to publish a room with an alias as a user without publish permission,
retry as the same user, but without publishing the room. retry as the same user, but without publishing the room.
@ -549,7 +557,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
self.test_denied_without_publication_permission() self.test_denied_without_publication_permission()
self.test_allowed_when_creating_private_room() self.test_allowed_when_creating_private_room()
def test_can_create_with_permission_after_rejection(self): def test_can_create_with_permission_after_rejection(self) -> None:
""" """
After failing to publish a room with an alias as a user without publish permission, After failing to publish a room with an alias as a user without publish permission,
retry as someone with permission, using the same alias. retry as someone with permission, using the same alias.
@ -566,7 +574,9 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
servlets = [directory.register_servlets, room.register_servlets] servlets = [directory.register_servlets, room.register_servlets]
def prepare(self, reactor, clock, hs): def prepare(
self, reactor: MemoryReactor, clock: Clock, hs: HomeServer
) -> HomeServer:
room_id = self.helper.create_room_as(self.user_id) room_id = self.helper.create_room_as(self.user_id)
channel = self.make_request( channel = self.make_request(
@ -579,7 +589,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
return hs return hs
def test_disabling_room_list(self): def test_disabling_room_list(self) -> None:
self.room_list_handler.enable_room_list_search = True self.room_list_handler.enable_room_list_search = True
self.directory_handler.enable_room_list_search = True self.directory_handler.enable_room_list_search = True

View file

@ -20,33 +20,37 @@ from parameterized import parameterized
from signedjson import key as key, sign as sign from signedjson import key as key, sign as sign
from twisted.internet import defer from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import RoomEncryptionAlgorithms from synapse.api.constants import RoomEncryptionAlgorithms
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable from tests.test_utils import make_awaitable
class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
return self.setup_test_homeserver(federation_client=mock.Mock()) return self.setup_test_homeserver(federation_client=mock.Mock())
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.handler = hs.get_e2e_keys_handler() self.handler = hs.get_e2e_keys_handler()
self.store = self.hs.get_datastores().main self.store = self.hs.get_datastores().main
def test_query_local_devices_no_devices(self): def test_query_local_devices_no_devices(self) -> None:
"""If the user has no devices, we expect an empty list.""" """If the user has no devices, we expect an empty list."""
local_user = "@boris:" + self.hs.hostname local_user = "@boris:" + self.hs.hostname
res = self.get_success(self.handler.query_local_devices({local_user: None})) res = self.get_success(self.handler.query_local_devices({local_user: None}))
self.assertDictEqual(res, {local_user: {}}) self.assertDictEqual(res, {local_user: {}})
def test_reupload_one_time_keys(self): def test_reupload_one_time_keys(self) -> None:
"""we should be able to re-upload the same keys""" """we should be able to re-upload the same keys"""
local_user = "@boris:" + self.hs.hostname local_user = "@boris:" + self.hs.hostname
device_id = "xyz" device_id = "xyz"
keys = { keys: JsonDict = {
"alg1:k1": "key1", "alg1:k1": "key1",
"alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}}, "alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}},
"alg2:k3": {"key": "key3"}, "alg2:k3": {"key": "key3"},
@ -74,7 +78,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
res, {"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}} res, {"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}}
) )
def test_change_one_time_keys(self): def test_change_one_time_keys(self) -> None:
"""attempts to change one-time-keys should be rejected""" """attempts to change one-time-keys should be rejected"""
local_user = "@boris:" + self.hs.hostname local_user = "@boris:" + self.hs.hostname
@ -134,7 +138,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
SynapseError, SynapseError,
) )
def test_claim_one_time_key(self): def test_claim_one_time_key(self) -> None:
local_user = "@boris:" + self.hs.hostname local_user = "@boris:" + self.hs.hostname
device_id = "xyz" device_id = "xyz"
keys = {"alg1:k1": "key1"} keys = {"alg1:k1": "key1"}
@ -161,7 +165,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
}, },
) )
def test_fallback_key(self): def test_fallback_key(self) -> None:
local_user = "@boris:" + self.hs.hostname local_user = "@boris:" + self.hs.hostname
device_id = "xyz" device_id = "xyz"
fallback_key = {"alg1:k1": "fallback_key1"} fallback_key = {"alg1:k1": "fallback_key1"}
@ -294,7 +298,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}}, {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}},
) )
def test_replace_master_key(self): def test_replace_master_key(self) -> None:
"""uploading a new signing key should make the old signing key unavailable""" """uploading a new signing key should make the old signing key unavailable"""
local_user = "@boris:" + self.hs.hostname local_user = "@boris:" + self.hs.hostname
keys1 = { keys1 = {
@ -328,7 +332,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
) )
self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]}) self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
def test_reupload_signatures(self): def test_reupload_signatures(self) -> None:
"""re-uploading a signature should not fail""" """re-uploading a signature should not fail"""
local_user = "@boris:" + self.hs.hostname local_user = "@boris:" + self.hs.hostname
keys1 = { keys1 = {
@ -433,7 +437,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
self.assertDictEqual(devices["device_keys"][local_user]["abc"], device_key_1) self.assertDictEqual(devices["device_keys"][local_user]["abc"], device_key_1)
self.assertDictEqual(devices["device_keys"][local_user]["def"], device_key_2) self.assertDictEqual(devices["device_keys"][local_user]["def"], device_key_2)
def test_self_signing_key_doesnt_show_up_as_device(self): def test_self_signing_key_doesnt_show_up_as_device(self) -> None:
"""signing keys should be hidden when fetching a user's devices""" """signing keys should be hidden when fetching a user's devices"""
local_user = "@boris:" + self.hs.hostname local_user = "@boris:" + self.hs.hostname
keys1 = { keys1 = {
@ -462,7 +466,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
res = self.get_success(self.handler.query_local_devices({local_user: None})) res = self.get_success(self.handler.query_local_devices({local_user: None}))
self.assertDictEqual(res, {local_user: {}}) self.assertDictEqual(res, {local_user: {}})
def test_upload_signatures(self): def test_upload_signatures(self) -> None:
"""should check signatures that are uploaded""" """should check signatures that are uploaded"""
# set up a user with cross-signing keys and a device. This user will # set up a user with cross-signing keys and a device. This user will
# try uploading signatures # try uploading signatures
@ -686,7 +690,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey], other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey],
) )
def test_query_devices_remote_no_sync(self): def test_query_devices_remote_no_sync(self) -> None:
"""Tests that querying keys for a remote user that we don't share a room """Tests that querying keys for a remote user that we don't share a room
with returns the cross signing keys correctly. with returns the cross signing keys correctly.
""" """
@ -759,7 +763,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
}, },
) )
def test_query_devices_remote_sync(self): def test_query_devices_remote_sync(self) -> None:
"""Tests that querying keys for a remote user that we share a room with, """Tests that querying keys for a remote user that we share a room with,
but haven't yet fetched the keys for, returns the cross signing keys but haven't yet fetched the keys for, returns the cross signing keys
correctly. correctly.
@ -845,7 +849,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
(["device_1", "device_2"],), (["device_1", "device_2"],),
] ]
) )
def test_query_all_devices_caches_result(self, device_ids: Iterable[str]): def test_query_all_devices_caches_result(self, device_ids: Iterable[str]) -> None:
"""Test that requests for all of a remote user's devices are cached. """Test that requests for all of a remote user's devices are cached.
We do this by asserting that only one call over federation was made, and that We do this by asserting that only one call over federation was made, and that
@ -853,7 +857,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
""" """
local_user_id = "@test:test" local_user_id = "@test:test"
remote_user_id = "@test:other" remote_user_id = "@test:other"
request_body = {"device_keys": {remote_user_id: []}} request_body: JsonDict = {"device_keys": {remote_user_id: []}}
response_devices = [ response_devices = [
{ {

View file

@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import List from typing import List, cast
from unittest import TestCase from unittest import TestCase
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersions
@ -23,7 +25,9 @@ from synapse.federation.federation_base import event_from_pdu_json
from synapse.logging.context import LoggingContext, run_in_background from synapse.logging.context import LoggingContext, run_in_background
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import login, room from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.types import create_requester from synapse.types import create_requester
from synapse.util import Clock
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
from tests import unittest from tests import unittest
@ -42,7 +46,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
room.register_servlets, room.register_servlets,
] ]
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver(federation_http_client=None) hs = self.setup_test_homeserver(federation_http_client=None)
self.handler = hs.get_federation_handler() self.handler = hs.get_federation_handler()
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
@ -50,7 +54,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
self._event_auth_handler = hs.get_event_auth_handler() self._event_auth_handler = hs.get_event_auth_handler()
return hs return hs
def test_exchange_revoked_invite(self): def test_exchange_revoked_invite(self) -> None:
user_id = self.register_user("kermit", "test") user_id = self.register_user("kermit", "test")
tok = self.login("kermit", "test") tok = self.login("kermit", "test")
@ -96,7 +100,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
self.assertEqual(failure.errcode, Codes.FORBIDDEN, failure) self.assertEqual(failure.errcode, Codes.FORBIDDEN, failure)
self.assertEqual(failure.msg, "You are not invited to this room.") self.assertEqual(failure.msg, "You are not invited to this room.")
def test_rejected_message_event_state(self): def test_rejected_message_event_state(self) -> None:
""" """
Check that we store the state group correctly for rejected non-state events. Check that we store the state group correctly for rejected non-state events.
@ -126,7 +130,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
"content": {}, "content": {},
"room_id": room_id, "room_id": room_id,
"sender": "@yetanotheruser:" + OTHER_SERVER, "sender": "@yetanotheruser:" + OTHER_SERVER,
"depth": join_event["depth"] + 1, "depth": cast(int, join_event["depth"]) + 1,
"prev_events": [join_event.event_id], "prev_events": [join_event.event_id],
"auth_events": [], "auth_events": [],
"origin_server_ts": self.clock.time_msec(), "origin_server_ts": self.clock.time_msec(),
@ -149,7 +153,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
self.assertEqual(sg, sg2) self.assertEqual(sg, sg2)
def test_rejected_state_event_state(self): def test_rejected_state_event_state(self) -> None:
""" """
Check that we store the state group correctly for rejected state events. Check that we store the state group correctly for rejected state events.
@ -180,7 +184,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
"content": {}, "content": {},
"room_id": room_id, "room_id": room_id,
"sender": "@yetanotheruser:" + OTHER_SERVER, "sender": "@yetanotheruser:" + OTHER_SERVER,
"depth": join_event["depth"] + 1, "depth": cast(int, join_event["depth"]) + 1,
"prev_events": [join_event.event_id], "prev_events": [join_event.event_id],
"auth_events": [], "auth_events": [],
"origin_server_ts": self.clock.time_msec(), "origin_server_ts": self.clock.time_msec(),
@ -203,7 +207,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
self.assertEqual(sg, sg2) self.assertEqual(sg, sg2)
def test_backfill_with_many_backward_extremities(self): def test_backfill_with_many_backward_extremities(self) -> None:
""" """
Check that we can backfill with many backward extremities. Check that we can backfill with many backward extremities.
The goal is to make sure that when we only use a portion The goal is to make sure that when we only use a portion
@ -262,7 +266,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
) )
self.get_success(d) self.get_success(d)
def test_backfill_floating_outlier_membership_auth(self): def test_backfill_floating_outlier_membership_auth(self) -> None:
""" """
As the local homeserver, check that we can properly process a federated As the local homeserver, check that we can properly process a federated
event from the OTHER_SERVER with auth_events that include a floating event from the OTHER_SERVER with auth_events that include a floating
@ -377,7 +381,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
for ae in auth_events for ae in auth_events
] ]
self.handler.federation_client.get_event_auth = get_event_auth self.handler.federation_client.get_event_auth = get_event_auth # type: ignore[assignment]
with LoggingContext("receive_pdu"): with LoggingContext("receive_pdu"):
# Fake the OTHER_SERVER federating the message event over to our local homeserver # Fake the OTHER_SERVER federating the message event over to our local homeserver
@ -397,7 +401,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
@unittest.override_config( @unittest.override_config(
{"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}} {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}}
) )
def test_invite_by_user_ratelimit(self): def test_invite_by_user_ratelimit(self) -> None:
"""Tests that invites from federation to a particular user are """Tests that invites from federation to a particular user are
actually rate-limited. actually rate-limited.
""" """
@ -446,7 +450,9 @@ class FederationTestCase(unittest.HomeserverTestCase):
exc=LimitExceededError, exc=LimitExceededError,
) )
def _build_and_send_join_event(self, other_server, other_user, room_id): def _build_and_send_join_event(
self, other_server: str, other_user: str, room_id: str
) -> EventBase:
join_event = self.get_success( join_event = self.get_success(
self.handler.on_make_join_request(other_server, room_id, other_user) self.handler.on_make_join_request(other_server, room_id, other_user)
) )
@ -469,7 +475,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
class EventFromPduTestCase(TestCase): class EventFromPduTestCase(TestCase):
def test_valid_json(self): def test_valid_json(self) -> None:
"""Valid JSON should be turned into an event.""" """Valid JSON should be turned into an event."""
ev = event_from_pdu_json( ev = event_from_pdu_json(
{ {
@ -487,7 +493,7 @@ class EventFromPduTestCase(TestCase):
self.assertIsInstance(ev, EventBase) self.assertIsInstance(ev, EventBase)
def test_invalid_numbers(self): def test_invalid_numbers(self) -> None:
"""Invalid values for an integer should be rejected, all floats should be rejected.""" """Invalid values for an integer should be rejected, all floats should be rejected."""
for value in [ for value in [
-(2 ** 53), -(2 ** 53),
@ -512,7 +518,7 @@ class EventFromPduTestCase(TestCase):
RoomVersions.V6, RoomVersions.V6,
) )
def test_invalid_nested(self): def test_invalid_nested(self) -> None:
"""List and dictionaries are recursively searched.""" """List and dictionaries are recursively searched."""
with self.assertRaises(SynapseError): with self.assertRaises(SynapseError):
event_from_pdu_json( event_from_pdu_json(

View file

@ -13,14 +13,18 @@
# limitations under the License. # limitations under the License.
import json import json
import os import os
from typing import Any, Dict
from unittest.mock import ANY, Mock, patch from unittest.mock import ANY, Mock, patch
from urllib.parse import parse_qs, urlparse from urllib.parse import parse_qs, urlparse
import pymacaroons import pymacaroons
from twisted.test.proto_helpers import MemoryReactor
from synapse.handlers.sso import MappingException from synapse.handlers.sso import MappingException
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import UserID from synapse.types import JsonDict, UserID
from synapse.util import Clock
from synapse.util.macaroons import get_value_from_macaroon from synapse.util.macaroons import get_value_from_macaroon
from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
@ -98,7 +102,7 @@ class TestMappingProviderFailures(TestMappingProvider):
} }
async def get_json(url): async def get_json(url: str) -> JsonDict:
# Mock get_json calls to handle jwks & oidc discovery endpoints # Mock get_json calls to handle jwks & oidc discovery endpoints
if url == WELL_KNOWN: if url == WELL_KNOWN:
# Minimal discovery document, as defined in OpenID.Discovery # Minimal discovery document, as defined in OpenID.Discovery
@ -116,6 +120,8 @@ async def get_json(url):
elif url == JWKS_URI: elif url == JWKS_URI:
return {"keys": []} return {"keys": []}
return {}
def _key_file_path() -> str: def _key_file_path() -> str:
"""path to a file containing the private half of a test key""" """path to a file containing the private half of a test key"""
@ -147,12 +153,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
if not HAS_OIDC: if not HAS_OIDC:
skip = "requires OIDC" skip = "requires OIDC"
def default_config(self): def default_config(self) -> Dict[str, Any]:
config = super().default_config() config = super().default_config()
config["public_baseurl"] = BASE_URL config["public_baseurl"] = BASE_URL
return config return config
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.http_client = Mock(spec=["get_json"]) self.http_client = Mock(spec=["get_json"])
self.http_client.get_json.side_effect = get_json self.http_client.get_json.side_effect = get_json
self.http_client.user_agent = b"Synapse Test" self.http_client.user_agent = b"Synapse Test"
@ -164,7 +170,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
sso_handler = hs.get_sso_handler() sso_handler = hs.get_sso_handler()
# Mock the render error method. # Mock the render error method.
self.render_error = Mock(return_value=None) self.render_error = Mock(return_value=None)
sso_handler.render_error = self.render_error sso_handler.render_error = self.render_error # type: ignore[assignment]
# Reduce the number of attempts when generating MXIDs. # Reduce the number of attempts when generating MXIDs.
sso_handler._MAP_USERNAME_RETRIES = 3 sso_handler._MAP_USERNAME_RETRIES = 3
@ -193,14 +199,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
return args return args
@override_config({"oidc_config": DEFAULT_CONFIG}) @override_config({"oidc_config": DEFAULT_CONFIG})
def test_config(self): def test_config(self) -> None:
"""Basic config correctly sets up the callback URL and client auth correctly.""" """Basic config correctly sets up the callback URL and client auth correctly."""
self.assertEqual(self.provider._callback_url, CALLBACK_URL) self.assertEqual(self.provider._callback_url, CALLBACK_URL)
self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID) self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID)
self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET) self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET)
@override_config({"oidc_config": {**DEFAULT_CONFIG, "discover": True}}) @override_config({"oidc_config": {**DEFAULT_CONFIG, "discover": True}})
def test_discovery(self): def test_discovery(self) -> None:
"""The handler should discover the endpoints from OIDC discovery document.""" """The handler should discover the endpoints from OIDC discovery document."""
# This would throw if some metadata were invalid # This would throw if some metadata were invalid
metadata = self.get_success(self.provider.load_metadata()) metadata = self.get_success(self.provider.load_metadata())
@ -219,13 +225,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.http_client.get_json.assert_not_called() self.http_client.get_json.assert_not_called()
@override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG}) @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
def test_no_discovery(self): def test_no_discovery(self) -> None:
"""When discovery is disabled, it should not try to load from discovery document.""" """When discovery is disabled, it should not try to load from discovery document."""
self.get_success(self.provider.load_metadata()) self.get_success(self.provider.load_metadata())
self.http_client.get_json.assert_not_called() self.http_client.get_json.assert_not_called()
@override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG}) @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
def test_load_jwks(self): def test_load_jwks(self) -> None:
"""JWKS loading is done once (then cached) if used.""" """JWKS loading is done once (then cached) if used."""
jwks = self.get_success(self.provider.load_jwks()) jwks = self.get_success(self.provider.load_jwks())
self.http_client.get_json.assert_called_once_with(JWKS_URI) self.http_client.get_json.assert_called_once_with(JWKS_URI)
@ -253,7 +259,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_failure(self.provider.load_jwks(force=True), RuntimeError) self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
@override_config({"oidc_config": DEFAULT_CONFIG}) @override_config({"oidc_config": DEFAULT_CONFIG})
def test_validate_config(self): def test_validate_config(self) -> None:
"""Provider metadatas are extensively validated.""" """Provider metadatas are extensively validated."""
h = self.provider h = self.provider
@ -336,14 +342,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
force_load_metadata() force_load_metadata()
@override_config({"oidc_config": {**DEFAULT_CONFIG, "skip_verification": True}}) @override_config({"oidc_config": {**DEFAULT_CONFIG, "skip_verification": True}})
def test_skip_verification(self): def test_skip_verification(self) -> None:
"""Provider metadata validation can be disabled by config.""" """Provider metadata validation can be disabled by config."""
with self.metadata_edit({"issuer": "http://insecure"}): with self.metadata_edit({"issuer": "http://insecure"}):
# This should not throw # This should not throw
get_awaitable_result(self.provider.load_metadata()) get_awaitable_result(self.provider.load_metadata())
@override_config({"oidc_config": DEFAULT_CONFIG}) @override_config({"oidc_config": DEFAULT_CONFIG})
def test_redirect_request(self): def test_redirect_request(self) -> None:
"""The redirect request has the right arguments & generates a valid session cookie.""" """The redirect request has the right arguments & generates a valid session cookie."""
req = Mock(spec=["cookies"]) req = Mock(spec=["cookies"])
req.cookies = [] req.cookies = []
@ -387,7 +393,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertEqual(redirect, "http://client/redirect") self.assertEqual(redirect, "http://client/redirect")
@override_config({"oidc_config": DEFAULT_CONFIG}) @override_config({"oidc_config": DEFAULT_CONFIG})
def test_callback_error(self): def test_callback_error(self) -> None:
"""Errors from the provider returned in the callback are displayed.""" """Errors from the provider returned in the callback are displayed."""
request = Mock(args={}) request = Mock(args={})
request.args[b"error"] = [b"invalid_client"] request.args[b"error"] = [b"invalid_client"]
@ -399,7 +405,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertRenderedError("invalid_client", "some description") self.assertRenderedError("invalid_client", "some description")
@override_config({"oidc_config": DEFAULT_CONFIG}) @override_config({"oidc_config": DEFAULT_CONFIG})
def test_callback(self): def test_callback(self) -> None:
"""Code callback works and display errors if something went wrong. """Code callback works and display errors if something went wrong.
A lot of scenarios are tested here: A lot of scenarios are tested here:
@ -428,9 +434,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
"username": username, "username": username,
} }
expected_user_id = "@%s:%s" % (username, self.hs.hostname) expected_user_id = "@%s:%s" % (username, self.hs.hostname)
self.provider._exchange_code = simple_async_mock(return_value=token) self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
self.provider._parse_id_token = simple_async_mock(return_value=userinfo) self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo) self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
auth_handler = self.hs.get_auth_handler() auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock() auth_handler.complete_sso_login = simple_async_mock()
@ -468,7 +474,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertRenderedError("mapping_error") self.assertRenderedError("mapping_error")
# Handle ID token errors # Handle ID token errors
self.provider._parse_id_token = simple_async_mock(raises=Exception()) self.provider._parse_id_token = simple_async_mock(raises=Exception()) # type: ignore[assignment]
self.get_success(self.handler.handle_oidc_callback(request)) self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_token") self.assertRenderedError("invalid_token")
@ -483,7 +489,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
"type": "bearer", "type": "bearer",
"access_token": "access_token", "access_token": "access_token",
} }
self.provider._exchange_code = simple_async_mock(return_value=token) self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
self.get_success(self.handler.handle_oidc_callback(request)) self.get_success(self.handler.handle_oidc_callback(request))
auth_handler.complete_sso_login.assert_called_once_with( auth_handler.complete_sso_login.assert_called_once_with(
@ -510,8 +516,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
id_token = { id_token = {
"sid": "abcdefgh", "sid": "abcdefgh",
} }
self.provider._parse_id_token = simple_async_mock(return_value=id_token) self.provider._parse_id_token = simple_async_mock(return_value=id_token) # type: ignore[assignment]
self.provider._exchange_code = simple_async_mock(return_value=token) self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
auth_handler.complete_sso_login.reset_mock() auth_handler.complete_sso_login.reset_mock()
self.provider._fetch_userinfo.reset_mock() self.provider._fetch_userinfo.reset_mock()
self.get_success(self.handler.handle_oidc_callback(request)) self.get_success(self.handler.handle_oidc_callback(request))
@ -531,21 +537,21 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.render_error.assert_not_called() self.render_error.assert_not_called()
# Handle userinfo fetching error # Handle userinfo fetching error
self.provider._fetch_userinfo = simple_async_mock(raises=Exception()) self.provider._fetch_userinfo = simple_async_mock(raises=Exception()) # type: ignore[assignment]
self.get_success(self.handler.handle_oidc_callback(request)) self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("fetch_error") self.assertRenderedError("fetch_error")
# Handle code exchange failure # Handle code exchange failure
from synapse.handlers.oidc import OidcError from synapse.handlers.oidc import OidcError
self.provider._exchange_code = simple_async_mock( self.provider._exchange_code = simple_async_mock( # type: ignore[assignment]
raises=OidcError("invalid_request") raises=OidcError("invalid_request")
) )
self.get_success(self.handler.handle_oidc_callback(request)) self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_request") self.assertRenderedError("invalid_request")
@override_config({"oidc_config": DEFAULT_CONFIG}) @override_config({"oidc_config": DEFAULT_CONFIG})
def test_callback_session(self): def test_callback_session(self) -> None:
"""The callback verifies the session presence and validity""" """The callback verifies the session presence and validity"""
request = Mock(spec=["args", "getCookie", "cookies"]) request = Mock(spec=["args", "getCookie", "cookies"])
@ -590,7 +596,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
@override_config( @override_config(
{"oidc_config": {**DEFAULT_CONFIG, "client_auth_method": "client_secret_post"}} {"oidc_config": {**DEFAULT_CONFIG, "client_auth_method": "client_secret_post"}}
) )
def test_exchange_code(self): def test_exchange_code(self) -> None:
"""Code exchange behaves correctly and handles various error scenarios.""" """Code exchange behaves correctly and handles various error scenarios."""
token = {"type": "bearer"} token = {"type": "bearer"}
token_json = json.dumps(token).encode("utf-8") token_json = json.dumps(token).encode("utf-8")
@ -686,7 +692,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
} }
} }
) )
def test_exchange_code_jwt_key(self): def test_exchange_code_jwt_key(self) -> None:
"""Test that code exchange works with a JWK client secret.""" """Test that code exchange works with a JWK client secret."""
from authlib.jose import jwt from authlib.jose import jwt
@ -741,7 +747,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
} }
} }
) )
def test_exchange_code_no_auth(self): def test_exchange_code_no_auth(self) -> None:
"""Test that code exchange works with no client secret.""" """Test that code exchange works with no client secret."""
token = {"type": "bearer"} token = {"type": "bearer"}
self.http_client.request = simple_async_mock( self.http_client.request = simple_async_mock(
@ -776,7 +782,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
} }
} }
) )
def test_extra_attributes(self): def test_extra_attributes(self) -> None:
""" """
Login while using a mapping provider that implements get_extra_attributes. Login while using a mapping provider that implements get_extra_attributes.
""" """
@ -790,8 +796,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
"username": "foo", "username": "foo",
"phone": "1234567", "phone": "1234567",
} }
self.provider._exchange_code = simple_async_mock(return_value=token) self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
self.provider._parse_id_token = simple_async_mock(return_value=userinfo) self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
auth_handler = self.hs.get_auth_handler() auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock() auth_handler.complete_sso_login = simple_async_mock()
@ -817,12 +823,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
) )
@override_config({"oidc_config": DEFAULT_CONFIG}) @override_config({"oidc_config": DEFAULT_CONFIG})
def test_map_userinfo_to_user(self): def test_map_userinfo_to_user(self) -> None:
"""Ensure that mapping the userinfo returned from a provider to an MXID works properly.""" """Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
auth_handler = self.hs.get_auth_handler() auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock() auth_handler.complete_sso_login = simple_async_mock()
userinfo = { userinfo: dict = {
"sub": "test_user", "sub": "test_user",
"username": "test_user", "username": "test_user",
} }
@ -870,7 +876,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
) )
@override_config({"oidc_config": {**DEFAULT_CONFIG, "allow_existing_users": True}}) @override_config({"oidc_config": {**DEFAULT_CONFIG, "allow_existing_users": True}})
def test_map_userinfo_to_existing_user(self): def test_map_userinfo_to_existing_user(self) -> None:
"""Existing users can log in with OpenID Connect when allow_existing_users is True.""" """Existing users can log in with OpenID Connect when allow_existing_users is True."""
store = self.hs.get_datastores().main store = self.hs.get_datastores().main
user = UserID.from_string("@test_user:test") user = UserID.from_string("@test_user:test")
@ -974,7 +980,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
) )
@override_config({"oidc_config": DEFAULT_CONFIG}) @override_config({"oidc_config": DEFAULT_CONFIG})
def test_map_userinfo_to_invalid_localpart(self): def test_map_userinfo_to_invalid_localpart(self) -> None:
"""If the mapping provider generates an invalid localpart it should be rejected.""" """If the mapping provider generates an invalid localpart it should be rejected."""
self.get_success( self.get_success(
_make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"}) _make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"})
@ -991,7 +997,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
} }
} }
) )
def test_map_userinfo_to_user_retries(self): def test_map_userinfo_to_user_retries(self) -> None:
"""The mapping provider can retry generating an MXID if the MXID is already in use.""" """The mapping provider can retry generating an MXID if the MXID is already in use."""
auth_handler = self.hs.get_auth_handler() auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock() auth_handler.complete_sso_login = simple_async_mock()
@ -1039,7 +1045,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
) )
@override_config({"oidc_config": DEFAULT_CONFIG}) @override_config({"oidc_config": DEFAULT_CONFIG})
def test_empty_localpart(self): def test_empty_localpart(self) -> None:
"""Attempts to map onto an empty localpart should be rejected.""" """Attempts to map onto an empty localpart should be rejected."""
userinfo = { userinfo = {
"sub": "tester", "sub": "tester",
@ -1058,7 +1064,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
} }
} }
) )
def test_null_localpart(self): def test_null_localpart(self) -> None:
"""Mapping onto a null localpart via an empty OIDC attribute should be rejected""" """Mapping onto a null localpart via an empty OIDC attribute should be rejected"""
userinfo = { userinfo = {
"sub": "tester", "sub": "tester",
@ -1075,7 +1081,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
} }
} }
) )
def test_attribute_requirements(self): def test_attribute_requirements(self) -> None:
"""The required attributes must be met from the OIDC userinfo response.""" """The required attributes must be met from the OIDC userinfo response."""
auth_handler = self.hs.get_auth_handler() auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock() auth_handler.complete_sso_login = simple_async_mock()
@ -1115,7 +1121,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
} }
} }
) )
def test_attribute_requirements_contains(self): def test_attribute_requirements_contains(self) -> None:
"""Test that auth succeeds if userinfo attribute CONTAINS required value""" """Test that auth succeeds if userinfo attribute CONTAINS required value"""
auth_handler = self.hs.get_auth_handler() auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock() auth_handler.complete_sso_login = simple_async_mock()
@ -1146,7 +1152,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
} }
} }
) )
def test_attribute_requirements_mismatch(self): def test_attribute_requirements_mismatch(self) -> None:
""" """
Test that auth fails if attributes exist but don't match, Test that auth fails if attributes exist but don't match,
or are non-string values. or are non-string values.
@ -1154,7 +1160,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
auth_handler = self.hs.get_auth_handler() auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock() auth_handler.complete_sso_login = simple_async_mock()
# userinfo with "test": "not_foobar" attribute should fail # userinfo with "test": "not_foobar" attribute should fail
userinfo = { userinfo: dict = {
"sub": "tester", "sub": "tester",
"username": "tester", "username": "tester",
"test": "not_foobar", "test": "not_foobar",
@ -1248,9 +1254,9 @@ async def _make_callback_with_userinfo(
handler = hs.get_oidc_handler() handler = hs.get_oidc_handler()
provider = handler._providers["oidc"] provider = handler._providers["oidc"]
provider._exchange_code = simple_async_mock(return_value={"id_token": ""}) provider._exchange_code = simple_async_mock(return_value={"id_token": ""}) # type: ignore[assignment]
provider._parse_id_token = simple_async_mock(return_value=userinfo) provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
provider._fetch_userinfo = simple_async_mock(return_value=userinfo) provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
state = "state" state = "state"
session = handler._token_generator.generate_oidc_session_token( session = handler._token_generator.generate_oidc_session_token(

View file

@ -331,11 +331,11 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
# Extract presence update user ID and state information into lists of tuples # Extract presence update user ID and state information into lists of tuples
db_presence_states = [(ps[0], ps[1]) for _, ps in db_presence_states[0]] db_presence_states = [(ps[0], ps[1]) for _, ps in db_presence_states[0]]
presence_states = [(ps.user_id, ps.state) for ps in presence_states] presence_states_compare = [(ps.user_id, ps.state) for ps in presence_states]
# Compare what we put into the storage with what we got out. # Compare what we put into the storage with what we got out.
# They should be identical. # They should be identical.
self.assertEqual(presence_states, db_presence_states) self.assertEqual(presence_states_compare, db_presence_states)
class PresenceTimeoutTestCase(unittest.TestCase): class PresenceTimeoutTestCase(unittest.TestCase):
@ -357,6 +357,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
self.assertIsNotNone(new_state) self.assertIsNotNone(new_state)
assert new_state is not None
self.assertEqual(new_state.state, PresenceState.UNAVAILABLE) self.assertEqual(new_state.state, PresenceState.UNAVAILABLE)
self.assertEqual(new_state.status_msg, status_msg) self.assertEqual(new_state.status_msg, status_msg)
@ -380,6 +381,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
self.assertIsNotNone(new_state) self.assertIsNotNone(new_state)
assert new_state is not None
self.assertEqual(new_state.state, PresenceState.BUSY) self.assertEqual(new_state.state, PresenceState.BUSY)
self.assertEqual(new_state.status_msg, status_msg) self.assertEqual(new_state.status_msg, status_msg)
@ -399,6 +401,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
self.assertIsNotNone(new_state) self.assertIsNotNone(new_state)
assert new_state is not None
self.assertEqual(new_state.state, PresenceState.OFFLINE) self.assertEqual(new_state.state, PresenceState.OFFLINE)
self.assertEqual(new_state.status_msg, status_msg) self.assertEqual(new_state.status_msg, status_msg)
@ -420,6 +423,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
) )
self.assertIsNotNone(new_state) self.assertIsNotNone(new_state)
assert new_state is not None
self.assertEqual(new_state.state, PresenceState.ONLINE) self.assertEqual(new_state.state, PresenceState.ONLINE)
self.assertEqual(new_state.status_msg, status_msg) self.assertEqual(new_state.status_msg, status_msg)
@ -477,6 +481,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
) )
self.assertIsNotNone(new_state) self.assertIsNotNone(new_state)
assert new_state is not None
self.assertEqual(new_state.state, PresenceState.OFFLINE) self.assertEqual(new_state.state, PresenceState.OFFLINE)
self.assertEqual(new_state.status_msg, status_msg) self.assertEqual(new_state.status_msg, status_msg)
@ -653,13 +658,13 @@ class PresenceHandlerTestCase(unittest.HomeserverTestCase):
self._set_presencestate_with_status_msg(user_id, PresenceState.ONLINE, None) self._set_presencestate_with_status_msg(user_id, PresenceState.ONLINE, None)
def _set_presencestate_with_status_msg( def _set_presencestate_with_status_msg(
self, user_id: str, state: PresenceState, status_msg: Optional[str] self, user_id: str, state: str, status_msg: Optional[str]
): ):
"""Set a PresenceState and status_msg and check the result. """Set a PresenceState and status_msg and check the result.
Args: Args:
user_id: User for that the status is to be set. user_id: User for that the status is to be set.
PresenceState: The new PresenceState. state: The new PresenceState.
status_msg: Status message that is to be set. status_msg: Status message that is to be set.
""" """
self.get_success( self.get_success(

View file

@ -11,14 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Dict from typing import Any, Awaitable, Callable, Dict
from unittest.mock import Mock from unittest.mock import Mock
from twisted.test.proto_helpers import MemoryReactor
import synapse.types import synapse.types
from synapse.api.errors import AuthError, SynapseError from synapse.api.errors import AuthError, SynapseError
from synapse.rest import admin from synapse.rest import admin
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import UserID from synapse.types import JsonDict, UserID
from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable from tests.test_utils import make_awaitable
@ -29,13 +32,15 @@ class ProfileTestCase(unittest.HomeserverTestCase):
servlets = [admin.register_servlets] servlets = [admin.register_servlets]
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.mock_federation = Mock() self.mock_federation = Mock()
self.mock_registry = Mock() self.mock_registry = Mock()
self.query_handlers = {} self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {}
def register_query_handler(query_type, handler): def register_query_handler(
query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
) -> None:
self.query_handlers[query_type] = handler self.query_handlers[query_type] = handler
self.mock_registry.register_query_handler = register_query_handler self.mock_registry.register_query_handler = register_query_handler
@ -47,7 +52,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
) )
return hs return hs
def prepare(self, reactor, clock, hs: HomeServer): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.frank = UserID.from_string("@1234abcd:test") self.frank = UserID.from_string("@1234abcd:test")
@ -58,7 +63,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.handler = hs.get_profile_handler() self.handler = hs.get_profile_handler()
def test_get_my_name(self): def test_get_my_name(self) -> None:
self.get_success( self.get_success(
self.store.set_profile_displayname(self.frank.localpart, "Frank") self.store.set_profile_displayname(self.frank.localpart, "Frank")
) )
@ -67,7 +72,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertEqual("Frank", displayname) self.assertEqual("Frank", displayname)
def test_set_my_name(self): def test_set_my_name(self) -> None:
self.get_success( self.get_success(
self.handler.set_displayname( self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.frank), "Frank Jr." self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
@ -110,7 +115,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.get_success(self.store.get_profile_displayname(self.frank.localpart)) self.get_success(self.store.get_profile_displayname(self.frank.localpart))
) )
def test_set_my_name_if_disabled(self): def test_set_my_name_if_disabled(self) -> None:
self.hs.config.registration.enable_set_displayname = False self.hs.config.registration.enable_set_displayname = False
# Setting displayname for the first time is allowed # Setting displayname for the first time is allowed
@ -135,7 +140,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
SynapseError, SynapseError,
) )
def test_set_my_name_noauth(self): def test_set_my_name_noauth(self) -> None:
self.get_failure( self.get_failure(
self.handler.set_displayname( self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.bob), "Frank Jr." self.frank, synapse.types.create_requester(self.bob), "Frank Jr."
@ -143,7 +148,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
AuthError, AuthError,
) )
def test_get_other_name(self): def test_get_other_name(self) -> None:
self.mock_federation.make_query.return_value = make_awaitable( self.mock_federation.make_query.return_value = make_awaitable(
{"displayname": "Alice"} {"displayname": "Alice"}
) )
@ -158,7 +163,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
ignore_backoff=True, ignore_backoff=True,
) )
def test_incoming_fed_query(self): def test_incoming_fed_query(self) -> None:
self.get_success(self.store.create_profile("caroline")) self.get_success(self.store.create_profile("caroline"))
self.get_success(self.store.set_profile_displayname("caroline", "Caroline")) self.get_success(self.store.set_profile_displayname("caroline", "Caroline"))
@ -174,7 +179,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertEqual({"displayname": "Caroline"}, response) self.assertEqual({"displayname": "Caroline"}, response)
def test_get_my_avatar(self): def test_get_my_avatar(self) -> None:
self.get_success( self.get_success(
self.store.set_profile_avatar_url( self.store.set_profile_avatar_url(
self.frank.localpart, "http://my.server/me.png" self.frank.localpart, "http://my.server/me.png"
@ -184,7 +189,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertEqual("http://my.server/me.png", avatar_url) self.assertEqual("http://my.server/me.png", avatar_url)
def test_set_my_avatar(self): def test_set_my_avatar(self) -> None:
self.get_success( self.get_success(
self.handler.set_avatar_url( self.handler.set_avatar_url(
self.frank, self.frank,
@ -225,7 +230,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
(self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))), (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
) )
def test_set_my_avatar_if_disabled(self): def test_set_my_avatar_if_disabled(self) -> None:
self.hs.config.registration.enable_set_avatar_url = False self.hs.config.registration.enable_set_avatar_url = False
# Setting displayname for the first time is allowed # Setting displayname for the first time is allowed
@ -250,7 +255,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
SynapseError, SynapseError,
) )
def test_avatar_constraints_no_config(self): def test_avatar_constraints_no_config(self) -> None:
"""Tests that the method to check an avatar against configured constraints skips """Tests that the method to check an avatar against configured constraints skips
all of its check if no constraint is configured. all of its check if no constraint is configured.
""" """
@ -263,7 +268,13 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertTrue(res) self.assertTrue(res)
@unittest.override_config({"max_avatar_size": 50}) @unittest.override_config({"max_avatar_size": 50})
def test_avatar_constraints_missing(self): def test_avatar_constraints_allow_empty_avatar_url(self) -> None:
"""An empty avatar is always permitted."""
res = self.get_success(self.handler.check_avatar_size_and_mime_type(""))
self.assertTrue(res)
@unittest.override_config({"max_avatar_size": 50})
def test_avatar_constraints_missing(self) -> None:
"""Tests that an avatar isn't allowed if the file at the given MXC URI couldn't """Tests that an avatar isn't allowed if the file at the given MXC URI couldn't
be found. be found.
""" """
@ -273,7 +284,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertFalse(res) self.assertFalse(res)
@unittest.override_config({"max_avatar_size": 50}) @unittest.override_config({"max_avatar_size": 50})
def test_avatar_constraints_file_size(self): def test_avatar_constraints_file_size(self) -> None:
"""Tests that a file that's above the allowed file size is forbidden but one """Tests that a file that's above the allowed file size is forbidden but one
that's below it is allowed. that's below it is allowed.
""" """
@ -295,7 +306,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertFalse(res) self.assertFalse(res)
@unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]}) @unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]})
def test_avatar_constraint_mime_type(self): def test_avatar_constraint_mime_type(self) -> None:
"""Tests that a file with an unauthorised MIME type is forbidden but one with """Tests that a file with an unauthorised MIME type is forbidden but one with
an authorised content type is allowed. an authorised content type is allowed.
""" """

View file

@ -12,12 +12,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional from typing import Any, Dict, Optional
from unittest.mock import Mock from unittest.mock import Mock
import attr import attr
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import RedirectException from synapse.api.errors import RedirectException
from synapse.server import HomeServer
from synapse.util import Clock
from tests.test_utils import simple_async_mock from tests.test_utils import simple_async_mock
from tests.unittest import HomeserverTestCase, override_config from tests.unittest import HomeserverTestCase, override_config
@ -81,10 +85,10 @@ class TestRedirectMappingProvider(TestMappingProvider):
class SamlHandlerTestCase(HomeserverTestCase): class SamlHandlerTestCase(HomeserverTestCase):
def default_config(self): def default_config(self) -> Dict[str, Any]:
config = super().default_config() config = super().default_config()
config["public_baseurl"] = BASE_URL config["public_baseurl"] = BASE_URL
saml_config = { saml_config: Dict[str, Any] = {
"sp_config": {"metadata": {}}, "sp_config": {"metadata": {}},
# Disable grandfathering. # Disable grandfathering.
"grandfathered_mxid_source_attribute": None, "grandfathered_mxid_source_attribute": None,
@ -98,7 +102,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
return config return config
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver() hs = self.setup_test_homeserver()
self.handler = hs.get_saml_handler() self.handler = hs.get_saml_handler()
@ -114,7 +118,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
elif not has_xmlsec1: elif not has_xmlsec1:
skip = "Requires xmlsec1" skip = "Requires xmlsec1"
def test_map_saml_response_to_user(self): def test_map_saml_response_to_user(self) -> None:
"""Ensure that mapping the SAML response returned from a provider to an MXID works properly.""" """Ensure that mapping the SAML response returned from a provider to an MXID works properly."""
# stub out the auth handler # stub out the auth handler
@ -140,7 +144,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
) )
@override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}}) @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
def test_map_saml_response_to_existing_user(self): def test_map_saml_response_to_existing_user(self) -> None:
"""Existing users can log in with SAML account.""" """Existing users can log in with SAML account."""
store = self.hs.get_datastores().main store = self.hs.get_datastores().main
self.get_success( self.get_success(
@ -186,7 +190,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
auth_provider_session_id=None, auth_provider_session_id=None,
) )
def test_map_saml_response_to_invalid_localpart(self): def test_map_saml_response_to_invalid_localpart(self) -> None:
"""If the mapping provider generates an invalid localpart it should be rejected.""" """If the mapping provider generates an invalid localpart it should be rejected."""
# stub out the auth handler # stub out the auth handler
@ -207,7 +211,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
) )
auth_handler.complete_sso_login.assert_not_called() auth_handler.complete_sso_login.assert_not_called()
def test_map_saml_response_to_user_retries(self): def test_map_saml_response_to_user_retries(self) -> None:
"""The mapping provider can retry generating an MXID if the MXID is already in use.""" """The mapping provider can retry generating an MXID if the MXID is already in use."""
# stub out the auth handler and error renderer # stub out the auth handler and error renderer
@ -271,7 +275,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
} }
} }
) )
def test_map_saml_response_redirect(self): def test_map_saml_response_redirect(self) -> None:
"""Test a mapping provider that raises a RedirectException""" """Test a mapping provider that raises a RedirectException"""
saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"}) saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
@ -292,7 +296,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
}, },
} }
) )
def test_attribute_requirements(self): def test_attribute_requirements(self) -> None:
"""The required attributes must be met from the SAML response.""" """The required attributes must be met from the SAML response."""
# stub out the auth handler # stub out the auth handler

View file

@ -18,11 +18,14 @@ from typing import Dict
from unittest.mock import ANY, Mock, call from unittest.mock import ANY, Mock, call
from twisted.internet import defer from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource from twisted.web.resource import Resource
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.federation.transport.server import TransportLayerServer from synapse.federation.transport.server import TransportLayerServer
from synapse.types import UserID, create_requester from synapse.server import HomeServer
from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable from tests.test_utils import make_awaitable
@ -42,7 +45,9 @@ ROOM_ID = "a-room"
OTHER_ROOM_ID = "another-room" OTHER_ROOM_ID = "another-room"
def _expect_edu_transaction(edu_type, content, origin="test"): def _expect_edu_transaction(
edu_type: str, content: JsonDict, origin: str = "test"
) -> JsonDict:
return { return {
"origin": origin, "origin": origin,
"origin_server_ts": 1000000, "origin_server_ts": 1000000,
@ -51,12 +56,12 @@ def _expect_edu_transaction(edu_type, content, origin="test"):
} }
def _make_edu_transaction_json(edu_type, content): def _make_edu_transaction_json(edu_type: str, content: JsonDict) -> bytes:
return json.dumps(_expect_edu_transaction(edu_type, content)).encode("utf8") return json.dumps(_expect_edu_transaction(edu_type, content)).encode("utf8")
class TypingNotificationsTestCase(unittest.HomeserverTestCase): class TypingNotificationsTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
# we mock out the keyring so as to skip the authentication check on the # we mock out the keyring so as to skip the authentication check on the
# federation API call. # federation API call.
mock_keyring = Mock(spec=["verify_json_for_server"]) mock_keyring = Mock(spec=["verify_json_for_server"])
@ -83,7 +88,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
d["/_matrix/federation"] = TransportLayerServer(self.hs) d["/_matrix/federation"] = TransportLayerServer(self.hs)
return d return d
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
mock_notifier = hs.get_notifier() mock_notifier = hs.get_notifier()
self.on_new_event = mock_notifier.on_new_event self.on_new_event = mock_notifier.on_new_event
@ -111,24 +116,24 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.room_members = [] self.room_members = []
async def check_user_in_room(room_id, user_id): async def check_user_in_room(room_id: str, user_id: str) -> None:
if user_id not in [u.to_string() for u in self.room_members]: if user_id not in [u.to_string() for u in self.room_members]:
raise AuthError(401, "User is not in the room") raise AuthError(401, "User is not in the room")
return None return None
hs.get_auth().check_user_in_room = check_user_in_room hs.get_auth().check_user_in_room = check_user_in_room
async def check_host_in_room(room_id, server_name): async def check_host_in_room(room_id: str, server_name: str) -> bool:
return room_id == ROOM_ID return room_id == ROOM_ID
hs.get_event_auth_handler().check_host_in_room = check_host_in_room hs.get_event_auth_handler().check_host_in_room = check_host_in_room
def get_joined_hosts_for_room(room_id): def get_joined_hosts_for_room(room_id: str):
return {member.domain for member in self.room_members} return {member.domain for member in self.room_members}
self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
async def get_users_in_room(room_id): async def get_users_in_room(room_id: str):
return {str(u) for u in self.room_members} return {str(u) for u in self.room_members}
self.datastore.get_users_in_room = get_users_in_room self.datastore.get_users_in_room = get_users_in_room
@ -153,7 +158,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
lambda *args, **kwargs: make_awaitable(None) lambda *args, **kwargs: make_awaitable(None)
) )
def test_started_typing_local(self): def test_started_typing_local(self) -> None:
self.room_members = [U_APPLE, U_BANANA] self.room_members = [U_APPLE, U_BANANA]
self.assertEqual(self.event_source.get_current_key(), 0) self.assertEqual(self.event_source.get_current_key(), 0)
@ -187,7 +192,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
) )
@override_config({"send_federation": True}) @override_config({"send_federation": True})
def test_started_typing_remote_send(self): def test_started_typing_remote_send(self) -> None:
self.room_members = [U_APPLE, U_ONION] self.room_members = [U_APPLE, U_ONION]
self.get_success( self.get_success(
@ -217,7 +222,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
try_trailing_slash_on_400=True, try_trailing_slash_on_400=True,
) )
def test_started_typing_remote_recv(self): def test_started_typing_remote_recv(self) -> None:
self.room_members = [U_APPLE, U_ONION] self.room_members = [U_APPLE, U_ONION]
self.assertEqual(self.event_source.get_current_key(), 0) self.assertEqual(self.event_source.get_current_key(), 0)
@ -256,7 +261,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
], ],
) )
def test_started_typing_remote_recv_not_in_room(self): def test_started_typing_remote_recv_not_in_room(self) -> None:
self.room_members = [U_APPLE, U_ONION] self.room_members = [U_APPLE, U_ONION]
self.assertEqual(self.event_source.get_current_key(), 0) self.assertEqual(self.event_source.get_current_key(), 0)
@ -292,7 +297,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.assertEqual(events[1], 0) self.assertEqual(events[1], 0)
@override_config({"send_federation": True}) @override_config({"send_federation": True})
def test_stopped_typing(self): def test_stopped_typing(self) -> None:
self.room_members = [U_APPLE, U_BANANA, U_ONION] self.room_members = [U_APPLE, U_BANANA, U_ONION]
# Gut-wrenching # Gut-wrenching
@ -343,7 +348,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
[{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}], [{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}],
) )
def test_typing_timeout(self): def test_typing_timeout(self) -> None:
self.room_members = [U_APPLE, U_BANANA] self.room_members = [U_APPLE, U_BANANA]
self.assertEqual(self.event_source.get_current_key(), 0) self.assertEqual(self.event_source.get_current_key(), 0)

View file

@ -86,6 +86,16 @@ class ModuleApiTestCase(HomeserverTestCase):
displayname = self.get_success(self.store.get_profile_displayname("bob")) displayname = self.get_success(self.store.get_profile_displayname("bob"))
self.assertEqual(displayname, "Bobberino") self.assertEqual(displayname, "Bobberino")
def test_can_register_admin_user(self):
user_id = self.get_success(
self.register_user(
"bob_module_admin", "1234", displayname="Bobberino Admin", admin=True
)
)
found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id))
self.assertEqual(found_user.user_id.to_string(), user_id)
self.assertIdentical(found_user.is_admin, True)
def test_get_userinfo_by_id(self): def test_get_userinfo_by_id(self):
user_id = self.register_user("alice", "1234") user_id = self.register_user("alice", "1234")
found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id)) found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id))

View file

@ -11,15 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List, Tuple from typing import Any, Dict, List, Optional, Tuple
from unittest.mock import Mock from unittest.mock import Mock
from twisted.internet.defer import Deferred from twisted.internet.defer import Deferred
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin import synapse.rest.admin
from synapse.logging.context import make_deferred_yieldable from synapse.logging.context import make_deferred_yieldable
from synapse.push import PusherConfigException from synapse.push import PusherConfigException
from synapse.rest.client import login, push_rule, receipts, room from synapse.rest.client import login, push_rule, receipts, room
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
from tests.unittest import HomeserverTestCase, override_config from tests.unittest import HomeserverTestCase, override_config
@ -35,13 +39,13 @@ class HTTPPusherTests(HomeserverTestCase):
user_id = True user_id = True
hijack_auth = False hijack_auth = False
def default_config(self): def default_config(self) -> Dict[str, Any]:
config = super().default_config() config = super().default_config()
config["start_pushers"] = True config["start_pushers"] = True
return config return config
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.push_attempts: List[tuple[Deferred, str, dict]] = [] self.push_attempts: List[Tuple[Deferred, str, dict]] = []
m = Mock() m = Mock()
@ -56,7 +60,7 @@ class HTTPPusherTests(HomeserverTestCase):
return hs return hs
def test_invalid_configuration(self): def test_invalid_configuration(self) -> None:
"""Invalid push configurations should be rejected.""" """Invalid push configurations should be rejected."""
# Register the user who gets notified # Register the user who gets notified
user_id = self.register_user("user", "pass") user_id = self.register_user("user", "pass")
@ -68,7 +72,7 @@ class HTTPPusherTests(HomeserverTestCase):
) )
token_id = user_tuple.token_id token_id = user_tuple.token_id
def test_data(data): def test_data(data: Optional[JsonDict]) -> None:
self.get_failure( self.get_failure(
self.hs.get_pusherpool().add_pusher( self.hs.get_pusherpool().add_pusher(
user_id=user_id, user_id=user_id,
@ -95,7 +99,7 @@ class HTTPPusherTests(HomeserverTestCase):
# A url with an incorrect path isn't accepted. # A url with an incorrect path isn't accepted.
test_data({"url": "http://example.com/foo"}) test_data({"url": "http://example.com/foo"})
def test_sends_http(self): def test_sends_http(self) -> None:
""" """
The HTTP pusher will send pushes for each message to a HTTP endpoint The HTTP pusher will send pushes for each message to a HTTP endpoint
when configured to do so. when configured to do so.
@ -200,7 +204,7 @@ class HTTPPusherTests(HomeserverTestCase):
self.assertEqual(len(pushers), 1) self.assertEqual(len(pushers), 1)
self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering) self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
def test_sends_high_priority_for_encrypted(self): def test_sends_high_priority_for_encrypted(self) -> None:
""" """
The HTTP pusher will send pushes at high priority if they correspond The HTTP pusher will send pushes at high priority if they correspond
to an encrypted message. to an encrypted message.
@ -321,7 +325,7 @@ class HTTPPusherTests(HomeserverTestCase):
) )
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high") self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high")
def test_sends_high_priority_for_one_to_one_only(self): def test_sends_high_priority_for_one_to_one_only(self) -> None:
""" """
The HTTP pusher will send pushes at high priority if they correspond The HTTP pusher will send pushes at high priority if they correspond
to a message in a one-to-one room. to a message in a one-to-one room.
@ -404,7 +408,7 @@ class HTTPPusherTests(HomeserverTestCase):
# check that this is low-priority # check that this is low-priority
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low") self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
def test_sends_high_priority_for_mention(self): def test_sends_high_priority_for_mention(self) -> None:
""" """
The HTTP pusher will send pushes at high priority if they correspond The HTTP pusher will send pushes at high priority if they correspond
to a message containing the user's display name. to a message containing the user's display name.
@ -480,7 +484,7 @@ class HTTPPusherTests(HomeserverTestCase):
# check that this is low-priority # check that this is low-priority
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low") self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
def test_sends_high_priority_for_atroom(self): def test_sends_high_priority_for_atroom(self) -> None:
""" """
The HTTP pusher will send pushes at high priority if they correspond The HTTP pusher will send pushes at high priority if they correspond
to a message that contains @room. to a message that contains @room.
@ -563,7 +567,7 @@ class HTTPPusherTests(HomeserverTestCase):
# check that this is low-priority # check that this is low-priority
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low") self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
def test_push_unread_count_group_by_room(self): def test_push_unread_count_group_by_room(self) -> None:
""" """
The HTTP pusher will group unread count by number of unread rooms. The HTTP pusher will group unread count by number of unread rooms.
""" """
@ -576,7 +580,7 @@ class HTTPPusherTests(HomeserverTestCase):
self._check_push_attempt(6, 1) self._check_push_attempt(6, 1)
@override_config({"push": {"group_unread_count_by_room": False}}) @override_config({"push": {"group_unread_count_by_room": False}})
def test_push_unread_count_message_count(self): def test_push_unread_count_message_count(self) -> None:
""" """
The HTTP pusher will send the total unread message count. The HTTP pusher will send the total unread message count.
""" """
@ -589,7 +593,7 @@ class HTTPPusherTests(HomeserverTestCase):
# last read receipt # last read receipt
self._check_push_attempt(6, 3) self._check_push_attempt(6, 3)
def _test_push_unread_count(self): def _test_push_unread_count(self) -> None:
""" """
Tests that the correct unread count appears in sent push notifications Tests that the correct unread count appears in sent push notifications
@ -681,7 +685,7 @@ class HTTPPusherTests(HomeserverTestCase):
self.helper.send(room_id, body="HELLO???", tok=other_access_token) self.helper.send(room_id, body="HELLO???", tok=other_access_token)
def _advance_time_and_make_push_succeed(self, expected_push_attempts): def _advance_time_and_make_push_succeed(self, expected_push_attempts: int) -> None:
self.pump() self.pump()
self.push_attempts[expected_push_attempts - 1][0].callback({}) self.push_attempts[expected_push_attempts - 1][0].callback({})
@ -708,7 +712,9 @@ class HTTPPusherTests(HomeserverTestCase):
expected_unread_count_last_push, expected_unread_count_last_push,
) )
def _send_read_request(self, access_token, message_event_id, room_id): def _send_read_request(
self, access_token: str, message_event_id: str, room_id: str
) -> None:
# Now set the user's read receipt position to the first event # Now set the user's read receipt position to the first event
# #
# This will actually trigger a new notification to be sent out so that # This will actually trigger a new notification to be sent out so that
@ -748,7 +754,7 @@ class HTTPPusherTests(HomeserverTestCase):
return user_id, access_token return user_id, access_token
def test_dont_notify_rule_overrides_message(self): def test_dont_notify_rule_overrides_message(self) -> None:
""" """
The override push rule will suppress notification The override push rule will suppress notification
""" """

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Dict from typing import Dict, Optional, Union
import frozendict import frozendict
@ -20,12 +20,13 @@ from synapse.api.room_versions import RoomVersions
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
from synapse.push import push_rule_evaluator from synapse.push import push_rule_evaluator
from synapse.push.push_rule_evaluator import PushRuleEvaluatorForEvent from synapse.push.push_rule_evaluator import PushRuleEvaluatorForEvent
from synapse.types import JsonDict
from tests import unittest from tests import unittest
class PushRuleEvaluatorTestCase(unittest.TestCase): class PushRuleEvaluatorTestCase(unittest.TestCase):
def _get_evaluator(self, content): def _get_evaluator(self, content: JsonDict) -> PushRuleEvaluatorForEvent:
event = FrozenEvent( event = FrozenEvent(
{ {
"event_id": "$event_id", "event_id": "$event_id",
@ -39,12 +40,12 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
) )
room_member_count = 0 room_member_count = 0
sender_power_level = 0 sender_power_level = 0
power_levels = {} power_levels: Dict[str, Union[int, Dict[str, int]]] = {}
return PushRuleEvaluatorForEvent( return PushRuleEvaluatorForEvent(
event, room_member_count, sender_power_level, power_levels event, room_member_count, sender_power_level, power_levels
) )
def test_display_name(self): def test_display_name(self) -> None:
"""Check for a matching display name in the body of the event.""" """Check for a matching display name in the body of the event."""
evaluator = self._get_evaluator({"body": "foo bar baz"}) evaluator = self._get_evaluator({"body": "foo bar baz"})
@ -71,20 +72,20 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
self.assertTrue(evaluator.matches(condition, "@user:test", "foo bar")) self.assertTrue(evaluator.matches(condition, "@user:test", "foo bar"))
def _assert_matches( def _assert_matches(
self, condition: Dict[str, Any], content: Dict[str, Any], msg=None self, condition: JsonDict, content: JsonDict, msg: Optional[str] = None
) -> None: ) -> None:
evaluator = self._get_evaluator(content) evaluator = self._get_evaluator(content)
self.assertTrue(evaluator.matches(condition, "@user:test", "display_name"), msg) self.assertTrue(evaluator.matches(condition, "@user:test", "display_name"), msg)
def _assert_not_matches( def _assert_not_matches(
self, condition: Dict[str, Any], content: Dict[str, Any], msg=None self, condition: JsonDict, content: JsonDict, msg: Optional[str] = None
) -> None: ) -> None:
evaluator = self._get_evaluator(content) evaluator = self._get_evaluator(content)
self.assertFalse( self.assertFalse(
evaluator.matches(condition, "@user:test", "display_name"), msg evaluator.matches(condition, "@user:test", "display_name"), msg
) )
def test_event_match_body(self): def test_event_match_body(self) -> None:
"""Check that event_match conditions on content.body work as expected""" """Check that event_match conditions on content.body work as expected"""
# if the key is `content.body`, the pattern matches substrings. # if the key is `content.body`, the pattern matches substrings.
@ -165,7 +166,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
r"? after \ should match any character", r"? after \ should match any character",
) )
def test_event_match_non_body(self): def test_event_match_non_body(self) -> None:
"""Check that event_match conditions on other keys work as expected""" """Check that event_match conditions on other keys work as expected"""
# if the key is anything other than 'content.body', the pattern must match the # if the key is anything other than 'content.body', the pattern must match the
@ -241,7 +242,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
"pattern should not match before a newline", "pattern should not match before a newline",
) )
def test_no_body(self): def test_no_body(self) -> None:
"""Not having a body shouldn't break the evaluator.""" """Not having a body shouldn't break the evaluator."""
evaluator = self._get_evaluator({}) evaluator = self._get_evaluator({})
@ -250,7 +251,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
} }
self.assertFalse(evaluator.matches(condition, "@user:test", "foo")) self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
def test_invalid_body(self): def test_invalid_body(self) -> None:
"""A non-string body should not break the evaluator.""" """A non-string body should not break the evaluator."""
condition = { condition = {
"kind": "contains_display_name", "kind": "contains_display_name",
@ -260,7 +261,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
evaluator = self._get_evaluator({"body": body}) evaluator = self._get_evaluator({"body": body})
self.assertFalse(evaluator.matches(condition, "@user:test", "foo")) self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
def test_tweaks_for_actions(self): def test_tweaks_for_actions(self) -> None:
""" """
This tests the behaviour of tweaks_for_actions. This tests the behaviour of tweaks_for_actions.
""" """

View file

@ -1050,6 +1050,25 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
self._is_erased("@user:test", True) self._is_erased("@user:test", True)
@override_config({"max_avatar_size": 1234})
def test_deactivate_user_erase_true_avatar_nonnull_but_empty(self) -> None:
"""Check we can erase a user whose avatar is the empty string.
Reproduces #12257.
"""
# Patch `self.other_user` to have an empty string as their avatar.
self.get_success(self.store.set_profile_avatar_url("user", ""))
# Check we can still erase them.
channel = self.make_request(
"POST",
self.url,
access_token=self.admin_user_tok,
content={"erase": True},
)
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self._is_erased("@user:test", True)
def test_deactivate_user_erase_false(self) -> None: def test_deactivate_user_erase_false(self) -> None:
""" """
Test deactivating a user and set `erase` to `false` Test deactivating a user and set `erase` to `false`

View file

@ -31,7 +31,7 @@ from synapse.rest import admin
from synapse.rest.client import account, login, register, room from synapse.rest.client import account, login, register, room
from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import JsonDict from synapse.types import JsonDict, UserID
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
@ -1222,6 +1222,62 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
expected_failures=[users[2]], expected_failures=[users[2]],
) )
@unittest.override_config(
{
"use_account_validity_in_account_status": True,
}
)
def test_no_account_validity(self) -> None:
"""Tests that if we decide to include account validity in the response but no
account validity 'is_user_expired' callback is provided, we default to marking all
users as not expired.
"""
user = self.register_user("someuser", "password")
self._test_status(
users=[user],
expected_statuses={
user: {
"exists": True,
"deactivated": False,
"org.matrix.expired": False,
},
},
expected_failures=[],
)
@unittest.override_config(
{
"use_account_validity_in_account_status": True,
}
)
def test_account_validity_expired(self) -> None:
"""Test that if we decide to include account validity in the response and the user
is expired, we return the correct info.
"""
user = self.register_user("someuser", "password")
async def is_expired(user_id: str) -> bool:
# We can't blindly say everyone is expired, otherwise the request to get the
# account status will fail.
return UserID.from_string(user_id).localpart == "someuser"
self.hs.get_account_validity_handler()._is_user_expired_callbacks.append(
is_expired
)
self._test_status(
users=[user],
expected_statuses={
user: {
"exists": True,
"deactivated": False,
"org.matrix.expired": True,
},
},
expected_failures=[],
)
def _test_status( def _test_status(
self, self,
users: Optional[List[str]], users: Optional[List[str]],

View file

@ -14,7 +14,7 @@
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin import synapse.rest.admin
from synapse.rest.client import login, room, shared_rooms from synapse.rest.client import login, mutual_rooms, room
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.util import Clock from synapse.util import Clock
@ -22,16 +22,16 @@ from tests import unittest
from tests.server import FakeChannel from tests.server import FakeChannel
class UserSharedRoomsTest(unittest.HomeserverTestCase): class UserMutualRoomsTest(unittest.HomeserverTestCase):
""" """
Tests the UserSharedRoomsServlet. Tests the UserMutualRoomsServlet.
""" """
servlets = [ servlets = [
login.register_servlets, login.register_servlets,
synapse.rest.admin.register_servlets_for_client_rest_resource, synapse.rest.admin.register_servlets_for_client_rest_resource,
room.register_servlets, room.register_servlets,
shared_rooms.register_servlets, mutual_rooms.register_servlets,
] ]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
@ -43,10 +43,10 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.handler = hs.get_user_directory_handler() self.handler = hs.get_user_directory_handler()
def _get_shared_rooms(self, token: str, other_user: str) -> FakeChannel: def _get_mutual_rooms(self, token: str, other_user: str) -> FakeChannel:
return self.make_request( return self.make_request(
"GET", "GET",
"/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/%s" "/_matrix/client/unstable/uk.half-shot.msc2666/user/mutual_rooms/%s"
% other_user, % other_user,
access_token=token, access_token=token,
) )
@ -56,14 +56,14 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
A room should show up in the shared list of rooms between two users A room should show up in the shared list of rooms between two users
if it is public. if it is public.
""" """
self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=True) self._check_mutual_rooms_with(room_one_is_public=True, room_two_is_public=True)
def test_shared_room_list_private(self) -> None: def test_shared_room_list_private(self) -> None:
""" """
A room should show up in the shared list of rooms between two users A room should show up in the shared list of rooms between two users
if it is private. if it is private.
""" """
self._check_shared_rooms_with( self._check_mutual_rooms_with(
room_one_is_public=False, room_two_is_public=False room_one_is_public=False, room_two_is_public=False
) )
@ -72,9 +72,9 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
The shared room list between two users should contain both public and private The shared room list between two users should contain both public and private
rooms. rooms.
""" """
self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=False) self._check_mutual_rooms_with(room_one_is_public=True, room_two_is_public=False)
def _check_shared_rooms_with( def _check_mutual_rooms_with(
self, room_one_is_public: bool, room_two_is_public: bool self, room_one_is_public: bool, room_two_is_public: bool
) -> None: ) -> None:
"""Checks that shared public or private rooms between two users appear in """Checks that shared public or private rooms between two users appear in
@ -94,7 +94,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
# Check shared rooms from user1's perspective. # Check shared rooms from user1's perspective.
# We should see the one room in common # We should see the one room in common
channel = self._get_shared_rooms(u1_token, u2) channel = self._get_mutual_rooms(u1_token, u2)
self.assertEqual(200, channel.code, channel.result) self.assertEqual(200, channel.code, channel.result)
self.assertEqual(len(channel.json_body["joined"]), 1) self.assertEqual(len(channel.json_body["joined"]), 1)
self.assertEqual(channel.json_body["joined"][0], room_id_one) self.assertEqual(channel.json_body["joined"][0], room_id_one)
@ -107,7 +107,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
self.helper.join(room_id_two, user=u2, tok=u2_token) self.helper.join(room_id_two, user=u2, tok=u2_token)
# Check shared rooms again. We should now see both rooms. # Check shared rooms again. We should now see both rooms.
channel = self._get_shared_rooms(u1_token, u2) channel = self._get_mutual_rooms(u1_token, u2)
self.assertEqual(200, channel.code, channel.result) self.assertEqual(200, channel.code, channel.result)
self.assertEqual(len(channel.json_body["joined"]), 2) self.assertEqual(len(channel.json_body["joined"]), 2)
for room_id_id in channel.json_body["joined"]: for room_id_id in channel.json_body["joined"]:
@ -128,7 +128,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
self.helper.join(room, user=u2, tok=u2_token) self.helper.join(room, user=u2, tok=u2_token)
# Assert user directory is not empty # Assert user directory is not empty
channel = self._get_shared_rooms(u1_token, u2) channel = self._get_mutual_rooms(u1_token, u2)
self.assertEqual(200, channel.code, channel.result) self.assertEqual(200, channel.code, channel.result)
self.assertEqual(len(channel.json_body["joined"]), 1) self.assertEqual(len(channel.json_body["joined"]), 1)
self.assertEqual(channel.json_body["joined"][0], room) self.assertEqual(channel.json_body["joined"][0], room)
@ -136,11 +136,11 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
self.helper.leave(room, user=u1, tok=u1_token) self.helper.leave(room, user=u1, tok=u1_token)
# Check user1's view of shared rooms with user2 # Check user1's view of shared rooms with user2
channel = self._get_shared_rooms(u1_token, u2) channel = self._get_mutual_rooms(u1_token, u2)
self.assertEqual(200, channel.code, channel.result) self.assertEqual(200, channel.code, channel.result)
self.assertEqual(len(channel.json_body["joined"]), 0) self.assertEqual(len(channel.json_body["joined"]), 0)
# Check user2's view of shared rooms with user1 # Check user2's view of shared rooms with user1
channel = self._get_shared_rooms(u2_token, u1) channel = self._get_mutual_rooms(u2_token, u1)
self.assertEqual(200, channel.code, channel.result) self.assertEqual(200, channel.code, channel.result)
self.assertEqual(len(channel.json_body["joined"]), 0) self.assertEqual(len(channel.json_body["joined"]), 0)

File diff suppressed because it is too large Load diff

View file

@ -16,7 +16,6 @@ from synapse.rest.media.v1.preview_html import (
_get_html_media_encodings, _get_html_media_encodings,
decode_body, decode_body,
parse_html_to_open_graph, parse_html_to_open_graph,
rebase_url,
summarize_paragraphs, summarize_paragraphs,
) )
@ -161,7 +160,7 @@ class CalcOgTestCase(unittest.TestCase):
""" """
tree = decode_body(html, "http://example.com/test.html") tree = decode_body(html, "http://example.com/test.html")
og = parse_html_to_open_graph(tree, "http://example.com/test.html") og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
@ -177,7 +176,7 @@ class CalcOgTestCase(unittest.TestCase):
""" """
tree = decode_body(html, "http://example.com/test.html") tree = decode_body(html, "http://example.com/test.html")
og = parse_html_to_open_graph(tree, "http://example.com/test.html") og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
@ -196,7 +195,7 @@ class CalcOgTestCase(unittest.TestCase):
""" """
tree = decode_body(html, "http://example.com/test.html") tree = decode_body(html, "http://example.com/test.html")
og = parse_html_to_open_graph(tree, "http://example.com/test.html") og = parse_html_to_open_graph(tree)
self.assertEqual( self.assertEqual(
og, og,
@ -218,7 +217,7 @@ class CalcOgTestCase(unittest.TestCase):
""" """
tree = decode_body(html, "http://example.com/test.html") tree = decode_body(html, "http://example.com/test.html")
og = parse_html_to_open_graph(tree, "http://example.com/test.html") og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
@ -232,7 +231,7 @@ class CalcOgTestCase(unittest.TestCase):
""" """
tree = decode_body(html, "http://example.com/test.html") tree = decode_body(html, "http://example.com/test.html")
og = parse_html_to_open_graph(tree, "http://example.com/test.html") og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": None, "og:description": "Some text."}) self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
@ -247,7 +246,7 @@ class CalcOgTestCase(unittest.TestCase):
""" """
tree = decode_body(html, "http://example.com/test.html") tree = decode_body(html, "http://example.com/test.html")
og = parse_html_to_open_graph(tree, "http://example.com/test.html") og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."}) self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."})
@ -262,7 +261,7 @@ class CalcOgTestCase(unittest.TestCase):
""" """
tree = decode_body(html, "http://example.com/test.html") tree = decode_body(html, "http://example.com/test.html")
og = parse_html_to_open_graph(tree, "http://example.com/test.html") og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": None, "og:description": "Some text."}) self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
@ -290,7 +289,7 @@ class CalcOgTestCase(unittest.TestCase):
<head><title>Foo</title></head><body>Some text.</body></html> <head><title>Foo</title></head><body>Some text.</body></html>
""".strip() """.strip()
tree = decode_body(html, "http://example.com/test.html") tree = decode_body(html, "http://example.com/test.html")
og = parse_html_to_open_graph(tree, "http://example.com/test.html") og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
def test_invalid_encoding(self) -> None: def test_invalid_encoding(self) -> None:
@ -304,7 +303,7 @@ class CalcOgTestCase(unittest.TestCase):
</html> </html>
""" """
tree = decode_body(html, "http://example.com/test.html", "invalid-encoding") tree = decode_body(html, "http://example.com/test.html", "invalid-encoding")
og = parse_html_to_open_graph(tree, "http://example.com/test.html") og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
def test_invalid_encoding2(self) -> None: def test_invalid_encoding2(self) -> None:
@ -319,7 +318,7 @@ class CalcOgTestCase(unittest.TestCase):
</html> </html>
""" """
tree = decode_body(html, "http://example.com/test.html") tree = decode_body(html, "http://example.com/test.html")
og = parse_html_to_open_graph(tree, "http://example.com/test.html") og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."}) self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."})
def test_windows_1252(self) -> None: def test_windows_1252(self) -> None:
@ -333,7 +332,7 @@ class CalcOgTestCase(unittest.TestCase):
</html> </html>
""" """
tree = decode_body(html, "http://example.com/test.html") tree = decode_body(html, "http://example.com/test.html")
og = parse_html_to_open_graph(tree, "http://example.com/test.html") og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "ó", "og:description": "Some text."}) self.assertEqual(og, {"og:title": "ó", "og:description": "Some text."})
@ -448,34 +447,3 @@ class MediaEncodingTestCase(unittest.TestCase):
'text/html; charset="invalid"', 'text/html; charset="invalid"',
) )
self.assertEqual(list(encodings), ["utf-8", "cp1252"]) self.assertEqual(list(encodings), ["utf-8", "cp1252"])
class RebaseUrlTestCase(unittest.TestCase):
def test_relative(self) -> None:
"""Relative URLs should be resolved based on the context of the base URL."""
self.assertEqual(
rebase_url("subpage", "https://example.com/foo/"),
"https://example.com/foo/subpage",
)
self.assertEqual(
rebase_url("sibling", "https://example.com/foo"),
"https://example.com/sibling",
)
self.assertEqual(
rebase_url("/bar", "https://example.com/foo/"),
"https://example.com/bar",
)
def test_absolute(self) -> None:
"""Absolute URLs should not be modified."""
self.assertEqual(
rebase_url("https://alice.com/a/", "https://example.com/foo/"),
"https://alice.com/a/",
)
def test_data(self) -> None:
"""Data URLs should not be modified."""
self.assertEqual(
rebase_url("data:,Hello%2C%20World%21", "https://example.com/foo/"),
"data:,Hello%2C%20World%21",
)

View file

@ -54,13 +54,18 @@ from twisted.internet.interfaces import (
ITransport, ITransport,
) )
from twisted.python.failure import Failure from twisted.python.failure import Failure
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock from twisted.test.proto_helpers import (
AccumulatingProtocol,
MemoryReactor,
MemoryReactorClock,
)
from twisted.web.http_headers import Headers from twisted.web.http_headers import Headers
from twisted.web.resource import IResource from twisted.web.resource import IResource
from twisted.web.server import Request, Site from twisted.web.server import Request, Site
from synapse.config.database import DatabaseConnectionConfig from synapse.config.database import DatabaseConnectionConfig
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.logging.context import ContextResourceUsage
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage import DataStore from synapse.storage import DataStore
from synapse.storage.engines import PostgresEngine, create_engine from synapse.storage.engines import PostgresEngine, create_engine
@ -88,18 +93,19 @@ class TimedOutException(Exception):
""" """
@attr.s @attr.s(auto_attribs=True)
class FakeChannel: class FakeChannel:
""" """
A fake Twisted Web Channel (the part that interfaces with the A fake Twisted Web Channel (the part that interfaces with the
wire). wire).
""" """
site = attr.ib(type=Union[Site, "FakeSite"]) site: Union[Site, "FakeSite"]
_reactor = attr.ib() _reactor: MemoryReactor
result = attr.ib(type=dict, default=attr.Factory(dict)) result: dict = attr.Factory(dict)
_ip = attr.ib(type=str, default="127.0.0.1") _ip: str = "127.0.0.1"
_producer: Optional[Union[IPullProducer, IPushProducer]] = None _producer: Optional[Union[IPullProducer, IPushProducer]] = None
resource_usage: Optional[ContextResourceUsage] = None
@property @property
def json_body(self): def json_body(self):
@ -168,6 +174,8 @@ class FakeChannel:
def requestDone(self, _self): def requestDone(self, _self):
self.result["done"] = True self.result["done"] = True
if isinstance(_self, SynapseRequest):
self.resource_usage = _self.logcontext.get_resource_usage()
def getPeer(self): def getPeer(self):
# We give an address so that getClientIP returns a non null entry, # We give an address so that getClientIP returns a non null entry,

View file

@ -47,9 +47,18 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
expected_ignorer_user_ids, expected_ignorer_user_ids,
) )
def assert_ignored(
self, ignorer_user_id: str, expected_ignored_user_ids: Set[str]
) -> None:
self.assertEqual(
self.get_success(self.store.ignored_users(ignorer_user_id)),
expected_ignored_user_ids,
)
def test_ignoring_users(self): def test_ignoring_users(self):
"""Basic adding/removing of users from the ignore list.""" """Basic adding/removing of users from the ignore list."""
self._update_ignore_list("@other:test", "@another:remote") self._update_ignore_list("@other:test", "@another:remote")
self.assert_ignored(self.user, {"@other:test", "@another:remote"})
# Check a user which no one ignores. # Check a user which no one ignores.
self.assert_ignorers("@user:test", set()) self.assert_ignorers("@user:test", set())
@ -62,6 +71,7 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
# Add one user, remove one user, and leave one user. # Add one user, remove one user, and leave one user.
self._update_ignore_list("@foo:test", "@another:remote") self._update_ignore_list("@foo:test", "@another:remote")
self.assert_ignored(self.user, {"@foo:test", "@another:remote"})
# Check the removed user. # Check the removed user.
self.assert_ignorers("@other:test", set()) self.assert_ignorers("@other:test", set())
@ -76,20 +86,24 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
"""Ensure that caching works properly between different users.""" """Ensure that caching works properly between different users."""
# The first user ignores a user. # The first user ignores a user.
self._update_ignore_list("@other:test") self._update_ignore_list("@other:test")
self.assert_ignored(self.user, {"@other:test"})
self.assert_ignorers("@other:test", {self.user}) self.assert_ignorers("@other:test", {self.user})
# The second user ignores them. # The second user ignores them.
self._update_ignore_list("@other:test", ignorer_user_id="@second:test") self._update_ignore_list("@other:test", ignorer_user_id="@second:test")
self.assert_ignored("@second:test", {"@other:test"})
self.assert_ignorers("@other:test", {self.user, "@second:test"}) self.assert_ignorers("@other:test", {self.user, "@second:test"})
# The first user un-ignores them. # The first user un-ignores them.
self._update_ignore_list() self._update_ignore_list()
self.assert_ignored(self.user, set())
self.assert_ignorers("@other:test", {"@second:test"}) self.assert_ignorers("@other:test", {"@second:test"})
def test_invalid_data(self): def test_invalid_data(self):
"""Invalid data ends up clearing out the ignored users list.""" """Invalid data ends up clearing out the ignored users list."""
# Add some data and ensure it is there. # Add some data and ensure it is there.
self._update_ignore_list("@other:test") self._update_ignore_list("@other:test")
self.assert_ignored(self.user, {"@other:test"})
self.assert_ignorers("@other:test", {self.user}) self.assert_ignorers("@other:test", {self.user})
# No ignored_users key. # No ignored_users key.
@ -102,10 +116,12 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
) )
# No one ignores the user now. # No one ignores the user now.
self.assert_ignored(self.user, set())
self.assert_ignorers("@other:test", set()) self.assert_ignorers("@other:test", set())
# Add some data and ensure it is there. # Add some data and ensure it is there.
self._update_ignore_list("@other:test") self._update_ignore_list("@other:test")
self.assert_ignored(self.user, {"@other:test"})
self.assert_ignorers("@other:test", {self.user}) self.assert_ignorers("@other:test", {self.user})
# Invalid data. # Invalid data.
@ -118,4 +134,5 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
) )
# No one ignores the user now. # No one ignores the user now.
self.assert_ignored(self.user, set())
self.assert_ignorers("@other:test", set()) self.assert_ignorers("@other:test", set())

View file

@ -17,8 +17,12 @@ from unittest.mock import Mock
import yaml import yaml
from twisted.internet.defer import Deferred, ensureDeferred from twisted.internet.defer import Deferred, ensureDeferred
from twisted.test.proto_helpers import MemoryReactor
from synapse.server import HomeServer
from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.background_updates import BackgroundUpdater
from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable, simple_async_mock from tests.test_utils import make_awaitable, simple_async_mock
@ -26,7 +30,7 @@ from tests.unittest import override_config
class BackgroundUpdateTestCase(unittest.HomeserverTestCase): class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, homeserver): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.updates: BackgroundUpdater = self.hs.get_datastores().main.db_pool.updates self.updates: BackgroundUpdater = self.hs.get_datastores().main.db_pool.updates
# the base test class should have run the real bg updates for us # the base test class should have run the real bg updates for us
self.assertTrue( self.assertTrue(
@ -39,7 +43,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
) )
self.store = self.hs.get_datastores().main self.store = self.hs.get_datastores().main
async def update(self, progress, count): async def update(self, progress: JsonDict, count: int) -> int:
duration_ms = 10 duration_ms = 10
await self.clock.sleep((count * duration_ms) / 1000) await self.clock.sleep((count * duration_ms) / 1000)
progress = {"my_key": progress["my_key"] + 1} progress = {"my_key": progress["my_key"] + 1}
@ -51,7 +55,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
) )
return count return count
def test_do_background_update(self): def test_do_background_update(self) -> None:
# the time we claim it takes to update one item when running the update # the time we claim it takes to update one item when running the update
duration_ms = 10 duration_ms = 10
@ -80,7 +84,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
# second step: complete the update # second step: complete the update
# we should now get run with a much bigger number of items to update # we should now get run with a much bigger number of items to update
async def update(progress, count): async def update(progress: JsonDict, count: int) -> int:
self.assertEqual(progress, {"my_key": 2}) self.assertEqual(progress, {"my_key": 2})
self.assertAlmostEqual( self.assertAlmostEqual(
count, count,
@ -110,7 +114,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
""" """
) )
) )
def test_background_update_default_batch_set_by_config(self): def test_background_update_default_batch_set_by_config(self) -> None:
""" """
Test that the background update is run with the default_batch_size set by the config Test that the background update is run with the default_batch_size set by the config
""" """
@ -133,7 +137,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
# on the first call, we should get run with the default background update size specified in the config # on the first call, we should get run with the default background update size specified in the config
self.update_handler.assert_called_once_with({"my_key": 1}, 20) self.update_handler.assert_called_once_with({"my_key": 1}, 20)
def test_background_update_default_sleep_behavior(self): def test_background_update_default_sleep_behavior(self) -> None:
""" """
Test default background update behavior, which is to sleep Test default background update behavior, which is to sleep
""" """
@ -147,7 +151,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
self.update_handler.side_effect = self.update self.update_handler.side_effect = self.update
self.update_handler.reset_mock() self.update_handler.reset_mock()
self.updates.start_doing_background_updates(), self.updates.start_doing_background_updates()
# 2: advance the reactor less than the default sleep duration (1000ms) # 2: advance the reactor less than the default sleep duration (1000ms)
self.reactor.pump([0.5]) self.reactor.pump([0.5])
@ -167,7 +171,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
""" """
) )
) )
def test_background_update_sleep_set_in_config(self): def test_background_update_sleep_set_in_config(self) -> None:
""" """
Test that changing the sleep time in the config changes how long it sleeps Test that changing the sleep time in the config changes how long it sleeps
""" """
@ -181,7 +185,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
self.update_handler.side_effect = self.update self.update_handler.side_effect = self.update
self.update_handler.reset_mock() self.update_handler.reset_mock()
self.updates.start_doing_background_updates(), self.updates.start_doing_background_updates()
# 2: advance the reactor less than the configured sleep duration (500ms) # 2: advance the reactor less than the configured sleep duration (500ms)
self.reactor.pump([0.45]) self.reactor.pump([0.45])
@ -201,7 +205,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
""" """
) )
) )
def test_disabling_background_update_sleep(self): def test_disabling_background_update_sleep(self) -> None:
""" """
Test that disabling sleep in the config results in bg update not sleeping Test that disabling sleep in the config results in bg update not sleeping
""" """
@ -215,7 +219,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
self.update_handler.side_effect = self.update self.update_handler.side_effect = self.update
self.update_handler.reset_mock() self.update_handler.reset_mock()
self.updates.start_doing_background_updates(), self.updates.start_doing_background_updates()
# 2: advance the reactor very little # 2: advance the reactor very little
self.reactor.pump([0.025]) self.reactor.pump([0.025])
@ -230,7 +234,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
""" """
) )
) )
def test_background_update_duration_set_in_config(self): def test_background_update_duration_set_in_config(self) -> None:
""" """
Test that the desired duration set in the config is used in determining batch size Test that the desired duration set in the config is used in determining batch size
""" """
@ -254,7 +258,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
# the first update was run with the default batch size, this should be run with 500ms as the # the first update was run with the default batch size, this should be run with 500ms as the
# desired duration # desired duration
async def update(progress, count): async def update(progress: JsonDict, count: int) -> int:
self.assertEqual(progress, {"my_key": 2}) self.assertEqual(progress, {"my_key": 2})
self.assertAlmostEqual( self.assertAlmostEqual(
count, count,
@ -275,7 +279,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
""" """
) )
) )
def test_background_update_min_batch_set_in_config(self): def test_background_update_min_batch_set_in_config(self) -> None:
""" """
Test that the minimum batch size set in the config is used Test that the minimum batch size set in the config is used
""" """
@ -290,7 +294,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
) )
# Run the update with the long-running update item # Run the update with the long-running update item
async def update(progress, count): async def update_long(progress: JsonDict, count: int) -> int:
await self.clock.sleep((count * duration_ms) / 1000) await self.clock.sleep((count * duration_ms) / 1000)
progress = {"my_key": progress["my_key"] + 1} progress = {"my_key": progress["my_key"] + 1}
await self.store.db_pool.runInteraction( await self.store.db_pool.runInteraction(
@ -301,7 +305,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
) )
return count return count
self.update_handler.side_effect = update self.update_handler.side_effect = update_long
self.update_handler.reset_mock() self.update_handler.reset_mock()
res = self.get_success( res = self.get_success(
self.updates.do_next_background_update(False), self.updates.do_next_background_update(False),
@ -311,25 +315,25 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
# the first update was run with the default batch size, this should be run with minimum batch size # the first update was run with the default batch size, this should be run with minimum batch size
# as the first items took a very long time # as the first items took a very long time
async def update(progress, count): async def update_short(progress: JsonDict, count: int) -> int:
self.assertEqual(progress, {"my_key": 2}) self.assertEqual(progress, {"my_key": 2})
self.assertEqual(count, 5) self.assertEqual(count, 5)
await self.updates._end_background_update("test_update") await self.updates._end_background_update("test_update")
return count return count
self.update_handler.side_effect = update self.update_handler.side_effect = update_short
self.get_success(self.updates.do_next_background_update(False)) self.get_success(self.updates.do_next_background_update(False))
class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase): class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, homeserver): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.updates: BackgroundUpdater = self.hs.get_datastores().main.db_pool.updates self.updates: BackgroundUpdater = self.hs.get_datastores().main.db_pool.updates
# the base test class should have run the real bg updates for us # the base test class should have run the real bg updates for us
self.assertTrue( self.assertTrue(
self.get_success(self.updates.has_completed_background_updates()) self.get_success(self.updates.has_completed_background_updates())
) )
self.update_deferred = Deferred() self.update_deferred: Deferred[int] = Deferred()
self.update_handler = Mock(return_value=self.update_deferred) self.update_handler = Mock(return_value=self.update_deferred)
self.updates.register_background_update_handler( self.updates.register_background_update_handler(
"test_update", self.update_handler "test_update", self.update_handler
@ -358,7 +362,7 @@ class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
), ),
) )
def test_controller(self): def test_controller(self) -> None:
store = self.hs.get_datastores().main store = self.hs.get_datastores().main
self.get_success( self.get_success(
store.db_pool.simple_insert( store.db_pool.simple_insert(
@ -368,7 +372,7 @@ class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
) )
# Set the return value for the context manager. # Set the return value for the context manager.
enter_defer = Deferred() enter_defer: Deferred[int] = Deferred()
self._update_ctx_manager.__aenter__ = Mock(return_value=enter_defer) self._update_ctx_manager.__aenter__ = Mock(return_value=enter_defer)
# Start the background update. # Start the background update.

View file

@ -12,7 +12,20 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.storage.database import make_tuple_comparison_clause from typing import Callable, Tuple
from unittest.mock import Mock, call
from twisted.internet import defer
from twisted.internet.defer import CancelledError, Deferred
from twisted.test.proto_helpers import MemoryReactor
from synapse.server import HomeServer
from synapse.storage.database import (
DatabasePool,
LoggingTransaction,
make_tuple_comparison_clause,
)
from synapse.util import Clock
from tests import unittest from tests import unittest
@ -22,3 +35,150 @@ class TupleComparisonClauseTestCase(unittest.TestCase):
clause, args = make_tuple_comparison_clause([("a", 1), ("b", 2)]) clause, args = make_tuple_comparison_clause([("a", 1), ("b", 2)])
self.assertEqual(clause, "(a,b) > (?,?)") self.assertEqual(clause, "(a,b) > (?,?)")
self.assertEqual(args, [1, 2]) self.assertEqual(args, [1, 2])
class CallbacksTestCase(unittest.HomeserverTestCase):
"""Tests for transaction callbacks."""
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.db_pool: DatabasePool = self.store.db_pool
def _run_interaction(
self, func: Callable[[LoggingTransaction], object]
) -> Tuple[Mock, Mock]:
"""Run the given function in a database transaction, with callbacks registered.
Args:
func: The function to be run in a transaction. The transaction will be
retried if `func` raises an `OperationalError`.
Returns:
Two mocks, which were registered as an `after_callback` and an
`exception_callback` respectively, on every transaction attempt.
"""
after_callback = Mock()
exception_callback = Mock()
def _test_txn(txn: LoggingTransaction) -> None:
txn.call_after(after_callback, 123, 456, extra=789)
txn.call_on_exception(exception_callback, 987, 654, extra=321)
func(txn)
try:
self.get_success_or_raise(
self.db_pool.runInteraction("test_transaction", _test_txn)
)
except Exception:
pass
return after_callback, exception_callback
def test_after_callback(self) -> None:
"""Test that the after callback is called when a transaction succeeds."""
after_callback, exception_callback = self._run_interaction(lambda txn: None)
after_callback.assert_called_once_with(123, 456, extra=789)
exception_callback.assert_not_called()
def test_exception_callback(self) -> None:
"""Test that the exception callback is called when a transaction fails."""
_test_txn = Mock(side_effect=ZeroDivisionError)
after_callback, exception_callback = self._run_interaction(_test_txn)
after_callback.assert_not_called()
exception_callback.assert_called_once_with(987, 654, extra=321)
def test_failed_retry(self) -> None:
"""Test that the exception callback is called for every failed attempt."""
# Always raise an `OperationalError`.
_test_txn = Mock(side_effect=self.db_pool.engine.module.OperationalError)
after_callback, exception_callback = self._run_interaction(_test_txn)
after_callback.assert_not_called()
exception_callback.assert_has_calls(
[
call(987, 654, extra=321),
call(987, 654, extra=321),
call(987, 654, extra=321),
call(987, 654, extra=321),
call(987, 654, extra=321),
call(987, 654, extra=321),
]
)
self.assertEqual(exception_callback.call_count, 6) # no additional calls
def test_successful_retry(self) -> None:
"""Test callbacks for a failed transaction followed by a successful attempt."""
# Raise an `OperationalError` on the first attempt only.
_test_txn = Mock(
side_effect=[self.db_pool.engine.module.OperationalError, None]
)
after_callback, exception_callback = self._run_interaction(_test_txn)
# Calling both `after_callback`s when the first attempt failed is rather
# surprising (#12184). Let's document the behaviour in a test.
after_callback.assert_has_calls(
[
call(123, 456, extra=789),
call(123, 456, extra=789),
]
)
self.assertEqual(after_callback.call_count, 2) # no additional calls
exception_callback.assert_not_called()
class CancellationTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.db_pool: DatabasePool = self.store.db_pool
def test_after_callback(self) -> None:
"""Test that the after callback is called when a transaction succeeds."""
d: "Deferred[None]"
after_callback = Mock()
exception_callback = Mock()
def _test_txn(txn: LoggingTransaction) -> None:
txn.call_after(after_callback, 123, 456, extra=789)
txn.call_on_exception(exception_callback, 987, 654, extra=321)
d.cancel()
d = defer.ensureDeferred(
self.db_pool.runInteraction("test_transaction", _test_txn)
)
self.get_failure(d, CancelledError)
after_callback.assert_called_once_with(123, 456, extra=789)
exception_callback.assert_not_called()
def test_exception_callback(self) -> None:
"""Test that the exception callback is called when a transaction fails."""
d: "Deferred[None]"
after_callback = Mock()
exception_callback = Mock()
def _test_txn(txn: LoggingTransaction) -> None:
txn.call_after(after_callback, 123, 456, extra=789)
txn.call_on_exception(exception_callback, 987, 654, extra=321)
d.cancel()
# Simulate a retryable failure on every attempt.
raise self.db_pool.engine.module.OperationalError()
d = defer.ensureDeferred(
self.db_pool.runInteraction("test_transaction", _test_txn)
)
self.get_failure(d, CancelledError)
after_callback.assert_not_called()
exception_callback.assert_has_calls(
[
call(987, 654, extra=321),
call(987, 654, extra=321),
call(987, 654, extra=321),
call(987, 654, extra=321),
call(987, 654, extra=321),
call(987, 654, extra=321),
]
)
self.assertEqual(exception_callback.call_count, 6) # no additional calls

View file

@ -13,9 +13,13 @@
# limitations under the License. # limitations under the License.
from typing import List, Optional from typing import List, Optional
from synapse.storage.database import DatabasePool from twisted.test.proto_helpers import MemoryReactor
from synapse.server import HomeServer
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.engines import IncorrectDatabaseSetup from synapse.storage.engines import IncorrectDatabaseSetup
from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.util import Clock
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
from tests.utils import USE_POSTGRES_FOR_TESTS from tests.utils import USE_POSTGRES_FOR_TESTS
@ -25,13 +29,13 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
if not USE_POSTGRES_FOR_TESTS: if not USE_POSTGRES_FOR_TESTS:
skip = "Requires Postgres" skip = "Requires Postgres"
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.db_pool: DatabasePool = self.store.db_pool self.db_pool: DatabasePool = self.store.db_pool
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
def _setup_db(self, txn): def _setup_db(self, txn: LoggingTransaction) -> None:
txn.execute("CREATE SEQUENCE foobar_seq") txn.execute("CREATE SEQUENCE foobar_seq")
txn.execute( txn.execute(
""" """
@ -59,12 +63,12 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
return self.get_success_or_raise(self.db_pool.runWithConnection(_create)) return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
def _insert_rows(self, instance_name: str, number: int): def _insert_rows(self, instance_name: str, number: int) -> None:
"""Insert N rows as the given instance, inserting with stream IDs pulled """Insert N rows as the given instance, inserting with stream IDs pulled
from the postgres sequence. from the postgres sequence.
""" """
def _insert(txn): def _insert(txn: LoggingTransaction) -> None:
for _ in range(number): for _ in range(number):
txn.execute( txn.execute(
"INSERT INTO foobar VALUES (nextval('foobar_seq'), ?)", "INSERT INTO foobar VALUES (nextval('foobar_seq'), ?)",
@ -80,12 +84,12 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.get_success(self.db_pool.runInteraction("_insert_rows", _insert)) self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
def _insert_row_with_id(self, instance_name: str, stream_id: int): def _insert_row_with_id(self, instance_name: str, stream_id: int) -> None:
"""Insert one row as the given instance with given stream_id, updating """Insert one row as the given instance with given stream_id, updating
the postgres sequence position to match. the postgres sequence position to match.
""" """
def _insert(txn): def _insert(txn: LoggingTransaction) -> None:
txn.execute( txn.execute(
"INSERT INTO foobar VALUES (?, ?)", "INSERT INTO foobar VALUES (?, ?)",
( (
@ -104,7 +108,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert)) self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert))
def test_empty(self): def test_empty(self) -> None:
"""Test an ID generator against an empty database gives sensible """Test an ID generator against an empty database gives sensible
current positions. current positions.
""" """
@ -114,7 +118,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# The table is empty so we expect an empty map for positions # The table is empty so we expect an empty map for positions
self.assertEqual(id_gen.get_positions(), {}) self.assertEqual(id_gen.get_positions(), {})
def test_single_instance(self): def test_single_instance(self) -> None:
"""Test that reads and writes from a single process are handled """Test that reads and writes from a single process are handled
correctly. correctly.
""" """
@ -130,7 +134,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# Try allocating a new ID gen and check that we only see position # Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager. # advanced after we leave the context manager.
async def _get_next_async(): async def _get_next_async() -> None:
async with id_gen.get_next() as stream_id: async with id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 8) self.assertEqual(stream_id, 8)
@ -142,7 +146,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen.get_positions(), {"master": 8}) self.assertEqual(id_gen.get_positions(), {"master": 8})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 8) self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
def test_out_of_order_finish(self): def test_out_of_order_finish(self) -> None:
"""Test that IDs persisted out of order are correctly handled""" """Test that IDs persisted out of order are correctly handled"""
# Prefill table with 7 rows written by 'master' # Prefill table with 7 rows written by 'master'
@ -191,7 +195,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen.get_positions(), {"master": 11}) self.assertEqual(id_gen.get_positions(), {"master": 11})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 11) self.assertEqual(id_gen.get_current_token_for_writer("master"), 11)
def test_multi_instance(self): def test_multi_instance(self) -> None:
"""Test that reads and writes from multiple processes are handled """Test that reads and writes from multiple processes are handled
correctly. correctly.
""" """
@ -215,7 +219,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# Try allocating a new ID gen and check that we only see position # Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager. # advanced after we leave the context manager.
async def _get_next_async(): async def _get_next_async() -> None:
async with first_id_gen.get_next() as stream_id: async with first_id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 8) self.assertEqual(stream_id, 8)
@ -233,7 +237,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# ... but calling `get_next` on the second instance should give a unique # ... but calling `get_next` on the second instance should give a unique
# stream ID # stream ID
async def _get_next_async(): async def _get_next_async2() -> None:
async with second_id_gen.get_next() as stream_id: async with second_id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 9) self.assertEqual(stream_id, 9)
@ -241,7 +245,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
second_id_gen.get_positions(), {"first": 3, "second": 7} second_id_gen.get_positions(), {"first": 3, "second": 7}
) )
self.get_success(_get_next_async()) self.get_success(_get_next_async2())
self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 9}) self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 9})
@ -249,7 +253,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
second_id_gen.advance("first", 8) second_id_gen.advance("first", 8)
self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9}) self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9})
def test_get_next_txn(self): def test_get_next_txn(self) -> None:
"""Test that the `get_next_txn` function works correctly.""" """Test that the `get_next_txn` function works correctly."""
# Prefill table with 7 rows written by 'master' # Prefill table with 7 rows written by 'master'
@ -263,7 +267,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# Try allocating a new ID gen and check that we only see position # Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager. # advanced after we leave the context manager.
def _get_next_txn(txn): def _get_next_txn(txn: LoggingTransaction) -> None:
stream_id = id_gen.get_next_txn(txn) stream_id = id_gen.get_next_txn(txn)
self.assertEqual(stream_id, 8) self.assertEqual(stream_id, 8)
@ -275,7 +279,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen.get_positions(), {"master": 8}) self.assertEqual(id_gen.get_positions(), {"master": 8})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 8) self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
def test_get_persisted_upto_position(self): def test_get_persisted_upto_position(self) -> None:
"""Test that `get_persisted_upto_position` correctly tracks updates to """Test that `get_persisted_upto_position` correctly tracks updates to
positions. positions.
""" """
@ -317,7 +321,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
id_gen.advance("second", 15) id_gen.advance("second", 15)
self.assertEqual(id_gen.get_persisted_upto_position(), 11) self.assertEqual(id_gen.get_persisted_upto_position(), 11)
def test_get_persisted_upto_position_get_next(self): def test_get_persisted_upto_position_get_next(self) -> None:
"""Test that `get_persisted_upto_position` correctly tracks updates to """Test that `get_persisted_upto_position` correctly tracks updates to
positions when `get_next` is called. positions when `get_next` is called.
""" """
@ -331,7 +335,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen.get_persisted_upto_position(), 5) self.assertEqual(id_gen.get_persisted_upto_position(), 5)
async def _get_next_async(): async def _get_next_async() -> None:
async with id_gen.get_next() as stream_id: async with id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 6) self.assertEqual(stream_id, 6)
self.assertEqual(id_gen.get_persisted_upto_position(), 5) self.assertEqual(id_gen.get_persisted_upto_position(), 5)
@ -344,7 +348,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# `persisted_upto_position` in this case, then it will be correct in the # `persisted_upto_position` in this case, then it will be correct in the
# other cases that are tested above (since they'll hit the same code). # other cases that are tested above (since they'll hit the same code).
def test_restart_during_out_of_order_persistence(self): def test_restart_during_out_of_order_persistence(self) -> None:
"""Test that restarting a process while another process is writing out """Test that restarting a process while another process is writing out
of order updates are handled correctly. of order updates are handled correctly.
""" """
@ -388,7 +392,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
id_gen_worker.advance("master", 9) id_gen_worker.advance("master", 9)
self.assertEqual(id_gen_worker.get_positions(), {"master": 9}) self.assertEqual(id_gen_worker.get_positions(), {"master": 9})
def test_writer_config_change(self): def test_writer_config_change(self) -> None:
"""Test that changing the writer config correctly works.""" """Test that changing the writer config correctly works."""
self._insert_row_with_id("first", 3) self._insert_row_with_id("first", 3)
@ -421,7 +425,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# Check that we get a sane next stream ID with this new config. # Check that we get a sane next stream ID with this new config.
async def _get_next_async(): async def _get_next_async() -> None:
async with id_gen_3.get_next() as stream_id: async with id_gen_3.get_next() as stream_id:
self.assertEqual(stream_id, 6) self.assertEqual(stream_id, 6)
@ -435,7 +439,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen_5.get_current_token_for_writer("first"), 6) self.assertEqual(id_gen_5.get_current_token_for_writer("first"), 6)
self.assertEqual(id_gen_5.get_current_token_for_writer("third"), 6) self.assertEqual(id_gen_5.get_current_token_for_writer("third"), 6)
def test_sequence_consistency(self): def test_sequence_consistency(self) -> None:
"""Test that we error out if the table and sequence diverges.""" """Test that we error out if the table and sequence diverges."""
# Prefill with some rows # Prefill with some rows
@ -458,13 +462,13 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
if not USE_POSTGRES_FOR_TESTS: if not USE_POSTGRES_FOR_TESTS:
skip = "Requires Postgres" skip = "Requires Postgres"
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.db_pool: DatabasePool = self.store.db_pool self.db_pool: DatabasePool = self.store.db_pool
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
def _setup_db(self, txn): def _setup_db(self, txn: LoggingTransaction) -> None:
txn.execute("CREATE SEQUENCE foobar_seq") txn.execute("CREATE SEQUENCE foobar_seq")
txn.execute( txn.execute(
""" """
@ -493,10 +497,10 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
return self.get_success(self.db_pool.runWithConnection(_create)) return self.get_success(self.db_pool.runWithConnection(_create))
def _insert_row(self, instance_name: str, stream_id: int): def _insert_row(self, instance_name: str, stream_id: int) -> None:
"""Insert one row as the given instance with given stream_id.""" """Insert one row as the given instance with given stream_id."""
def _insert(txn): def _insert(txn: LoggingTransaction) -> None:
txn.execute( txn.execute(
"INSERT INTO foobar VALUES (?, ?)", "INSERT INTO foobar VALUES (?, ?)",
( (
@ -514,13 +518,13 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.get_success(self.db_pool.runInteraction("_insert_row", _insert)) self.get_success(self.db_pool.runInteraction("_insert_row", _insert))
def test_single_instance(self): def test_single_instance(self) -> None:
"""Test that reads and writes from a single process are handled """Test that reads and writes from a single process are handled
correctly. correctly.
""" """
id_gen = self._create_id_generator() id_gen = self._create_id_generator()
async def _get_next_async(): async def _get_next_async() -> None:
async with id_gen.get_next() as stream_id: async with id_gen.get_next() as stream_id:
self._insert_row("master", stream_id) self._insert_row("master", stream_id)
@ -530,7 +534,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen.get_current_token_for_writer("master"), -1) self.assertEqual(id_gen.get_current_token_for_writer("master"), -1)
self.assertEqual(id_gen.get_persisted_upto_position(), -1) self.assertEqual(id_gen.get_persisted_upto_position(), -1)
async def _get_next_async2(): async def _get_next_async2() -> None:
async with id_gen.get_next_mult(3) as stream_ids: async with id_gen.get_next_mult(3) as stream_ids:
for stream_id in stream_ids: for stream_id in stream_ids:
self._insert_row("master", stream_id) self._insert_row("master", stream_id)
@ -548,14 +552,14 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(second_id_gen.get_current_token_for_writer("master"), -4) self.assertEqual(second_id_gen.get_current_token_for_writer("master"), -4)
self.assertEqual(second_id_gen.get_persisted_upto_position(), -4) self.assertEqual(second_id_gen.get_persisted_upto_position(), -4)
def test_multiple_instance(self): def test_multiple_instance(self) -> None:
"""Tests that having multiple instances that get advanced over """Tests that having multiple instances that get advanced over
federation works corretly. federation works corretly.
""" """
id_gen_1 = self._create_id_generator("first", writers=["first", "second"]) id_gen_1 = self._create_id_generator("first", writers=["first", "second"])
id_gen_2 = self._create_id_generator("second", writers=["first", "second"]) id_gen_2 = self._create_id_generator("second", writers=["first", "second"])
async def _get_next_async(): async def _get_next_async() -> None:
async with id_gen_1.get_next() as stream_id: async with id_gen_1.get_next() as stream_id:
self._insert_row("first", stream_id) self._insert_row("first", stream_id)
id_gen_2.advance("first", stream_id) id_gen_2.advance("first", stream_id)
@ -567,7 +571,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen_1.get_persisted_upto_position(), -1) self.assertEqual(id_gen_1.get_persisted_upto_position(), -1)
self.assertEqual(id_gen_2.get_persisted_upto_position(), -1) self.assertEqual(id_gen_2.get_persisted_upto_position(), -1)
async def _get_next_async2(): async def _get_next_async2() -> None:
async with id_gen_2.get_next() as stream_id: async with id_gen_2.get_next() as stream_id:
self._insert_row("second", stream_id) self._insert_row("second", stream_id)
id_gen_1.advance("second", stream_id) id_gen_1.advance("second", stream_id)
@ -584,13 +588,13 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
if not USE_POSTGRES_FOR_TESTS: if not USE_POSTGRES_FOR_TESTS:
skip = "Requires Postgres" skip = "Requires Postgres"
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.db_pool: DatabasePool = self.store.db_pool self.db_pool: DatabasePool = self.store.db_pool
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
def _setup_db(self, txn): def _setup_db(self, txn: LoggingTransaction) -> None:
txn.execute("CREATE SEQUENCE foobar_seq") txn.execute("CREATE SEQUENCE foobar_seq")
txn.execute( txn.execute(
""" """
@ -642,7 +646,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
from the postgres sequence. from the postgres sequence.
""" """
def _insert(txn): def _insert(txn: LoggingTransaction) -> None:
for _ in range(number): for _ in range(number):
txn.execute( txn.execute(
"INSERT INTO %s VALUES (nextval('foobar_seq'), ?)" % (table,), "INSERT INTO %s VALUES (nextval('foobar_seq'), ?)" % (table,),
@ -659,7 +663,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.get_success(self.db_pool.runInteraction("_insert_rows", _insert)) self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
def test_load_existing_stream(self): def test_load_existing_stream(self) -> None:
"""Test creating ID gens with multiple tables that have rows from after """Test creating ID gens with multiple tables that have rows from after
the position in `stream_positions` table. the position in `stream_positions` table.
""" """

View file

@ -0,0 +1,46 @@
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import MagicMock, patch
from synapse.storage.database import make_conn
from synapse.storage.engines._base import IncorrectDatabaseSetup
from tests.unittest import HomeserverTestCase
from tests.utils import USE_POSTGRES_FOR_TESTS
class UnsafeLocaleTest(HomeserverTestCase):
if not USE_POSTGRES_FOR_TESTS:
skip = "Requires Postgres"
@patch("synapse.storage.engines.postgres.PostgresEngine.get_db_locale")
def test_unsafe_locale(self, mock_db_locale: MagicMock) -> None:
mock_db_locale.return_value = ("B", "B")
database = self.hs.get_datastores().databases[0]
db_conn = make_conn(database._database_config, database.engine, "test_unsafe")
with self.assertRaises(IncorrectDatabaseSetup):
database.engine.check_database(db_conn)
with self.assertRaises(IncorrectDatabaseSetup):
database.engine.check_new_database(db_conn)
db_conn.close()
def test_safe_locale(self) -> None:
database = self.hs.get_datastores().databases[0]
db_conn = make_conn(database._database_config, database.engine, "test_unsafe")
with db_conn.cursor() as txn:
res = database.engine.get_db_locale(txn)
self.assertEqual(res, ("C", "C"))
db_conn.close()

View file

@ -12,7 +12,7 @@ from tests.unittest import TestCase
class DummyDistribution(metadata.Distribution): class DummyDistribution(metadata.Distribution):
def __init__(self, version: str): def __init__(self, version: object):
self._version = version self._version = version
@property @property
@ -30,6 +30,7 @@ old = DummyDistribution("0.1.2")
old_release_candidate = DummyDistribution("0.1.2rc3") old_release_candidate = DummyDistribution("0.1.2rc3")
new = DummyDistribution("1.2.3") new = DummyDistribution("1.2.3")
new_release_candidate = DummyDistribution("1.2.3rc4") new_release_candidate = DummyDistribution("1.2.3rc4")
distribution_with_no_version = DummyDistribution(None)
# could probably use stdlib TestCase --- no need for twisted here # could probably use stdlib TestCase --- no need for twisted here
@ -67,6 +68,18 @@ class TestDependencyChecker(TestCase):
# should not raise # should not raise
check_requirements() check_requirements()
def test_version_reported_as_none(self) -> None:
"""Complain if importlib.metadata.version() returns None.
This shouldn't normally happen, but it was seen in the wild (#12223).
"""
with patch(
"synapse.util.check_dependencies.metadata.requires",
return_value=["dummypkg >= 1"],
):
with self.mock_installed_package(distribution_with_no_version):
self.assertRaises(DependencyException, check_requirements)
def test_checks_ignore_dev_dependencies(self) -> None: def test_checks_ignore_dev_dependencies(self) -> None:
"""Bot generic and per-extra checks should ignore dev dependencies.""" """Bot generic and per-extra checks should ignore dev dependencies."""
with patch( with patch(

View file

@ -15,13 +15,8 @@
import atexit import atexit
import os import os
from unittest.mock import Mock, patch
from urllib import parse as urlparse
from twisted.internet import defer
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.errors import CodeMessageException, cs_error
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersions
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.config.server import DEFAULT_ROOM_VERSION from synapse.config.server import DEFAULT_ROOM_VERSION
@ -187,111 +182,6 @@ def mock_getRawHeaders(headers=None):
return getRawHeaders return getRawHeaders
# This is a mock /resource/ not an entire server
class MockHttpResource:
def __init__(self, prefix=""):
self.callbacks = [] # 3-tuple of method/pattern/function
self.prefix = prefix
def trigger_get(self, path):
return self.trigger(b"GET", path, None)
@patch("twisted.web.http.Request")
@defer.inlineCallbacks
def trigger(
self, http_method, path, content, mock_request, federation_auth_origin=None
):
"""Fire an HTTP event.
Args:
http_method : The HTTP method
path : The HTTP path
content : The HTTP body
mock_request : Mocked request to pass to the event so it can get
content.
federation_auth_origin (bytes|None): domain to authenticate as, for federation
Returns:
A tuple of (code, response)
Raises:
KeyError If no event is found which will handle the path.
"""
path = self.prefix + path
# annoyingly we return a twisted http request which has chained calls
# to get at the http content, hence mock it here.
mock_content = Mock()
config = {"read.return_value": content}
mock_content.configure_mock(**config)
mock_request.content = mock_content
mock_request.method = http_method.encode("ascii")
mock_request.uri = path.encode("ascii")
mock_request.getClientIP.return_value = "-"
headers = {}
if federation_auth_origin is not None:
headers[b"Authorization"] = [
b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,)
]
mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers)
# return the right path if the event requires it
mock_request.path = path
# add in query params to the right place
try:
mock_request.args = urlparse.parse_qs(path.split("?")[1])
mock_request.path = path.split("?")[0]
path = mock_request.path
except Exception:
pass
if isinstance(path, bytes):
path = path.decode("utf8")
for (method, pattern, func) in self.callbacks:
if http_method != method:
continue
matcher = pattern.match(path)
if matcher:
try:
args = [urlparse.unquote(u) for u in matcher.groups()]
(code, response) = yield defer.ensureDeferred(
func(mock_request, *args)
)
return code, response
except CodeMessageException as e:
return e.code, cs_error(e.msg, code=e.errcode)
raise KeyError("No event can handle %s" % path)
def register_paths(self, method, path_patterns, callback, servlet_name):
for path_pattern in path_patterns:
self.callbacks.append((method, path_pattern, callback))
class MockKey:
alg = "mock_alg"
version = "mock_version"
signature = b"\x9a\x87$"
@property
def verify_key(self):
return self
def sign(self, message):
return self
def verify(self, message, sig):
assert sig == b"\x9a\x87$"
def encode(self):
return b"<fake_encoded_key>"
class MockClock: class MockClock:
now = 1000 now = 1000